Skip to content

Commit 10c4b98

Browse files
David Riazatifacebook-github-bot
authored andcommitted
Remove weak script (#22212)
Summary: * Deletes all weak script decorators / associated data structures / methods * In order to keep supporting the standard library in script, this enables recursive script on any function defined in `torch.nn` * Most changes in `torch/nn` are the result of `ag -Q "weak" torch/nn/ -l | xargs sed -i '/weak/d'`, only `rnn.py` needed manual editing to use the `ignore` and `export` to continue supporting the overloaded `forward` methods * `Sequential`/`ModuleList` no longer need to be added to constants since they are compiled on demand This should also fix #22212 Pull Request resolved: #22212 Differential Revision: D15988346 Pulled By: driazati fbshipit-source-id: af223e3ad0580be895377312949997a70e988e4f
1 parent b93f29d commit 10c4b98

28 files changed

+109
-564
lines changed

test/test_jit.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6979,7 +6979,7 @@ def forward(self, v):
69796979
with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"):
69806980
M()
69816981

6982-
def test_script_module_list_sequential_error(self):
6982+
def test_script_module_list_sequential(self):
69836983
class M(torch.jit.ScriptModule):
69846984
def __init__(self, mod_list):
69856985
super(M, self).__init__(False)
@@ -6991,25 +6991,21 @@ def forward(self, v):
69916991
v = m(v)
69926992
return v
69936993

6994-
with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"):
6995-
a = M(nn.Sequential(nn.ReLU()))
6996-
with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"):
6997-
a = M(nn.ModuleList([nn.ReLU()]))
6994+
m = M(nn.Sequential(nn.ReLU()))
6995+
self.assertExportImportModule(m, (torch.randn(2, 2),))
69986996

6999-
def test_attr_module_constants_error(self):
6997+
def test_attr_module_constants(self):
70006998
class M2(torch.jit.ScriptModule):
70016999
def __init__(self, mod_list):
70027000
super(M2, self).__init__(False)
70037001
self.mods = mod_list
70047002

70057003
@torch.jit.script_method
7006-
def forward(self, v):
7004+
def forward(self, x):
70077005
return self.mods.forward(x)
70087006

7009-
with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"):
7010-
M2(nn.Sequential(nn.ReLU()))
7011-
with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"):
7012-
M2(nn.ModuleList([nn.ReLU()]))
7007+
m = M2(nn.Sequential(nn.ReLU()))
7008+
self.assertExportImportModule(m, (torch.randn(2, 2),))
70137009

70147010
def test_script_sequential_for(self):
70157011
class Sub(torch.jit.ScriptModule):
@@ -11007,6 +11003,7 @@ def foo(cond):
1100711003
with self.assertRaisesRegex(torch.jit.Error, "Exception"):
1100811004
foo(torch.tensor(0))
1100911005

11006+
@unittest.skipIf(True, "Removing weak script")
1101011007
def test_weak_script_function(self):
1101111008
outer_var = 10
1101211009
outer_var2 = 11
@@ -11086,6 +11083,7 @@ def foo(x):
1108611083
eg = torch.zeros(3, dtype=torch.uint8)
1108711084
self.assertEqual(foo_traced(eg), foo(eg))
1108811085

11086+
@unittest.skipIf(True, "Removing weak script")
1108911087
def test_weak_module(self):
1109011088

1109111089
@torch._jit_internal.weak_module
@@ -11161,6 +11159,7 @@ def forward(self, x):
1116111159
self.assertEqual(script_result, expected_result)
1116211160
self.assertEqual(script_result, script_result2)
1116311161

11162+
@unittest.skipIf(True, "Removing weak script")
1116411163
def test_weak_module_parameters_and_buffers(self):
1116511164
weights = torch.randn(10, 10)
1116611165
bias = torch.randn(10)
@@ -11219,6 +11218,7 @@ def forward(self, x):
1121911218
self.assertEqual(strong_mod(inp), expected_result)
1122011219
self.assertExportImportModule(strong_mod, (inp,))
1122111220

11221+
@unittest.skipIf(True, "Removing weak script")
1122211222
def test_weak_module_nested(self):
1122311223
@torch._jit_internal.weak_module
1122411224
class OtherWeak(torch.nn.Module):
@@ -11280,6 +11280,7 @@ def forward(self, x):
1128011280
+ F.linear(inp, 2 * torch.ones(10, 10), 2 * torch.ones(10))
1128111281
self.assertEqual(result, expected_result)
1128211282

11283+
@unittest.skipIf(True, "Removing weak script")
1128311284
def test_weak_module_submodule(self):
1128411285
@torch._jit_internal.weak_module
1128511286
class Weak(torch.nn.Module):
@@ -11319,6 +11320,7 @@ def forward(self, x):
1131911320
with self.assertRaisesRegex(RuntimeError, "Cannot call a ScriptModule that is not a submodule of the caller"):
1132011321
strong_mod = Strong()
1132111322

11323+
@unittest.skipIf(True, "Removing weak script")
1132211324
def test_weak_module_copying(self):
1132311325
class Submodule(torch.nn.Module):
1132411326
def __init__(self):
@@ -11385,6 +11387,7 @@ def __init__(self):
1138511387

1138611388
m = M()
1138711389

11390+
@unittest.skipIf(True, "Removing weak script")
1138811391
def test_weak_module_attributes(self):
1138911392
tester = self
1139011393

@@ -11948,6 +11951,7 @@ def test({arg_str}):
1194811951

1194911952
FileCheck().check_not("prim::PythonOp").run(cu.test.graph)
1195011953

11954+
@unittest.skipIf(True, "Removing weak script")
1195111955
def test_overloading(self):
1195211956
@torch._jit_internal.weak_module
1195311957
class W(torch.nn.Module):
@@ -13623,6 +13627,9 @@ def forward(self, x, y):
1362313627
'test_nn_AdaptiveAvgPool3d_tuple_none',
1362413628
'test_nn_AdaptiveMaxPool2d_tuple_none',
1362513629
'test_nn_AdaptiveMaxPool3d_tuple_none',
13630+
13631+
# Uses Module._backend, so this is not supported
13632+
'test_nn_CrossMapLRN2d',
1362613633
}
1362713634

