Skip to content

Commit 98df96d

Browse files
author
root
committed
Update on "[JIT] add support for overloading functions"
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in #23266. Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading). The usage is: ``` @torch.jit.overload def add(x: int, y: int) -> int: ... @torch.jit.overload def add(x: int, y: int) -> int: ... def add: return x + y ``` Follow up PRs: - Add same API for methods - A couple of cleanups for functions: - don't require default params specified on the overload as well - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently
1 parent 397328b commit 98df96d

File tree

3 files changed

+50
-29
lines changed

3 files changed

+50
-29
lines changed

test/test_jit.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12914,8 +12914,30 @@ def test():
1291412914
impl_compile_failure("one", "two")
1291512915

1291612916

12917-
with self.assertRaisesRegex(Exception, r"# type: (str, str) -> (str"):
12918-
torch.jit.script(impl_compile_failure)
12917+
with self.assertRaisesRegex(Exception, "Arguments for call are not valid"):
12918+
torch.jit.script(test)
12919+
12920+
def test_function_overloading_isinstance(self):
12921+
@torch.jit.overload # noqa: F811
12922+
def my_conv(x, y): # noqa: F811
12923+
# type: (float, str) -> (float)
12924+
pass
12925+
12926+
@torch.jit.overload # noqa: F811
12927+
def my_conv(x, y=2.0): # noqa: F811
12928+
# type: (float, float) -> (float)
12929+
pass
12930+
12931+
def my_conv(x, y=2.0): # noqa: F811
12932+
if isinstance(y, str):
12933+
return 4.0 - x
12934+
else:
12935+
return 2.0 + x
12936+
12937+
def test_uses():
12938+
return my_conv(1.5), my_conv(1.5, "hi"), my_conv(1.5, 5.0)
12939+
12940+
self.checkScript(test_uses, ())
1291912941

1292012942
@unittest.skipIf(True, "Removing weak script")
1292112943
def test_overloading(self):

torch/csrc/jit/script/init.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,13 +180,13 @@ void checkOverloadDecl(const Decl& new_decl, const Decl& old_decl) {
180180

181181
TORCH_INTERNAL_ASSERT(
182182
new_params.size() == old_params.size(),
183-
"Overload must have same number of parameters",
183+
"Overload must have same number of parameters\n",
184184
new_decl.range(),
185185
old_decl.range());
186186
for (size_t i = 0; i < new_decl.params().size(); ++i) {
187187
TORCH_INTERNAL_ASSERT(
188188
new_params[i].ident().name() == old_params[i].ident().name(),
189-
"Overload parameters must have the same names",
189+
"Overload parameters must have the same names\n",
190190
new_params[i].ident(),
191191
old_params[i].ident());
192192
}

torch/jit/__init__.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,22 +1157,28 @@ def forward(self, input):
11571157
check_directly_compile_overloaded(obj)
11581158
ast = get_jit_def(obj)
11591159
if _rcb is None:
1160-
closure_rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
1161-
stack_rcb = _jit_internal.createResolutionCallback(_frames_up + 1)
1162-
1163-
def _rcb(name):
1164-
# since type comments aren't captured in the function's closures,
1165-
# we still need to try to the rcb based on stack frames if the
1166-
# closure rcb fails
1167-
result = closure_rcb(name)
1168-
if result:
1169-
return result
1170-
return stack_rcb(name)
1160+
_rcb = _gen_rcb(obj, _frames_up)
11711161
fn = torch._C._jit_script_compile(qualified_name, ast, _rcb, get_default_args(obj))
11721162
# Forward docstrings
11731163
fn.__doc__ = obj.__doc__
11741164
return fn
11751165

1166+
def _gen_rcb(obj, _frames_up):
1167+
_frames_up = _frames_up + 1 # for invoking _gen_rcb()
1168+
1169+
closure_rcb = _jit_internal.createResolutionCallbackFromClosure(obj)
1170+
stack_rcb = _jit_internal.createResolutionCallback(_frames_up + 1)
1171+
1172+
def _rcb(name):
1173+
# since type comments aren't captured in the function's closures,
1174+
# we still need to try to the rcb based on stack frames if the
1175+
# closure rcb fails
1176+
result = closure_rcb(name)
1177+
if result:
1178+
return result
1179+
return stack_rcb(name)
1180+
1181+
return _rcb
11761182

11771183
ScriptMethodStub = namedtuple('ScriptMethodStub', ('resolution_callback', 'def_', 'original_method'))
11781184

@@ -2088,14 +2094,14 @@ def _get_script_class(name):
20882094
def overload(func):
20892095
qual_name = _qualified_name(func)
20902096
global _overloaded_fns
2091-
li = _overloaded_fns.get(qual_name)
2092-
if li is None:
2093-
li = []
2094-
_overloaded_fns[qual_name] = li
2097+
fn_overload_list = _overloaded_fns.get(qual_name)
2098+
if fn_overload_list is None:
2099+
fn_overload_list = []
2100+
_overloaded_fns[qual_name] = fn_overload_list
20952101
signature = torch.jit.annotations.get_signature(func)
20962102
if signature is None:
20972103
raise RuntimeError("Must explicitly add type annotations to overloaded functions: {obj}").format(func)
2098-
li.append((torch.jit.get_jit_def(func).decl(), get_default_args(func)))
2104+
fn_overload_list.append((torch.jit.get_jit_def(func).decl(), get_default_args(func)))
20992105
return func
21002106

21012107
def compile_function_with_overload(qual_name, impl_fn, overload_decl, overload_defaults):
@@ -2105,15 +2111,8 @@ def compile_function_with_overload(qual_name, impl_fn, overload_decl, overload_d
21052111
# and refactor with above usage
21062112
closure_rcb = _jit_internal.createResolutionCallbackFromClosure(impl_fn)
21072113
stack_rcb = _jit_internal.createResolutionCallback(_frames_up + 1)
2108-
2109-
def _rcb(name):
2110-
# since type comments aren't captured in the function's closures,
2111-
# we still need to try to the rcb based on stack frames if the
2112-
# closure rcb fails
2113-
result = closure_rcb(name)
2114-
if result:
2115-
return result
2116-
return stack_rcb(name)
2114+
_frames_up = 0
2115+
_rcb = _gen_rcb(impl_fn, _frames_up)
21172116
fn = torch._C._jit_script_compile_overload(qual_name, overload_decl, impl_ast, _rcb, overload_defaults)
21182117
return fn
21192118

0 commit comments

Comments
 (0)