Skip to content

Commit 8e3c021

Browse files
Elias Ellisonfacebook-github-bot
authored andcommitted
extend torch.jit._overload to module methods (#24259)
Summary: Pull Request resolved: #24259 Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266. The usage is: ``` torch.jit.overload def add(self, y: int) -> int: ... torch.jit.overload def add(self, y: float) -> float: ... def add(): ... ``` Test Plan: Imported from OSS Differential Revision: D16921304 Pulled By: eellison fbshipit-source-id: 784e2f26f7ca9a330a434a603c86b53725c3dc71
1 parent 4b3ea92 commit 8e3c021

File tree

7 files changed

+274
-50
lines changed

7 files changed

+274
-50
lines changed

test/test_jit.py

Lines changed: 136 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13274,48 +13274,158 @@ def test_uses():
1327413274

1327513275
self.checkScript(test_uses, ())
1327613276

13277-
@unittest.skipIf(True, "Removing weak script")
13278-
def test_overloading(self):
13279-
@torch._jit_internal.weak_module
13280-
class W(torch.nn.Module):
13281-
__overloads__ = {'forward': ['forward_tuple', 'forward_tensor']}
13282-
13277+
def test_method_overloading(self):
13278+
class Over(torch.nn.Module):
1328313279
def __init__(self):
13284-
super(W, self).__init__()
13280+
super(Over, self).__init__()
1328513281

13286-
@torch._jit_internal.weak_script_method
13287-
def forward_tuple(self, x):
13282+
@torch.jit._overload_method # noqa: F811
13283+
def forward(self, x): # noqa: F811
1328813284
# type: (Tuple[Tensor, Tensor]) -> Tensor
13289-
return x[0] + 5
13290-
13291-
def forward(self, x):
13292-
# manually do argument switching
13293-
if isinstance(x, tuple):
13294-
return self.forward_tuple(x)
13295-
else:
13296-
return self.forward_tensor(x)
13285+
pass
1329713286

13298-
@torch._jit_internal.weak_script_method
13299-
def forward_tensor(self, x):
13287+
@torch.jit._overload_method # noqa: F811
13288+
def forward(self, x): # noqa: F811
1330013289
# type: (Tensor) -> Tensor
13301-
return x + 20
13290+
pass
13291+
13292+
def forward(self, x): # noqa: F811
13293+
if isinstance(x, Tensor):
13294+
return x + 20
13295+
else:
13296+
return x[0] + 5
1330213297

1330313298
class S(torch.jit.ScriptModule):
1330413299
def __init__(self):
1330513300
super(S, self).__init__()
13306-
self.weak = W()
13301+
self.weak = Over()
1330713302

1330813303
@torch.jit.script_method
1330913304
def forward(self, x):
1331013305
return self.weak(x) + self.weak((x, x))
1331113306

13312-
s = S()
13307+
s_mod = S()
1331313308
x = torch.ones(1)
13314-
self.assertEqual(s(x), x + 20 + 5 + x)
13309+
self.assertEqual(s_mod(x), x + 20 + 5 + x)
13310+
13311+
over = Over()
13312+
self.assertEqual(over((x, x)), x + 5)
13313+
self.assertEqual(over((x)), x + 20)
13314+
13315+
class Unannotated(torch.nn.Module):
13316+
def __init__(self):
13317+
super(Unannotated, self).__init__()
13318+
13319+
@torch.jit._overload_method # noqa: F811
13320+
def hello(self, x): # noqa: F811
13321+
pass
13322+
13323+
@torch.jit._overload_method # noqa: F811
13324+
def hello(self, x): # noqa: F811
13325+
# type: (int) -> (int)
13326+
pass
13327+
13328+
def hello(self, x): # noqa: F811
13329+
return x + 3
13330+
13331+
def forward(self):
13332+
return self.hello(1), self.hello(.5)
13333+
13334+
w = Unannotated()
13335+
with self.assertRaisesRegex(Exception, "explicitly add type annotations to overloaded functions"):
13336+
torch.jit.script(w)
13337+
13338+
class CompileOverloadError(torch.nn.Module):
13339+
def __init__(self):
13340+
super(CompileOverloadError, self).__init__()
13341+
13342+
@torch.jit._overload_method # noqa: F811
13343+
def hello(self, x): # noqa: F811
13344+
# type: (str) -> (int)
13345+
pass
13346+
13347+
@torch.jit._overload_method # noqa: F811
13348+
def hello(self, x): # noqa: F811
13349+
# type: (int) -> (int)
13350+
pass
13351+
13352+
def hello(self, x): # noqa: F811
13353+
return x + 1
13354+
13355+
def forward(self):
13356+
return self.hello("hi"), self.hello(.5)
13357+
13358+
w = CompileOverloadError()
13359+
with self.assertRaisesRegex(Exception, "but instead found type \'str\'"):
13360+
torch.jit.script(w)
13361+
13362+
# testing overload declared first, then non-overload
13363+
with self.assertRaisesRegex(Exception, "Overloads are not useable when a module"):
13364+
class W3(torch.nn.Module):
13365+
def __init__(self):
13366+
super(W3, self).__init__()
13367+
13368+
@torch.jit._overload_method # noqa: F811
13369+
def forward(self, x): # noqa: F811
13370+
# type: (int) -> int
13371+
pass
13372+
13373+
@torch.jit._overload_method # noqa: F811
13374+
def forward(self, x): # noqa: F811
13375+
# type: (Tensor) -> Tensor
13376+
pass
13377+
13378+
def forward(self, x): # noqa: F811
13379+
return x + 5
13380+
13381+
a = W3()
13382+
b = torch.jit.script(a)
13383+
13384+
class W3(torch.nn.Module):
13385+
def __init__(self):
13386+
super(W3, self).__init__()
13387+
13388+
def forward(self, x): # noqa: F811
13389+
return x + 5 + 10
13390+
13391+
a = W3()
13392+
b = torch.jit.script(a)
13393+
13394+
# testing non-overload declared first, then overload
13395+
class W2(torch.nn.Module):
13396+
def __init__(self):
13397+
super(W2, self).__init__()
13398+
13399+
def hello(self, x1, x2):
13400+
return x1 + x2
13401+
13402+
def forward(self, x):
13403+
return self.hello(x, x)
13404+
13405+
a = torch.jit.script(W2())
13406+
self.assertEqual(a(torch.tensor(1)), torch.tensor(2))
13407+
13408+
class W2(torch.nn.Module):
13409+
def __init__(self):
13410+
super(W2, self).__init__()
13411+
13412+
@torch.jit._overload_method # noqa: F811
13413+
def hello(self, x): # noqa: F811
13414+
pass
13415+
13416+
@torch.jit._overload_method # noqa: F811
13417+
def hello(self, x): # noqa: F811
13418+
# type: (int) -> (int)
13419+
pass
13420+
13421+
def hello(self, x): # noqa: F811
13422+
return x + 5 + 10
13423+
13424+
def forward(self, x):
13425+
return self.hello(1), self.hello(x)
1331513426

13316-
w = W()
13317-
self.assertEqual(w((x, x)), x + 5)
13318-
self.assertEqual(w((x)), x + 20)
13427+
with self.assertRaisesRegex(Exception, "Overloads are not useable when a module"):
13428+
a = torch.jit.script(W2())
1331913429

1332013430
def test_select_after_chunk(self):
1332113431
def foo(x):

torch/_jit_internal.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,74 @@ def _get_fn_overloads(qual_name):
285285
def _clear_fn_overloads(qual_name):
286286
del _overloaded_fns[qual_name]
287287

288+
def get_class_name_lineno(method):
289+
current_frame = inspect.currentframe()
290+
291+
# one for the get_class_name call, one for _overload_method call
292+
for i in range(2):
293+
current_frame = current_frame.f_back
294+
class_name = current_frame.f_code.co_name
295+
line_no = current_frame.f_code.co_firstlineno
296+
return class_name, line_no
297+
298+
# At the the point the decorator is applied to class methods the method
299+
# has no reference to its owning class. _qualified_name would not include
300+
# the class it is defined in, so any methods with the same name in the same file
301+
# would have the same _qualified_name, even if they were defined in different
302+
# classes. This problem only exists in python 2.
303+
# We get around this problem by looking at the stack frame and identifying
304+
# the class name, and throwing an error whenever overloads are used
305+
# when modules of the same name are in the same file
306+
307+
# qualified_name => class name => list[overload_functions]
308+
_overloaded_methods = {} # noqa: T484
309+
310+
311+
# (qualified_name, class name) => class_fileno
312+
_overloaded_method_class_fileno = {}
313+
314+
def _overload_method(func):
315+
qual_name = _qualified_name(func)
316+
global _overloaded_methods
317+
class_name_map = _overloaded_methods.get(qual_name, None)
318+
if class_name_map is None:
319+
class_name_map = {}
320+
_overloaded_methods[qual_name] = class_name_map
321+
322+
class_name, line_no = get_class_name_lineno(func)
323+
method_overloads = class_name_map.get(class_name, None)
324+
if method_overloads is None:
325+
method_overloads = []
326+
class_name_map[class_name] = method_overloads
327+
_overloaded_method_class_fileno[(qual_name, class_name)] = line_no
328+
else:
329+
existing_lineno = _overloaded_method_class_fileno[(qual_name, class_name)]
330+
if existing_lineno != line_no:
331+
raise RuntimeError("Cannot currently overload the same method name in two different"
332+
" classes with the same name in the same module")
333+
334+
method_overloads.append(func)
335+
return func
336+
337+
def _get_overloaded_methods(method, mod_class):
338+
# TODO: __name__ not set for submodules in recursive script
339+
if not hasattr(method, "__name__"):
340+
return None
341+
qual_name = _qualified_name(method)
342+
class_name_map = _overloaded_methods.get(qual_name, None)
343+
if class_name_map is None:
344+
return None
345+
overloads = class_name_map.get(mod_class.__name__, None)
346+
if overloads is None:
347+
return None
348+
349+
method_line_no = inspect.getsourcelines(method)[1]
350+
mod_class_fileno = inspect.getsourcelines(mod_class)[1]
351+
mod_end_fileno = mod_class_fileno + len(inspect.getsourcelines(mod_class)[0])
352+
if not (method_line_no >= mod_class_fileno and method_line_no <= mod_end_fileno):
353+
raise Exception("Overloads are not useable when a module is redaclared within the same file: " + str(method))
354+
return overloads
355+
288356
try:
289357
import typing
290358
from typing import Tuple, List, Dict, Optional

torch/csrc/jit/script/init.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -743,7 +743,14 @@ void initJitScriptBindings(PyObject* module) {
743743
auto new_def = implementation_def.withDecl(overload_decl);
744744
return script_compile_function(name, new_def, defaults, std::move(rcb));
745745
});
746-
746+
m.def(
747+
"_replace_overloaded_method_decl",
748+
[](const Decl& overload_decl,
749+
const Def& implementation_def,
750+
const std::string& new_name) {
751+
checkOverloadDecl(overload_decl, implementation_def.decl());
752+
return implementation_def.withDecl(overload_decl).withName(new_name);
753+
});
747754
m.def(
748755
"_create_function_from_trace",
749756
[](std::string qualname,

torch/csrc/jit/script/python_sugared_value.cpp

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -218,25 +218,29 @@ std::shared_ptr<SugaredValue> OverloadedMethodValue::call(
218218
std::vector<NamedValue> new_inputs = inputs.vec();
219219
new_inputs.insert(new_inputs.begin(), module_);
220220

221-
for (const std::string& method_name : method_names_) {
222-
auto cls = module_->type()->expect<ClassType>();
223-
const auto fn = cls->getMethod(method_name);
224-
auto match = tryMatchSchema(
225-
fn->getSchema(),
226-
loc,
227-
*caller.graph().get(),
228-
c10::nullopt,
229-
new_inputs,
230-
attributes,
231-
&err,
232-
true);
233-
if (match) {
234-
return MethodValue(module_, method_name)
235-
.call(loc, caller, inputs, attributes, n_binders);
221+
std::stringstream failure_messages;
222+
for (bool allow_conversions : {false, true}) {
223+
// clear previous error messages
224+
failure_messages.str("");
225+
for (const std::string& method_name : method_names_) {
226+
auto cls = module_->type()->expect<ClassType>();
227+
const auto fn = cls->getMethod(method_name);
228+
auto match = tryMatchSchema(
229+
fn->getSchema(),
230+
loc,
231+
*caller.graph().get(),
232+
c10::nullopt,
233+
new_inputs,
234+
attributes,
235+
&err,
236+
allow_conversions);
237+
if (match) {
238+
return MethodValue(module_, method_name)
239+
.call(loc, caller, inputs, attributes, n_binders);
240+
}
236241
}
237242
}
238-
throw ErrorReport(loc) << "Could not find any matching overloads\n"
239-
<< err.str();
243+
throw ErrorReport(loc) << failure_messages.str();
240244
}
241245

242246
std::shared_ptr<SugaredValue> OverloadedFunctionValue::call(

torch/csrc/jit/script/python_tree_views.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ void initTreeViewBindings(PyObject* module) {
146146
const auto& r = name.range();
147147
return Def::create(r, name, decl, wrap_list(r, std::move(body)));
148148
}))
149-
.def("decl", [](const Def& def) { return def.decl(); });
149+
.def("decl", [](const Def& def) { return def.decl(); })
150+
.def("name", [](const Def& def) { return def.name(); });
150151
py::class_<ClassDef, TreeView>(m, "ClassDef")
151152
.def(py::init([](const Ident& name, std::vector<Stmt> body) {
152153
const auto& r = name.range();

torch/jit/__init__.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from collections import OrderedDict, namedtuple
3232

3333
# These are imported so users can access them from the `torch.jit` module
34-
from torch._jit_internal import Final, _overload # noqa: F401
34+
from torch._jit_internal import Final, _overload, _overload_method # noqa: F401
3535
from torch._jit_internal import ignore, export # noqa: F401
3636

3737
if sys.version_info[0] > 2:
@@ -1876,10 +1876,14 @@ def _compile_function_with_overload(qual_name, impl_fn, overload_decl, overload_
18761876
fn = torch._C._jit_script_compile_overload(qual_name, overload_decl, impl_ast, _rcb, overload_defaults)
18771877
return fn
18781878

1879-
def _get_overload_decl_and_defaults(func):
1879+
def _check_no_signature(func):
18801880
signature = torch.jit.annotations.get_signature(func)
18811881
if signature is None:
1882-
raise RuntimeError("Must explicitly add type annotations to overloaded functions: {obj}").format(func)
1882+
qual_name = _qualified_name(func)
1883+
raise RuntimeError("Must explicitly add type annotations to overloaded functions: {}".format(qual_name))
1884+
1885+
def _get_overload_decl_and_defaults(func):
1886+
_check_no_signature(func)
18831887
return (torch.jit.get_jit_def(func).decl(), get_default_args(func))
18841888

18851889
def _get_overloads(obj):

0 commit comments

Comments
 (0)