Skip to content

Commit 67eeeed

Browse files
author
davidriazati
committed
[jit] Support recursive ModuleList / Sequential
Adds support for recursively compiling `nn.Sequential` and `nn.ModuleList`. When either is used, it is converted to a `jit._ConstModuleList` or `jit._ConstSequential` as necessary. Due to this, we don't need to add it to `__constants__` since it's made constant on demand. This PR also moves the recursive script tests out to their own class `TestRecursiveScript` (the added test is called `test_iterable_modules`)
1 parent 37fed9b commit 67eeeed

File tree

3 files changed

+155
-85
lines changed

3 files changed

+155
-85
lines changed

test/test_jit.py

Lines changed: 119 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -13133,86 +13133,6 @@ def is_tensor_value(item):
1313313133
continue
1313413134
self.assertEqual(item[1], loaded_item)
1313513135

13136-
def test_script_recurse(self):
13137-
def a_python_fn(a, b, c):
13138-
return a + b + c
13139-
13140-
with torch.jit._enable_recursive_script():
13141-
@torch.jit.script
13142-
def a_script_fn(d, e, f):
13143-
return a_python_fn(d, e, f)
13144-
13145-
graph = str(a_script_fn.graph)
13146-
FileCheck().check("aten::add").run(graph)
13147-
FileCheck().check_not("a_python_fn").run(graph)
13148-
t = torch.ones(2, 2)
13149-
self.assertEqual(a_script_fn(t, t, t), t + t + t)
13150-
13151-
def test_module_recursive(self):
13152-
class Other(torch.nn.Module):
13153-
__constants__ = ['x']
13154-
13155-
def __init__(self, x):
13156-
super(Other, self).__init__()
13157-
self.x = x
13158-
self.param = torch.nn.Parameter(torch.ones(2, 2))
13159-
13160-
def some_unscriptable_method(self):
13161-
a = 2
13162-
a = [2]
13163-
return a
13164-
13165-
def forward(self, t):
13166-
return t + self.x + self.param
13167-
13168-
13169-
class M(torch.nn.Module):
13170-
__constants__ = ['x']
13171-
13172-
def __init__(self):
13173-
super(M, self).__init__()
13174-
self.other = Other(200)
13175-
13176-
def forward(self, t):
13177-
return self.other(t) * 2
13178-
13179-
with torch.jit._enable_recursive_script():
13180-
sm = torch.jit.script(M())
13181-
13182-
self.assertExportImportModule(sm, (torch.ones(2, 2),))
13183-
13184-
def test_module_function_export(self):
13185-
class Other(torch.nn.Module):
13186-
__constants__ = ['x']
13187-
13188-
def __init__(self, x):
13189-
super(Other, self).__init__()
13190-
self.x = x
13191-
self.param = torch.nn.Parameter(torch.ones(2, 2))
13192-
13193-
@torch.jit.export
13194-
def some_entry_point(self, y):
13195-
return y + 20
13196-
13197-
def forward(self, t):
13198-
return t + self.x + self.param
13199-
13200-
13201-
class M(torch.nn.Module):
13202-
__constants__ = ['x']
13203-
13204-
def __init__(self):
13205-
super(M, self).__init__()
13206-
self.other = Other(200)
13207-
13208-
def forward(self, t):
13209-
return self.other(t) * 2
13210-
13211-
with torch.jit._enable_recursive_script():
13212-
sm = torch.jit.script(M())
13213-
13214-
self.assertExportImportModule(sm, (torch.ones(2, 2),))
13215-
1321613136
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle")
1321713137
def test_old_models_bc(self):
1321813138
model = {
@@ -13475,6 +13395,125 @@ def foo(a):
1347513395
foo(1)
1347613396

1347713397

13398+
class TestRecursiveScript(JitTestCase):
13399+
"""
13400+
Tests in this class are all run under `with torch.jit._enable_recursive_script()`
13401+
"""
13402+
def run(self, result=None):
13403+
with torch.jit._enable_recursive_script():
13404+
super(TestRecursiveScript, self).run(result)
13405+
13406+
def checkModule(self, nn_module, args):
13407+
"""
13408+
Check that a nn.Module's results in Script mode match eager and that it
13409+
can be exported
13410+
"""
13411+
sm = torch.jit.script(nn_module)
13412+
13413+
eager_out = nn_module(*args)
13414+
script_out = sm(*args)
13415+
13416+
self.assertEqual(eager_out, script_out)
13417+
self.assertExportImportModule(sm, args)
13418+
13419+
return sm
13420+
13421+
def test_script_basic(self):
13422+
def a_python_fn(a, b, c):
13423+
return a + b + c
13424+
13425+
@torch.jit.script
13426+
def a_script_fn(d, e, f):
13427+
return a_python_fn(d, e, f)
13428+
13429+
graph = str(a_script_fn.graph)
13430+
FileCheck().check("aten::add").run(graph)
13431+
FileCheck().check_not("a_python_fn").run(graph)
13432+
t = torch.ones(2, 2)
13433+
self.assertEqual(a_script_fn(t, t, t), t + t + t)
13434+
13435+
def test_module_basic(self):
13436+
class Other(torch.nn.Module):
13437+
__constants__ = ['x']
13438+
13439+
def __init__(self, x):
13440+
super(Other, self).__init__()
13441+
self.x = x
13442+
self.param = torch.nn.Parameter(torch.ones(2, 2))
13443+
13444+
def some_unscriptable_method(self):
13445+
a = 2
13446+
a = [2]
13447+
return a
13448+
13449+
def forward(self, t):
13450+
return t + self.x + self.param
13451+
13452+
13453+
class M(torch.nn.Module):
13454+
__constants__ = ['x']
13455+
13456+
def __init__(self):
13457+
super(M, self).__init__()
13458+
self.other = Other(200)
13459+
13460+
def forward(self, t):
13461+
return self.other(t) * 2
13462+
13463+
sm = torch.jit.script(M())
13464+
13465+
self.assertExportImportModule(sm, (torch.ones(2, 2),))
13466+
13467+
def test_module_function_export(self):
13468+
class Other(torch.nn.Module):
13469+
__constants__ = ['x']
13470+
13471+
def __init__(self, x):
13472+
super(Other, self).__init__()
13473+
self.x = x
13474+
self.param = torch.nn.Parameter(torch.ones(2, 2))
13475+
13476+
@torch.jit.export
13477+
def some_entry_point(self, y):
13478+
return y + 20
13479+
13480+
def forward(self, t):
13481+
return t + self.x + self.param
13482+
13483+
13484+
class M(torch.nn.Module):
13485+
__constants__ = ['x']
13486+
13487+
def __init__(self):
13488+
super(M, self).__init__()
13489+
self.other = Other(200)
13490+
13491+
def forward(self, t):
13492+
return self.other(t) * 2
13493+
13494+
sm = torch.jit.script(M())
13495+
self.assertExportImportModule(sm, (torch.ones(2, 2),))
13496+
13497+
def test_iterable_modules(self):
13498+
class M(torch.nn.Module):
13499+
def __init__(self):
13500+
super(M, self).__init__()
13501+
self.sequential = nn.Sequential(
13502+
nn.Linear(5, 5),
13503+
nn.Linear(5, 5),
13504+
nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5))
13505+
)
13506+
self.module_list = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)])
13507+
13508+
def forward(self, x):
13509+
for mod in self.module_list:
13510+
x += mod(x)
13511+
x += self.sequential(x)
13512+
return x
13513+
13514+
self.checkModule(M(), (torch.randn(5, 5),))
13515+
13516+
1347813517
class MnistNet(nn.Module):
1347913518
def __init__(self):
1348013519
super(MnistNet, self).__init__()

