Skip to content

Commit 74c864f

Browse files
author
root
committed
extend torch.jit._overload to modules
ghstack-source-id: c4fbf4b Pull Request resolved: #24259
1 parent 1daac9c commit 74c864f

File tree

8 files changed

+279
-53
lines changed

8 files changed

+279
-53
lines changed

test/test_jit.py

Lines changed: 136 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13166,48 +13166,158 @@ def test_uses():
1316613166

1316713167
self.checkScript(test_uses, ())
1316813168

13169-
@unittest.skipIf(True, "Removing weak script")
13170-
def test_overloading(self):
13171-
@torch._jit_internal.weak_module
13172-
class W(torch.nn.Module):
13173-
__overloads__ = {'forward': ['forward_tuple', 'forward_tensor']}
13174-
13169+
def test_method_overloading(self):
13170+
class Over(torch.nn.Module):
1317513171
def __init__(self):
13176-
super(W, self).__init__()
13172+
super(Over, self).__init__()
1317713173

13178-
@torch._jit_internal.weak_script_method
13179-
def forward_tuple(self, x):
13174+
@torch.jit._overload_method # noqa: F811
13175+
def forward(self, x): # noqa: F811
1318013176
# type: (Tuple[Tensor, Tensor]) -> Tensor
13181-
return x[0] + 5
13182-
13183-
def forward(self, x):
13184-
# manually do argument switching
13185-
if isinstance(x, tuple):
13186-
return self.forward_tuple(x)
13187-
else:
13188-
return self.forward_tensor(x)
13177+
pass
1318913178

13190-
@torch._jit_internal.weak_script_method
13191-
def forward_tensor(self, x):
13179+
@torch.jit._overload_method # noqa: F811
13180+
def forward(self, x): # noqa: F811
1319213181
# type: (Tensor) -> Tensor
13193-
return x + 20
13182+
pass
13183+
13184+
def forward(self, x): # noqa: F811
13185+
if isinstance(x, Tensor):
13186+
return x + 20
13187+
else:
13188+
return x[0] + 5
1319413189

1319513190
class S(torch.jit.ScriptModule):
1319613191
def __init__(self):
1319713192
super(S, self).__init__()
13198-
self.weak = W()
13193+
self.weak = Over()
1319913194

1320013195
@torch.jit.script_method
1320113196
def forward(self, x):
1320213197
return self.weak(x) + self.weak((x, x))
1320313198

