Skip to content

Commit 483786c

Browse files
author
root
committed
extend torch.jit._overload to modules
ghstack-source-id: af601e7 Pull Request resolved: #24259
1 parent 1daac9c commit 483786c

File tree

7 files changed

+177
-44
lines changed

7 files changed

+177
-44
lines changed

test/test_jit.py

Lines changed: 62 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13166,31 +13166,26 @@ def test_uses():
1316613166

1316713167
self.checkScript(test_uses, ())
1316813168

13169-
@unittest.skipIf(True, "Removing weak script")
13170-
def test_overloading(self):
13171-
@torch._jit_internal.weak_module
13169+
def test_method_overloading(self):
1317213170
class W(torch.nn.Module):
13173-
__overloads__ = {'forward': ['forward_tuple', 'forward_tensor']}
13174-
1317513171
def __init__(self):
1317613172
super(W, self).__init__()
1317713173

13178-
@torch._jit_internal.weak_script_method
13179-
def forward_tuple(self, x):
13174+
@torch.jit._overload_method # noqa: F811
13175+
def forward(self, x): # noqa: F811
1318013176
# type: (Tuple[Tensor, Tensor]) -> Tensor
13181-
return x[0] + 5
13182-
13183-
def forward(self, x):
13184-
# manually do argument switching
13185-
if isinstance(x, tuple):
13186-
return self.forward_tuple(x)
13187-
else:
13188-
return self.forward_tensor(x)
13177+
pass
1318913178

13190-
@torch._jit_internal.weak_script_method
13191-
def forward_tensor(self, x):
13179+
@torch.jit._overload_method # noqa: F811
13180+
def forward(self, x): # noqa: F811
1319213181
# type: (Tensor) -> Tensor
13193-
return x + 20
13182+
pass
13183+
13184+
def forward(self, x): # noqa: F811
13185+
if isinstance(x, Tensor):
13186+
return x + 20
13187+
else:
13188+
return x[0] + 5
1319413189

1319513190
class S(torch.jit.ScriptModule):
1319613191
def __init__(self):
@@ -13201,14 +13196,61 @@ def __init__(self):
1320113196
def forward(self, x):
1320213197
return self.weak(x) + self.weak((x, x))
1320313198

13204-
s = S()
13199+
s_mod = S()
1320513200
x = torch.ones(1)
13206-
self.assertEqual(s(x), x + 20 + 5 + x)
13201+
self.assertEqual(s_mod(x), x + 20 + 5 + x)
1320713202

1320813203
w = W()
1320913204
self.assertEqual(w((x, x)), x + 5)
1321013205
self.assertEqual(w((x)), x + 20)
1321113206

13207+
class Unannotated(torch.nn.Module):
13208+
def __init__(self):
13209+
super(Unannotated, self).__init__()
13210+
13211+
@torch.jit._overload_method # noqa: F811
13212+
def hello(self, x): # noqa: F811
13213+
pass
13214+
13215+
@torch.jit._overload_method # noqa: F811
13216+
def hello(self, x): # noqa: F811
13217+
# type: (int) -> (int)
13218+
pass
13219+
13220+
def hello(self, x): # noqa: F811
13221+
return x + 3
13222+
13223+
def forward(self):
13224+
return self.hello(1), self.hello(.5)
13225+
13226+
w = Unannotated()
13227+
with self.assertRaisesRegex(Exception, "explicitly add type annotations to overloaded functions"):
13228+
torch.jit.script(w)
13229+
13230+
class CompileOverloadError(torch.nn.Module):
13231+
def __init__(self):
13232+
super(CompileOverloadError, self).__init__()
13233+
13234+
@torch.jit._overload_method # noqa: F811
13235+
def hello(self, x): # noqa: F811
13236+
# type: (str) -> (int)
13237+
pass
13238+
13239+
@torch.jit._overload_method # noqa: F811
13240+
def hello(self, x): # noqa: F811
13241+
# type: (int) -> (int)
13242+
pass
13243+
13244+
def hello(self, x): # noqa: F811
13245+
return x + 1
13246+
13247+
def forward(self):
13248+
return self.hello("hi"), self.hello(.5)
13249+
13250+
w = CompileOverloadError()
13251+
with self.assertRaisesRegex(Exception, "but instead found type \'str\'"):
13252+
torch.jit.script(w)
13253+
1321213254
def test_select_after_chunk(self):
1321313255
def foo(x):
1321413256
chunked = torch.chunk(x, 1)