torch/csrc/jit/script/python_sugared_value.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,6 @@ std::shared_ptr<SugaredValue> ModuleValue::attr(
275275

276276
// This can also be a call to a non-script module, or a plain
277277
// python method. If so return this as a python value.
278-
279278
py::object overloads =
280279
py_module_.attr("_overloads").attr("get")(field, py::none());
281280
if (!overloads.is_none()) {
@@ -327,7 +326,7 @@ std::shared_ptr<SugaredValue> ModuleValue::attr(
327326
module_->register_module(field, submodule);
328327
auto v = module_->find_module(field);
329328
return std::make_shared<ModuleValue>(
330-
m.graph()->insertGetAttr(self_, field), v, attr);
329+
m.graph()->insertGetAttr(self_, field), v, result);
331330
}
332331
} else if (py::isinstance<py::function>(attr)) {
333332
auto stub = py::module::import("torch.jit")

torch/jit/__init__.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -940,11 +940,26 @@ def env(key):
940940

941941
return env
942942

943+
def _create_constant_iterable_module(module):
944+
for i, submodule in enumerate(module):
945+
if isinstance(submodule, (ModuleList, Sequential)):
946+
# Make each item in the module a constant
947+
module[i] = _create_constant_iterable_module(submodule)
948+
949+
if isinstance(module, Sequential):
950+
return _ConstSequential(module)
951+
elif isinstance(module, ModuleList):
952+
return _ConstModuleList(module)
953+
else:
954+
raise RuntimeError("Only nn.ModuleList and nn.Sequential can be made "
955+
"into constant modules, found {}".format(module))
956+
943957

944958
def _make_strong_submodule(field, module, parent):
945959
if field not in parent._modules:
946960
# It's not a submodule, don't do anything
947961
return None
962+
948963
return _convert_to_script_module(module)
949964

950965

@@ -953,6 +968,9 @@ def _try_compile_fn(fn):
953968
# Don't do anything for @ignore'd functions
954969
return None
955970

971+
if isinstance(fn, torch.nn.Module):
972+
return None
973+
956974
# We don't have the actual scope where the function was defined, but we can
957975
# extract the necessary info from the closed over variables on the function
958976
# object
@@ -1237,8 +1255,6 @@ def _create_methods_from_stubs(self, stubs):
12371255
# has run. This has to occur after the user-defined __init__ so that
12381256
# submodules and parameters are initialized _before_ the script compiler
12391257
# resolve references to `self.param` or `self.module`.
1240-
1241-
12421258
class ScriptMeta(type):
12431259
# this has to inherit from pybind11's metaclass otherwise we get
12441260
# issues because ScriptModule inherits from torch._C.ScriptModule,
@@ -1260,6 +1276,7 @@ def __init__(cls, name, bases, attrs):
12601276
cls._methods[v.original_method.__name__] = v
12611277

12621278
original_init = getattr(cls, '__init__', lambda self: None)
1279+
print("set overloads", cls, name)
12631280
cls._overloads = dict(getattr(cls, '__overloads__', {}))
12641281

12651282
# after the user's __init__ register all the script methods
@@ -1537,6 +1554,11 @@ def graph_for(self, *args, **kwargs):
15371554
return self.forward.graph_for(*args, **kwargs)
15381555

15391556
class WeakScriptModuleProxy(ScriptModule):
1557+
# TODO: [weak script refactor]
1558+
# WeakScriptModule proxy should be deleted since its functionality is
1559+
# subsumed by recursive scripting, and the copying code in init moved
1560+
# to a function to create a ScriptModule from an nn.Module without
1561+
# making a WeakScriptModuleProxy
15401562
"""
15411563
Copies the parameters, buffers, constants, attributes, and submodules
15421564
of an nn.Module into itself.
@@ -1568,7 +1590,13 @@ def __init__(self, original, stubs):
15681590
elif isinstance(item, (Parameter, Module, Attribute)):
15691591
if isinstance(item, (ModuleList, Sequential)):
15701592
# These are in __constants__, so ignore them here
1571-
continue
1593+
1594+
if not torch._C._jit_recursive_script():
1595+
# For recursive script, these are constantified after
1596+
# they are used, so they don't need to be in constants.
1597+
# The `continue` here should be deleted along with
1598+
# [weak script refactor]
1599+
continue
15721600
ScriptModule.__setattr__(self, name, item)
15731601

15741602
# Copy buffers
@@ -1651,6 +1679,10 @@ def _convert_to_script_module(mod, methods=None):
16511679
`('forward',)`. Methods accessed in forward are scripted on demand if
16521680
`_enable_recursive_script()` is used.
16531681
"""
1682+
if isinstance(mod, (ModuleList, Sequential)):
1683+
# Create constant versions for the iterable modules
1684+
return _create_constant_iterable_module(mod)
1685+
16541686
if methods is None:
16551687
methods = ('forward',)
16561688
exported = []

0 commit comments

Comments
 (0)