13204-
s = S()
13199+
s_mod = S()
1320513200
x = torch.ones(1)
13206-
self.assertEqual(s(x), x + 20 + 5 + x)
13201+
self.assertEqual(s_mod(x), x + 20 + 5 + x)
13202+
13203+
over = Over()
13204+
self.assertEqual(over((x, x)), x + 5)
13205+
self.assertEqual(over((x)), x + 20)
13206+
13207+
class Unannotated(torch.nn.Module):
13208+
def __init__(self):
13209+
super(Unannotated, self).__init__()
13210+
13211+
@torch.jit._overload_method # noqa: F811
13212+
def hello(self, x): # noqa: F811
13213+
pass
13214+
13215+
@torch.jit._overload_method # noqa: F811
13216+
def hello(self, x): # noqa: F811
13217+
# type: (int) -> (int)
13218+
pass
13219+
13220+
def hello(self, x): # noqa: F811
13221+
return x + 3
13222+
13223+
def forward(self):
13224+
return self.hello(1), self.hello(.5)
13225+
13226+
w = Unannotated()
13227+
with self.assertRaisesRegex(Exception, "explicitly add type annotations to overloaded functions"):
13228+
torch.jit.script(w)
13229+
13230+
class CompileOverloadError(torch.nn.Module):
13231+
def __init__(self):
13232+
super(CompileOverloadError, self).__init__()
13233+
13234+
@torch.jit._overload_method # noqa: F811
13235+
def hello(self, x): # noqa: F811
13236+
# type: (str) -> (int)
13237+
pass
13238+
13239+
@torch.jit._overload_method # noqa: F811
13240+
def hello(self, x): # noqa: F811
13241+
# type: (int) -> (int)
13242+
pass
13243+
13244+
def hello(self, x): # noqa: F811
13245+
return x + 1
13246+
13247+
def forward(self):
13248+
return self.hello("hi"), self.hello(.5)
13249+
13250+
w = CompileOverloadError()
13251+
with self.assertRaisesRegex(Exception, "but instead found type \'str\'"):
13252+
torch.jit.script(w)
13253+
13254+
# testing overload declared first, then non-overload
13255+
with self.assertRaisesRegex(Exception, "Overloads are not useable when a module"):
13256+
class W3(torch.nn.Module):
13257+
def __init__(self):
13258+
super(W3, self).__init__()
13259+
13260+
@torch.jit._overload_method # noqa: F811
13261+
def forward(self, x): # noqa: F811
13262+
# type: (int) -> int
13263+
pass
13264+
13265+
@torch.jit._overload_method # noqa: F811
13266+
def forward(self, x): # noqa: F811
13267+
# type: (Tensor) -> Tensor
13268+
pass
13269+
13270+
def forward(self, x): # noqa: F811
13271+
return x + 5
13272+
13273+
a = W3()
13274+
b = torch.jit.script(a)
13275+
13276+
class W3(torch.nn.Module):
13277+
def __init__(self):
13278+
super(W3, self).__init__()
13279+
13280+
def forward(self, x): # noqa: F811
13281+
return x + 5 + 10
13282+
13283+
a = W3()
13284+
b = torch.jit.script(a)
13285+
13286+
# testing non-overload declared first, then overload
13287+
class W2(torch.nn.Module):
13288+
def __init__(self):
13289+
super(W2, self).__init__()
13290+
13291+
def hello(self, x1, x2):
13292+
return x1 + x2
13293+
13294+
def forward(self, x):
13295+
return self.hello(x, x)
13296+
13297+
a = torch.jit.script(W2())
13298+
self.assertEqual(a(torch.tensor(1)), torch.tensor(2))
13299+
13300+
class W2(torch.nn.Module):
13301+
def __init__(self):
13302+
super(W2, self).__init__()
13303+
13304+
@torch.jit._overload_method # noqa: F811
13305+
def hello(self, x): # noqa: F811
13306+
pass
13307+
13308+
@torch.jit._overload_method # noqa: F811
13309+
def hello(self, x): # noqa: F811
13310+
# type: (int) -> (int)
13311+
pass
13312+
13313+
def hello(self, x): # noqa: F811
13314+
return x + 5 + 10
13315+
13316+
def forward(self, x):
13317+
return self.hello(1), self.hello(x)
1320713318

13208-
w = W()
13209-
self.assertEqual(w((x, x)), x + 5)
13210-
self.assertEqual(w((x)), x + 20)
13319+
with self.assertRaisesRegex(Exception, "Overloads are not useable when a module"):
13320+
a = torch.jit.script(W2())
1321113321

1321213322
def test_select_after_chunk(self):
1321313323
def foo(x):

torch/_jit_internal.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,74 @@ def _get_fn_overloads(qual_name):
276276
def _clear_fn_overloads(qual_name):
277277
del _overloaded_fns[qual_name]
278278

