Skip to content
33 changes: 33 additions & 0 deletions test/jit/test_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,6 +1310,39 @@ def check(mod):
imported = self.getExportImportCopy(traced)
check(imported.foo)

# Note that Bar's forward can only be traced, but not scripted
class Bar(nn.Module):
def __init__(self):
super().__init__()

@torch.jit.export
def addTwo(self, x):
return x + 2

def forward(self, input):
return (lambda a: a + 1)(input)

# When tracing Bar as a submodule, we only want to script the
# exported methods, and we want to keep the forwards still
# being traced.
class WrapperExports(torch.nn.Module):
def __init__(self):
super(WrapperExports, self).__init__()
self.bar = Bar()

@torch.jit.export
def addOne(self, x):
return x + 1

def forward(self, x):
return self.bar(x)

f = WrapperExports()

traced = torch.jit.trace(f, (torch.rand(3, 4),))
expected_names = ['addOne']
check(traced)

def test_trace_autograd_function(self):
class TestFunc(torch.autograd.Function):
@staticmethod
Expand Down
17 changes: 15 additions & 2 deletions torch/jit/_recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,19 @@ def make_stub_from_method(nn_module, method_name):
return make_stub(func, method_name)


def make_stubs_from_exported_methods(mod):
stubs = []
for name in dir(mod):
item = getattr(mod, name, None)
if (
_jit_internal.get_torchscript_modifier(item)
is _jit_internal.FunctionModifiers.EXPORT
):
stubs.append(make_stub_from_method(mod, name))

return stubs


# base types that can be constants
# in addition, tuples and lists of these base types are also considered constants
# If you edit this list, then you also need to edit the handlers in
Expand Down Expand Up @@ -371,8 +384,8 @@ def init_fn(script_module):
elif isinstance(orig_value, torch.jit.ScriptModule):
scripted = orig_value
else:
# use the default recursive rule to compile the module
scripted = create_script_module_impl(orig_value, sub_concrete_type, infer_methods_to_compile)
# always reuse the provided stubs_fn to infer the methods to compile
scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)

cpp_module.setattr(name, scripted)
script_module._modules[name] = scripted
Expand Down
7 changes: 5 additions & 2 deletions torch/jit/_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch
import torch._jit_internal as _jit_internal
from torch.utils import set_module
from torch.jit._recursive import ScriptMethodStub, wrap_cpp_module
from torch.jit._recursive import ScriptMethodStub, wrap_cpp_module, infer_methods_to_compile
from torch.nn import Module
from torch.jit._state import _enabled
from torch.jit._builtins import _register_builtin
Expand Down Expand Up @@ -199,7 +199,10 @@ def init_then_script(self, *args, **kwargs):

def make_stubs(module):
cls = type(module)
return [v for k, v in sorted(cls._methods.items())]
if hasattr(cls, "_methods"):
return [v for k, v in sorted(cls._methods.items())]
else:
return infer_methods_to_compile(module)

self.__dict__[
"_actual_script_module"
Expand Down
21 changes: 4 additions & 17 deletions torch/jit/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from torch.jit._script import ScriptModule, _CachedForward, script
from torch._jit_internal import _qualified_name
from torch.autograd import function
from torch import _jit_internal
from torch.nn import Module

_flatten = torch._C._jit_flatten
Expand Down Expand Up @@ -549,23 +548,11 @@ def make_module(mod, _module_class, _compilation_unit):
return mod
elif torch._jit_internal.module_has_exports(mod):

def make_stubs_from_exported_methods(mod):
exported = []
for name in dir(mod):
item = getattr(mod, name, None)
if (
torch._jit_internal.get_torchscript_modifier(item)
is _jit_internal.FunctionModifiers.EXPORT
):
exported.append(name)

stubs = []
for method in exported:
stubs.append(torch.jit._recursive.make_stub_from_method(mod, method))
return stubs

infer_methods_stubs_fn = torch.jit._recursive.make_stubs_from_exported_methods
return torch.jit._recursive.create_script_module(
mod, make_stubs_from_exported_methods, share_types=False
mod,
infer_methods_stubs_fn,
share_types=False
)
else:
if _module_class is None:
Expand Down