1362813635

@@ -14552,10 +14559,6 @@ def add_nn_module_test(*args, **kwargs):
1455214559

1455314560
module_name = name.split("_")[0]
1455414561

14555-
module = getattr(torch.nn, module_name, None)
14556-
if module is None or torch._jit_internal.weak_types.get(module) is None:
14557-
return
14558-
1455914562
if 'desc' in kwargs and 'eval' in kwargs['desc']:
1456014563
# eval() is not supported, so skip these tests
1456114564
return

torch/_jit_internal.py

Lines changed: 19 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,14 @@
44
circular dependency problems
55
"""
66

7-
import weakref
87
import inspect
8+
import weakref
99
from torch._six import builtins
1010

11-
# Tracks standalone weak script functions
12-
compiled_weak_fns = weakref.WeakKeyDictionary() # noqa: T484
13-
14-
# Tracks which methods should be converted to strong methods
15-
weak_script_methods = weakref.WeakKeyDictionary() # noqa: T484
16-
17-
# Converted modules and their corresponding WeakScriptModuleProxy objects
18-
weak_modules = weakref.WeakKeyDictionary() # noqa: T484
19-
20-
# Types that have been declared as weak modules
21-
weak_types = weakref.WeakKeyDictionary() # noqa: T484
22-
2311
# Wrapper functions that can call either of 2 functions depending on a boolean
2412
# argument
2513
boolean_dispatched = weakref.WeakKeyDictionary() # noqa: T484
2614

27-
COMPILATION_PENDING = object()
28-
COMPILED = object()
29-
3015

3116
def createResolutionCallback(frames_up=0):
3217
"""
@@ -71,51 +56,41 @@ def env(key):
7156
return f_globals[key]
7257
elif hasattr(builtins, key):
7358
return getattr(builtins, key)
74-
else:
75-
return None
7659

7760
return env
7861

7962

80-
def weak_script(fn, _frames_up=0):
63+
def createResolutionCallbackFromClosure(fn):
8164
"""
82-
Marks a function as a weak script function. When used in a script function
83-
or ScriptModule, the weak script function will be lazily compiled and
84-
inlined in the graph. When not used in a script function, the weak script
85-
annotation has no effect.
65+
Create a resolutionCallback by introspecting the function instead of
66+
looking up the stack for the enclosing scope
8667
"""
87-
compiled_weak_fns[fn] = {
88-
"status": COMPILATION_PENDING,
89-
"compiled_fn": None,
90-
"rcb": createResolutionCallback(_frames_up + 1)
91-
}
92-
return fn
68+
var_names = fn.__code__.co_freevars
9369

70+
# map of captured name -> value
71+
free_vars = {}
9472

