Skip to content

Commit 759157b

Browse files
author
davidriazati
committed
combine rcbs
1 parent 544d77f commit 759157b

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

test/test_jit.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12865,6 +12865,9 @@ def forward(self, x):
1286512865
self.assertEqual(m.int64_max, imported.int64_max)
1286612866
self.assertEqual(m.int64_min, imported.int64_min)
1286712867

12868+
def test_script_scope(self):
12869+
scripted = torch.jit.script(torch.nn.functional.pad)
12870+
1286812871
@unittest.skipIf(IS_WINDOWS, "NYI: TemporaryFileName on Windows")
1286912872
def test_serialization_sharing(self):
1287012873
class M(torch.jit.ScriptModule):

torch/jit/__init__.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,8 +1086,6 @@ def _compile_and_register_class(obj, rcb, qualified_name):
10861086
def script(obj, optimize=True, _frames_up=0, _rcb=None):
10871087
if not _enabled:
10881088
return obj
1089-
if _rcb is None:
1090-
_rcb = _jit_internal.createResolutionCallback(_frames_up + 1)
10911089

10921090
if isinstance(obj, torch.nn.Module):
10931091
return _convert_to_script_module(obj)
@@ -1096,10 +1094,23 @@ def script(obj, optimize=True, _frames_up=0, _rcb=None):
10961094
if inspect.isclass(obj):
10971095
if not _is_new_style_class(obj):
10981096
raise RuntimeError("TorchScript classes must be new-style classes. Please inherit from 'object'")
1097+
if _rcb is None:
1098+
_rcb = _jit_internal.createResolutionCallback(_frames_up + 1)
10991099
_compile_and_register_class(obj, _rcb, qualified_name)
11001100
return obj
11011101
else:
11021102
ast = get_jit_def(obj)
1103+
if _rcb is None:
1104+
closure_rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
1105+
stack_rcb = _jit_internal.createResolutionCallback(_frames_up + 1)
1106+
def _rcb(name):
1107+
# since type comments aren't captured in the function's closures,
1108+
# we still need to try to the rcb based on stack frames if the
1109+
# closure rcb fails
1110+
result = closure_rcb(name)
1111+
if result:
1112+
return result
1113+
return stack_rcb(name)
11031114
fn = torch._C._jit_script_compile(qualified_name, ast, _rcb, get_default_args(obj))
11041115
# Forward docstrings
11051116
fn.__doc__ = obj.__doc__

0 commit comments

Comments
 (0)