279+
def get_class_name_lineno(method):
280+
current_frame = inspect.currentframe()
281+
282+
# one for the get_class_name call, one for _overload_method call
283+
for i in range(2):
284+
current_frame = current_frame.f_back
285+
class_name = current_frame.f_code.co_name
286+
line_no = current_frame.f_code.co_firstlineno
287+
return class_name, line_no
288+
289+
# At the the point the decorator is applied to class methods the method
290+
# has no reference to its owning class. _qualified_name would not include
291+
# the class it is defined in, so any methods with the same name in the same file
292+
# would have the same _qualified_name, even if they were defined in different
293+
# classes. This problem only exists in python 2.
294+
# We get around this problem by looking at the stack frame and identifying
295+
# the class name, and throwing an error whenever overloads are used
296+
# when modules of the same name are in the same file
297+
298+
# qualified_name => class name => list[overload_functions]
299+
_overloaded_methods = {} # noqa: T484
300+
301+
302+
# (qualified_name, class name) => class_fileno
303+
_overloaded_method_class_fileno = {}
304+
305+
def _overload_method(func):
306+
qual_name = _qualified_name(func)
307+
global _overloaded_methods
308+
class_name_map = _overloaded_methods.get(qual_name, None)
309+
if class_name_map is None:
310+
class_name_map = {}
311+
_overloaded_methods[qual_name] = class_name_map
312+
313+
class_name, line_no = get_class_name_lineno(func)
314+
method_overloads = class_name_map.get(class_name, None)
315+
if method_overloads is None:
316+
method_overloads = []
317+
class_name_map[class_name] = method_overloads
318+
_overloaded_method_class_fileno[(qual_name, class_name)] = line_no
319+
else:
320+
existing_lineno = _overloaded_method_class_fileno[(qual_name, class_name)]
321+
if existing_lineno != line_no:
322+
raise RuntimeError("Cannot currently overload the same method name in two different"
323+
" classes with the same name in the same module")
324+
325+
method_overloads.append(func)
326+
return func
327+
328+
def _get_overloaded_methods(method, mod_class):
329+
# TODO: __name__ not set for submodules in recursive script
330+
if not hasattr(method, "__name__"):
331+
return None
332+
qual_name = _qualified_name(method)
333+
class_name_map = _overloaded_methods.get(qual_name, None)
334+
if class_name_map is None:
335+
return None
336+
overloads = class_name_map.get(mod_class.__name__, None)
337+
if overloads is None:
338+
return None
339+
340+
method_line_no = inspect.getsourcelines(method)[1]
341+
mod_class_fileno = inspect.getsourcelines(mod_class)[1]
342+
mod_end_fileno = mod_class_fileno + len(inspect.getsourcelines(mod_class)[0])
343+
if not (method_line_no >= mod_class_fileno and method_line_no <= mod_end_fileno):
344+
raise Exception("Overloads are not useable when a module is redaclared within the same file: " + str(method))
345+
return overloads
346+
279347
try:
280348
import typing
281349
from typing import Tuple, List, Dict, Optional

