Skip to content

Commit ae2c668

Browse files
Revert "[dynamo][api] Better support of torch.nn.Module (#88629)"
This reverts commit c833485. Reverted #88629 on behalf of https://github.com/anijain2305 due to job failing on master https://github.com/pytorch/pytorch/actions/runs/3449914495/jobs/5758267231
1 parent 6b775c4 commit ae2c668

File tree

5 files changed

+20
-204
lines changed

5 files changed

+20
-204
lines changed

test/dynamo/test_modules.py

Lines changed: 0 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -904,133 +904,6 @@ def forward(self, x):
904904
self.assertTrue(torch._dynamo.testing.same(real, graph(rx)))
905905

906906

907-
class MockModule(torch.nn.Module):
908-
def __init__(self):
909-
super().__init__()
910-
self.relu = torch.nn.ReLU()
911-
self.linear = torch.nn.Linear(10, 10)
912-
self.register_buffer("buf0", torch.randn(10, 10))
913-
914-
def forward(self, x):
915-
return self.relu(self.linear(x) + self.buf0)
916-
917-
918-
class OptimizedModuleTest(torch._dynamo.test_case.TestCase):
919-
def test_nn_module(self):
920-
mod = MockModule()
921-
cnt = torch._dynamo.testing.CompileCounter()
922-
opt_mod = torch._dynamo.optimize(cnt)(mod)
923-
self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule)
924-
925-
x = torch.randn(10, 10)
926-
self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x)))
927-
self.assertEqual(cnt.frame_count, 1)
928-
929-
def test_to(self):
930-
mod = MockModule()
931-
cnt = torch._dynamo.testing.CompileCounter()
932-
opt_mod = torch._dynamo.optimize(cnt)(mod)
933-
x = torch.randn(10, 10)
934-
self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x)))
935-
self.assertEqual(cnt.frame_count, 1)
936-
937-
# Ensure that there is no recompilation
938-
opt_mod(x)
939-
self.assertEqual(cnt.frame_count, 1)
940-
941-
opt_mod = opt_mod.to(device="cpu").to(dtype=torch.float64)
942-
self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule)
943-
x = torch.randn(10, 10).to(dtype=torch.float64)
944-
opt_mod(x)
945-
# Ensure that there is a recompilation
946-
self.assertEqual(cnt.frame_count, 2)
947-
948-
def test_attr(self):
949-
class MockModule(torch.nn.Module):
950-
def __init__(self):
951-
super().__init__()
952-
self.linear = torch.nn.Linear(10, 10)
953-
self.register_buffer("buf0", torch.randn(10, 10))
954-
955-
def forward(self, x):
956-
return self.r(torch.sin(x)) + self.buf0
957-
958-
mod = MockModule()
959-
opt_mod = torch._dynamo.optimize("eager")(mod)
960-
961-
# Check parameteres and buffers
962-
for (p1, p2) in zip(mod.parameters(), opt_mod.parameters()):
963-
self.assertTrue(id(p1) == id(p2))
964-
965-
def test_recursion(self):
966-
mod = MockModule()
967-
cnt = torch._dynamo.testing.CompileCounter()
968-
opt_mod = torch._dynamo.optimize(cnt)(mod)
969-
970-
for _ in range(5):
971-
opt_mod = torch._dynamo.optimize(cnt)(opt_mod)
972-
opt_mod(torch.randn(10, 10))
973-
self.assertEqual(cnt.frame_count, 1)
974-
975-
def test_composition(self):
976-
class InnerModule(torch.nn.Module):
977-
def __init__(self):
978-
super().__init__()
979-
self.relu = torch.nn.ReLU()
980-
981-
def forward(self, x):
982-
return self.relu(torch.sin(x))
983-
984-
opt_inner_mod = InnerModule()
985-
986-
class OuterModule(torch.nn.Module):
987-
def __init__(self):
988-
super().__init__()
989-
self.mod = opt_inner_mod
990-
991-
def forward(self, x):
992-
return self.mod(torch.cos(x))
993-
994-
outer_mod = OuterModule()
995-
cnt = torch._dynamo.testing.CompileCounter()
996-
opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod)
997-
998-
x = torch.randn(4)
999-
self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule)
1000-
self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x)))
1001-
self.assertEqual(cnt.frame_count, 1)
1002-
1003-
def test_composition_with_opt_mod(self):
1004-
class InnerModule(torch.nn.Module):
1005-
def __init__(self):
1006-
super().__init__()
1007-
self.relu = torch.nn.ReLU()
1008-
1009-
def forward(self, x):
1010-
return self.relu(torch.sin(x))
1011-
1012-
inner_mod = InnerModule()
1013-
cnt = torch._dynamo.testing.CompileCounter()
1014-
opt_inner_mod = torch._dynamo.optimize(cnt)(inner_mod)
1015-
1016-
class OuterModule(torch.nn.Module):
1017-
def __init__(self):
1018-
super().__init__()
1019-
self.mod = opt_inner_mod
1020-
1021-
def forward(self, x):
1022-
return self.mod(torch.cos(x))
1023-
1024-
outer_mod = OuterModule()
1025-
opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod)
1026-
1027-
x = torch.randn(4)
1028-
self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule)
1029-
self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x)))
1030-
# There will be a graph break for the inner mod being OptimizedModule
1031-
self.assertEqual(cnt.frame_count, 2)
1032-
1033-
1034907
if __name__ == "__main__":
1035908
from torch._dynamo.test_case import run_tests
1036909