95-
def weak_module(cls):
96-
weak_types[cls] = {
97-
"method_stubs": None
98-
}
99-
return cls
73+
for index, name in enumerate(var_names):
74+
free_vars[name] = fn.__closure__[index].cell_contents
75+
f_globals = fn.__globals__
10076

77+
def env(key):
78+
if key in free_vars:
79+
return free_vars[key]
80+
elif hasattr(builtins, key):
81+
return getattr(builtins, key)
82+
else:
83+
return f_globals.get(key)
10184

102-
def weak_script_method(fn):
103-
weak_script_methods[fn] = {
104-
"rcb": createResolutionCallback(frames_up=2),
105-
"original_method": fn
106-
}
107-
return fn
85+
return env
10886

10987

11088
def boolean_dispatch(arg_name, arg_index, default, if_true, if_false, module_name, func_name):
11189
"""
112-
Dispatches to either of 2 weak script functions based on a boolean argument.
90+
Dispatches to either of 2 script functions based on a boolean argument.
11391
In TorchScript, the boolean argument must be constant so that the correct
11492
function to use can be determined at compile time.
11593
"""
116-
if compiled_weak_fns.get(if_true) is None or compiled_weak_fns.get(if_false) is None:
117-
raise RuntimeError("both functions must be weak script")
118-
11994
def fn(*args, **kwargs):
12095
dispatch_flag = False
12196
if arg_name in kwargs:

torch/csrc/jit/script/parser.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ Decl mergeTypesFromTypeComment(
2323
<< "Number of type annotations ("
2424
<< type_annotation_decl.params().size()
2525
<< ") did not match the number of "
26-
<< "function parameters (" << expected_num_annotations << ")";
26+
<< (is_method ? "method" : "function")
27+
<< " parameters (" << expected_num_annotations << ")";
2728
}
2829
auto old = decl.params();
2930
auto _new = type_annotation_decl.params();

torch/csrc/jit/script/python_sugared_value.cpp

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,11 @@ std::shared_ptr<SugaredValue> OverloadedMethodValue::call(
244244
<< err.str();
245245
}
246246

247+
bool should_recurse(py::object obj) {
248+
return py::cast<bool>(py::module::import("torch.jit")
249+
.attr("_is_recursive_script_enabled")(obj));
250+
}
251+
247252
std::shared_ptr<SugaredValue> ModuleValue::attr(
248253
const SourceRange& loc,
249254
Function& m,
@@ -307,7 +312,7 @@ std::shared_ptr<SugaredValue> ModuleValue::attr(
307312

308313
// If recursive script mode is on, create a ScriptModule and register it as
309314
// as submodule or register a python method as a script::Method
310-
if (getRecursiveScriptMode()) {
315+
if (should_recurse(attr)) {
311316
if (py::isinstance(attr, py::module::import("torch.nn").attr("Module"))) {
312317
// If the module is a submodule of the py_module, convert it to a
313318
// ScriptModule and add it as a submodule to the script::Module. This
@@ -471,11 +476,6 @@ std::shared_ptr<SugaredValue> toSugaredValue(
471476
}
472477
}
473478

474-
auto weak_obj =
475-
py::module::import("torch.jit").attr("_try_get_weak_module")(obj);
476-
if (!weak_obj.is_none()) {
477-
obj = weak_obj;
478-
}
479479
if (auto callee = as_function(obj)) {
480480
return std::make_shared<FunctionValue>(callee);
481481
} else if (py::isinstance<py::module>(obj)) {
@@ -504,12 +504,6 @@ std::shared_ptr<SugaredValue> toSugaredValue(
504504
<< "which is currently not supported in Torchscript."
505505
<< "Please open a feature request to add it.";
506506
}
507-
508-
auto compiled_fn =
509-
py::module::import("torch.jit").attr("_try_compile_weak_script")(obj);
510-
if (auto callee = as_function(compiled_fn)) {
511-
return std::make_shared<FunctionValue>(callee);
512-
}
513507
}
514508

515509
py::object dispatched_fn =
@@ -528,7 +522,7 @@ std::shared_ptr<SugaredValue> toSugaredValue(
528522
}
529523
}
530524

531-
if (getRecursiveScriptMode() && py::isinstance<py::function>(obj)) {
525+
if (should_recurse(obj) && py::isinstance<py::function>(obj)) {
532526
auto compiled_fn =
533527
py::module::import("torch.jit").attr("_try_compile_fn")(obj);
534528
if (auto callee = as_function(compiled_fn)) {

0 commit comments

Comments
 (0)