torch/csrc/jit/script/class_type.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ ClassType::ClassType(
7777
c10::optional<QualifiedName> name,
7878
std::weak_ptr<CompilationUnit> cu,
7979
bool is_module)
80-
: NamedType(TypeKind::ClassType, name), compilation_unit_(std::move(cu)) {
80+
: NamedType(TypeKind::ClassType, std::move(name)),
81+
compilation_unit_(std::move(cu)) {
8182
if (is_module) {
8283
parameterSlots_ = std::make_shared<std::vector<bool>>();
8384
}

torch/csrc/jit/script/init.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -750,7 +750,14 @@ void initJitScriptBindings(PyObject* module) {
750750
auto new_def = implementation_def.withDecl(overload_decl);
751751
return script_compile_function(name, new_def, defaults, std::move(rcb));
752752
});
753-
753+
m.def(
754+
"_replace_overloaded_method_decl",
755+
[](const Decl& overload_decl,
756+
const Def& implementation_def,
757+
const std::string& new_name) {
758+
checkOverloadDecl(overload_decl, implementation_def.decl());
759+
return implementation_def.withDecl(overload_decl).withName(new_name);
760+
});
754761
m.def(
755762
"_create_function_from_trace",
756763
[](std::string qualname,
@@ -846,7 +853,7 @@ void initJitScriptBindings(PyObject* module) {
846853
const std::vector<at::Tensor>& constant_table) {
847854
import_functions(
848855
c10::nullopt,
849-
cu,
856+
std::move(cu),
850857
std::make_shared<Source>(src),
851858
constant_table,
852859
nullptr,

torch/csrc/jit/script/python_sugared_value.cpp

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -218,25 +218,29 @@ std::shared_ptr<SugaredValue> OverloadedMethodValue::call(
218218
std::vector<NamedValue> new_inputs = inputs.vec();
219219
new_inputs.insert(new_inputs.begin(), module_);
220220

221-
for (const std::string& method_name : method_names_) {
222-
auto cls = module_->type()->expect<ClassType>();
223-
const auto fn = cls->getMethod(method_name);
224-
auto match = tryMatchSchema(
225-
fn->getSchema(),
226-
loc,
227-
*caller.graph().get(),
228-
c10::nullopt,
229-
new_inputs,
230-
attributes,
231-
&err,
232-
true);
233-
if (match) {
234-
return MethodValue(module_, method_name)
235-
.call(loc, caller, inputs, attributes, n_binders);
221+
std::stringstream failure_messages;
222+
for (bool allow_conversions : {false, true}) {
223+
// clear previous error messages
224+
failure_messages.str("");
225+
for (const std::string& method_name : method_names_) {
226+
auto cls = module_->type()->expect<ClassType>();
227+
const auto fn = cls->getMethod(method_name);
228+
auto match = tryMatchSchema(
229+
fn->getSchema(),
230+
loc,
231+
*caller.graph().get(),
232+
c10::nullopt,
233+
new_inputs,
234+
attributes,
235+
&err,
236+
allow_conversions);
237+
if (match) {
238+
return MethodValue(module_, method_name)
239+
.call(loc, caller, inputs, attributes, n_binders);
240+
}
236241
}
237242
}
238-
throw ErrorReport(loc) << "Could not find any matching overloads\n"
239-
<< err.str();
243+
throw ErrorReport(loc) << failure_messages.str();
240244
}
241245

242246
std::shared_ptr<SugaredValue> OverloadedFunctionValue::call(

torch/csrc/jit/script/python_tree_views.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ void initTreeViewBindings(PyObject* module) {
150150
const auto& r = name.range();
151151
return Def::create(r, name, decl, wrap_list(r, std::move(body)));
152152
}))
153-
.def("decl", [](const Def& def) { return def.decl(); });
153+
.def("decl", [](const Def& def) { return def.decl(); })
154+
.def("name", [](const Def& def) { return def.name(); });
154155
py::class_<ClassDef, TreeView>(m, "ClassDef")
155156
.def(py::init([](const Ident& name, std::vector<Stmt> body) {
156157
const auto& r = name.range();

torch/jit/__init__.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from collections import OrderedDict, namedtuple
3232

3333
# These are imported so users can access them from the `torch.jit` module
34-
from torch._jit_internal import Final, _overload # noqa: F401
34+
from torch._jit_internal import Final, _overload, _overload_method # noqa: F401
3535
from torch._jit_internal import ignore, export # noqa: F401
3636

3737
if sys.version_info[0] > 2:
@@ -1443,7 +1443,8 @@ def weighted_kernel_sum(self, weight):
14431443
14441444
.. note::
14451445
1446-
* The first three trace/trace_module calls are equivalent and return ``ScriptModule`` with a single ``forward`` method.
1446+
* The first three trace/trace_module calls are equivalent and return `
1447+
`ScriptModule`` with a single ``forward`` method.
14471448
* The last ``trace_module`` call produces a ``ScriptModule`` with two methods.
14481449
14491450
Tracing only records operations done when the given function is run on the given
@@ -1903,10 +1904,14 @@ def _compile_function_with_overload(qual_name, impl_fn, overload_decl, overload_
19031904
fn = torch._C._jit_script_compile_overload(qual_name, overload_decl, impl_ast, _rcb, overload_defaults)
19041905
return fn
19051906

1906-
def _get_overload_decl_and_defaults(func):
1907+
def _check_no_signature(func):
19071908
signature = torch.jit.annotations.get_signature(func)
19081909
if signature is None:
1909-
raise RuntimeError("Must explicitly add type annotations to overloaded functions: {obj}").format(func)
1910+
qual_name = _qualified_name(func)
1911+
raise RuntimeError("Must explicitly add type annotations to overloaded functions: {}".format(qual_name))
1912+
1913+
def _get_overload_decl_and_defaults(func):
1914+
_check_no_signature(func)
19101915
return (torch.jit.get_jit_def(func).decl(), get_default_args(func))
19111916

19121917
def _get_overloads(obj):

0 commit comments

Comments
 (0)