torch/_dynamo/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
export,
88
optimize,
99
optimize_assert,
10-
OptimizedModule,
1110
reset_code,
1211
run,
1312
skip,
@@ -26,7 +25,6 @@
2625
"reset",
2726
"list_backends",
2827
"skip",
29-
"OptimizedModule",
3028
]
3129

3230

torch/_dynamo/debug_utils.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -486,16 +486,8 @@ def same_two_models(gm, opt_gm, example_inputs, only_fwd=False):
486486
"""
487487
Check two models have same accuracy.
488488
"""
489-
from .eval_frame import OptimizedModule
490-
from .testing import named_parameters_for_optimized_module
491489
from .utils import same
492490

493-
if isinstance(gm, OptimizedModule):
494-
gm.named_parameters = named_parameters_for_optimized_module(gm)
495-
496-
if isinstance(opt_gm, OptimizedModule):
497-
opt_gm.named_parameters = named_parameters_for_optimized_module(opt_gm)
498-
499491
ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd)
500492

501493
try:

torch/_dynamo/eval_frame.py

Lines changed: 20 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import logging
66
import os
77
import sys
8-
import textwrap
98
import threading
109
import traceback
1110
import types
@@ -45,27 +44,6 @@
4544
most_recent_backend = None
4645

4746

48-
class OptimizedModule(torch.nn.Module):
49-
"""
50-
Wraps the original nn.Module object and later patches its
51-
forward method to optimized self.forward method.
52-
"""
53-
54-
def __init__(self, mod):
55-
super().__init__()
56-
# Installs the params/buffer
57-
self._orig_mod = mod
58-
59-
def __getattr__(self, name):
60-
if name == "_orig_mod":
61-
return self._modules["_orig_mod"]
62-
return getattr(self._orig_mod, name)
63-
64-
def forward(self, *args, **kwargs):
65-
# This will be monkey patched later
66-
raise RuntimeError("Should not be here")
67-
68-
6947
def remove_from_cache(f):
7048
"""
7149
Make sure f.__code__ is not cached to force a recompile
@@ -140,15 +118,31 @@ def __call__(self, fn):
140118
# Optimize the forward method of torch.nn.Module object
141119
if isinstance(fn, torch.nn.Module):
142120
mod = fn
143-
new_mod = OptimizedModule(mod)
144-
new_mod.forward = self(mod.forward)
121+
optimized_forward = self(mod.forward)
122+
123+
class TorchDynamoNNModuleWrapper:
124+
"""
125+
A wrapper that redirects the forward call to the optimized
126+
forward, while for rest it redirects the calls to the original
127+
module.
128+
"""
129+
130+
def __getattr__(self, name):
131+
return getattr(mod, name)
132+
133+
def forward(self, *args, **kwargs):
134+
return optimized_forward(*args, **kwargs)
135+
136+
def __call__(self, *args, **kwargs):
137+
return self.forward(*args, **kwargs)
138+
139+
new_mod = TorchDynamoNNModuleWrapper()
145140
# Save the function pointer to find the original callable while nesting
146141
# of decorators.
147-
new_mod._torchdynamo_orig_callable = mod.forward
142+
new_mod._torchdynamo_orig_callable = mod
148143
return new_mod
149144

150145
assert callable(fn)
151-
152146
callback = self.callback
153147
on_enter = self.on_enter
154148
backend_ctx_ctor = self.extra_ctx_ctor
@@ -190,34 +184,6 @@ def _fn(*args, **kwargs):
190184
# If the function is called using torch._dynamo.optimize decorator, we
191185
# should prevent any type of skipping.
192186
if callback not in (None, False):
193-
if not hasattr(fn, "__code__"):
194-
raise RuntimeError(
195-
textwrap.dedent(
196-
"""
197-
198-
torch._dynamo.optimize is called on a non function object.
199-
If this is a callable class, please optimize the individual methods that you are interested in optimizing.
200-
201-
>> class CallableClass:
202-
>> def __init__(self):
203-
>> super().__init__()
204-
>> self.relu = torch.nn.ReLU()
205-
>>
206-
>> def __call__(self, x):
207-
>> return self.relu(torch.sin(x))
208-
>>
209-
>> def print_hello(self):
210-
>> print("Hello world")
211-
>>
212-
>> mod = CallableClass()
213-
214-
If you want to optimize the __call__ function
215-
216-
>> mod.__call__ = torch._dynamo.optimize(mod.__call__)
217-
218-
"""
219-
)
220-
)
221187
always_optimize_code_objects[fn.__code__] = True
222188

223189
return _fn

torch/_dynamo/testing.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,6 @@ def clone_me(x):
3232
return x.detach().clone().requires_grad_(x.requires_grad)
3333

3434

35-
def named_parameters_for_optimized_module(mod):
36-
assert isinstance(mod, eval_frame.OptimizedModule)
37-
return mod._orig_mod.named_parameters
38-
39-
40-
def remove_optimized_module_prefix(name):
41-
prefix = "_orig_mod."
42-
assert name.startswith(prefix)
43-
return name[len(prefix) :]
44-
45-
4635
def collect_results(model, prediction, loss, example_inputs):
4736
results = []
4837
results.append(prediction)
@@ -55,8 +44,6 @@ def collect_results(model, prediction, loss, example_inputs):
5544
grads = dict()
5645
params = dict()
5746
for name, param in model.named_parameters():
58-
if isinstance(model, eval_frame.OptimizedModule):
59-
name = remove_optimized_module_prefix(name)
6047
param_copy = param
6148
grad = param.grad
6249
# Treat None and zero grad as same

0 commit comments

Comments
 (0)