torch/_jit_internal.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
circular dependency problems
55
"""
66

7+
import threading
8+
import sys
79
import inspect
810
import weakref
911
import torch._C
@@ -276,6 +278,50 @@ def _get_fn_overloads(qual_name):
276278
def _clear_fn_overloads(qual_name):
277279
del _overloaded_fns[qual_name]
278280

281+
def get_class_name(method):
282+
current_frame = sys._current_frames()[threading.currentThread().ident]
283+
284+
# one for the get_class_name call, one for _overload_method call
285+
for i in range(2):
286+
current_frame = current_frame.f_back
287+
class_name = current_frame.f_code.co_name
288+
return class_name
289+
290+
# At the the point the decorator is applied to class methods the method
291+
# has no reference to its owning class. _qualified_name would not include
292+
# the class it is defined in, so any methods with the same name in the same file
293+
# would have the same _qualified_name, even if they were defined in different
294+
# classes. This problem only exists in python 2.
295+
# We get around this problem by looking at the stack frame and identifying
296+
# the class name.
297+
298+
# qualified_name => class name => list[overload_functions]
299+
_overloaded_methods = {} # noqa: T484
300+
301+
def _overload_method(func):
302+
qual_name = _qualified_name(func)
303+
global _overloaded_methods
304+
class_name_map = _overloaded_methods.get(qual_name, None)
305+
if class_name_map is None:
306+
class_name_map = {}
307+
_overloaded_methods[qual_name] = class_name_map
308+
309+
class_name = get_class_name(func)
310+
method_overloads = class_name_map.get(class_name, None)
311+
if method_overloads is None:
312+
method_overloads = []
313+
class_name_map[class_name] = method_overloads
314+
315+
method_overloads.append(func)
316+
return func
317+
318+
def _get_overloaded_methods(method, class_name):
319+
qual_name = _qualified_name(method)
320+
class_name_map = _overloaded_methods.get(qual_name, None)
321+
if class_name_map is None:
322+
return None
323+
return class_name_map.get(class_name, None)
324+
279325
try:
280326
import typing
281327
from typing import Tuple, List, Dict, Optional

torch/csrc/jit/script/init.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -750,7 +750,14 @@ void initJitScriptBindings(PyObject* module) {
750750
auto new_def = implementation_def.withDecl(overload_decl);
751751
return script_compile_function(name, new_def, defaults, std::move(rcb));
752752
});
753-
753+
m.def(
754+
"_replace_overloaded_method_decl",
755+
[](const Decl& overload_decl,
756+
const Def& implementation_def,
757+
const std::string& new_name) {
758+
checkOverloadDecl(overload_decl, implementation_def.decl());
759+
return implementation_def.withDecl(overload_decl).withName(new_name);
760+
});
754761
m.def(
755762
"_create_function_from_trace",
756763
[](std::string qualname,

torch/csrc/jit/script/python_sugared_value.cpp

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -218,25 +218,29 @@ std::shared_ptr<SugaredValue> OverloadedMethodValue::call(
218218
std::vector<NamedValue> new_inputs = inputs.vec();
219219
new_inputs.insert(new_inputs.begin(), module_);
220220

221-
for (const std::string& method_name : method_names_) {
222-
auto cls = module_->type()->expect<ClassType>();
223-
const auto fn = cls->getMethod(method_name);
224-
auto match = tryMatchSchema(
225-
fn->getSchema(),
226-
loc,
227-
*caller.graph().get(),
228-
c10::nullopt,
229-
new_inputs,
230-
attributes,
231-
&err,
232-
true);
233-
if (match) {
234-
return MethodValue(module_, method_name)
235-
.call(loc, caller, inputs, attributes, n_binders);
221+
std::stringstream failure_messages;
222+
for (bool allow_conversions : {false, true}) {
223+
// clear previous error messages
224+
failure_messages.str("");
225+
for (const std::string& method_name : method_names_) {
226+
auto cls = module_->type()->expect<ClassType>();
227+
const auto fn = cls->getMethod(method_name);
228+
auto match = tryMatchSchema(
229+
fn->getSchema(),
230+
loc,
231+
*caller.graph().get(),
232+
c10::nullopt,
233+
new_inputs,
234+
attributes,
235+
&err,
236+
allow_conversions);
237+
if (match) {
238+
return MethodValue(module_, method_name)
239+
.call(loc, caller, inputs, attributes, n_binders);
240+
}
236241
}
237242
}
238-
throw ErrorReport(loc) << "Could not find any matching overloads\n"
239-
<< err.str();
243+
throw ErrorReport(loc) << failure_messages.str();
240244
}
241245

242246
std::shared_ptr<SugaredValue> OverloadedFunctionValue::call(

torch/csrc/jit/script/python_tree_views.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ void initTreeViewBindings(PyObject* module) {
150150
const auto& r = name.range();
151151
return Def::create(r, name, decl, wrap_list(r, std::move(body)));
152152
}))
153-
.def("decl", [](const Def& def) { return def.decl(); });
153+
.def("decl", [](const Def& def) { return def.decl(); })
154+
.def("name", [](const Def& def) { return def.name(); });
154155
py::class_<ClassDef, TreeView>(m, "ClassDef")
155156
.def(py::init([](const Ident& name, std::vector<Stmt> body) {
156157
const auto& r = name.range();

torch/jit/__init__.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from collections import OrderedDict, namedtuple
3232

3333
# These are imported so users can access them from the `torch.jit` module
34-
from torch._jit_internal import Final, _overload # noqa: F401
34+
from torch._jit_internal import Final, _overload, _overload_method # noqa: F401
3535
from torch._jit_internal import ignore, export # noqa: F401
3636

3737
if sys.version_info[0] > 2:
@@ -1903,10 +1903,14 @@ def _compile_function_with_overload(qual_name, impl_fn, overload_decl, overload_
19031903
fn = torch._C._jit_script_compile_overload(qual_name, overload_decl, impl_ast, _rcb, overload_defaults)
19041904
return fn
19051905

1906-
def _get_overload_decl_and_defaults(func):
1906+
def _check_no_signature(func):
19071907
signature = torch.jit.annotations.get_signature(func)
19081908
if signature is None:
1909-
raise RuntimeError("Must explicitly add type annotations to overloaded functions: {obj}").format(func)
1909+
qual_name = _qualified_name(func)
1910+
raise RuntimeError("Must explicitly add type annotations to overloaded functions: {}".format(qual_name))
1911+
1912+
def _get_overload_decl_and_defaults(func):
1913+
_check_no_signature(func)
19101914
return (torch.jit.get_jit_def(func).decl(), get_default_args(func))
19111915

19121916
def _get_overloads(obj):

torch/jit/_recursive.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,19 +117,48 @@ def recursive_script(mod):
117117
if not _jit_internal.is_ignored_fn(mod.forward):
118118
methods = ('forward',)
119119
exported = []
120+
overloads = []
120121
for name in dir(mod):
121122
item = getattr(mod, name)
122123
if callable(item):
123124
if _jit_internal.get_torchscript_modifier(item) is _jit_internal.FunctionModifiers.EXPORT:
124125
exported.append(name)
126+
127+
# builtin functions like repr() in python 2 do not have __module__ defined
128+
if hasattr(item, "__module__") and item.__module__ is not None:
129+
method_overloads = _jit_internal._get_overloaded_methods(item, mod._get_name())
130+
if method_overloads is not None:
131+
overloads.append((item, method_overloads))
132+
125133
methods = methods + tuple(exported)
126134

135+
overload_name_mappings = dict(getattr(mod, "__overloads__", {}))
136+
overload_stubs = []
137+
138+
for orig_fn, overload_fns in overloads:
139+
orig_ast = torch.jit.get_jit_def(orig_fn, self_name="ScriptModule")
140+
names = list(map(lambda i: orig_ast.name().name + "__" + str(i), range(len(overload_fns))))
141+
overload_name_mappings[orig_ast.name().name] = names
142+
for overload_fn, name in zip(overload_fns, names):
143+
torch.jit._check_no_signature(overload_fn)
144+
over_ast = torch.jit.get_jit_def(overload_fn, self_name="ScriptModule")
145+
new_ast = torch._C._replace_overloaded_method_decl(over_ast.decl(), orig_ast, name)
146+
_rcb = _jit_internal.createResolutionCallbackFromClosure(orig_fn)
147+
overload_stubs.append(torch.jit.ScriptMethodStub(_rcb, new_ast, overload_fn))
148+
149+
mod.__overloads__ = overload_name_mappings
150+
151+
# we shouldn't directly compile overloaded methods, just its overloads
152+
def ignore_overloaded(method_name):
153+
return method_name not in overload_name_mappings
154+
127155
def make_stub(method):
128156
func = get_function_from_type(type(mod), method)
129157
return torch.jit.script_method(func, _jit_internal.createResolutionCallbackFromClosure(func))
130158

131-
stubs = list(map(make_stub, methods))
132-
return copy_to_script_module(mod, stubs)
159+
filtered_methods = filter(ignore_overloaded, methods)
160+
stubs = list(map(make_stub, filtered_methods))
161+
return copy_to_script_module(mod, overload_stubs + stubs)
133162

134163

135164
def create_method_from_fn(module, fn):

0 commit comments

Comments
 (0)