Skip to content

Commit 142bd59

Browse files
author
root
committed
Update on "[JIT] add support for overloading functions"
This is a series of PRs that will allow us to support adding [padding to conv](#22484) and also reduce the friction of adding method overloads that was brought up in #23266. Support for overloaded functions following the specification in [PEP 484](https://www.python.org/dev/peps/pep-0484/#function-method-overloading). The usage is: ``` @torch.jit.overload def add(x: int, y: int) -> int: ... @torch.jit.overload def add(x: float, y: float) -> float: ... def add: return x + y ``` Follow up PRs: - Add same API for methods - A couple of cleanups for functions: - don't require default params specified on the overload as well - potentially error if invocation could be matched to multiple overloads. now it just chooses the first one, mypy does the same thing currently Differential Revision: [D16694863](https://our.internmc.facebook.com/intern/diff/D16694863)
1 parent d5ea0a7 commit 142bd59

File tree

4 files changed

+47
-51
lines changed

4 files changed

+47
-51
lines changed

test/test_jit.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12837,12 +12837,12 @@ def test_function_overloads(self):
1283712837
# decorators. This is fixed on master but not on version 2.1.1.
1283812838
# Next version update remove noqa and add @typing.overload annotation
1283912839

12840-
@torch.jit.overload # noqa: F811
12840+
@torch.jit._overload # noqa: F811
1284112841
def test_simple(x1): # noqa: F811
1284212842
# type: (int) -> int
1284312843
pass
1284412844

12845-
@torch.jit.overload # noqa: F811
12845+
@torch.jit._overload # noqa: F811
1284612846
def test_simple(x1): # noqa: F811
1284712847
# type: (float) -> float
1284812848
pass
@@ -12856,20 +12856,20 @@ def invoke_function():
1285612856
self.checkScript(invoke_function, ())
1285712857

1285812858
# testing that the functions are cached
12859-
compiled_fns_1 = torch.jit.get_overloads(test_simple)
12860-
compiled_fns_2 = torch.jit.get_overloads(test_simple)
12859+
compiled_fns_1 = torch.jit._get_overloads(test_simple)
12860+
compiled_fns_2 = torch.jit._get_overloads(test_simple)
1286112861
for a, b in zip(compiled_fns_1, compiled_fns_2):
1286212862
self.assertIs(a, b)
1286312863

1286412864
# currently we take the default values have to be specified in the
1286512865
# overload as well - TODO take them from implementation and apply
1286612866
# where the type is valid.
12867-
@torch.jit.overload # noqa: F811
12867+
@torch.jit._overload # noqa: F811
1286812868
def identity(x1): # noqa: F811
1286912869
# type: (str) -> str
1287012870
pass
1287112871

12872-
@torch.jit.overload # noqa: F811
12872+
@torch.jit._overload # noqa: F811
1287312873
def identity(x1=1.0): # noqa: F811
1287412874
# type: (float) -> float
1287512875
pass
@@ -12890,19 +12890,18 @@ def schema_match_failure():
1289012890
torch.jit.script(schema_match_failure)
1289112891
except Exception as e:
1289212892
thrown = True
12893-
e_msg = str(e)
12894-
self.assertTrue(r"of type 'str'" in e_msg and r"of type 'float" in e_msg)
12893+
self.assertTrue(r"of type 'str'" in str(e) and r"of type 'float" in str(e))
1289512894
self.assertTrue(thrown)
1289612895

1289712896
with self.assertRaisesRegex(Exception, "cannot be directly compiled"):
1289812897
torch.jit.script(identity)
1289912898

12900-
@torch.jit.overload # noqa: F811
12899+
@torch.jit._overload # noqa: F811
1290112900
def impl_compile_failure(x, y): # noqa: F811
1290212901
# type: (str, str) -> (str)
1290312902
pass
1290412903

12905-
@torch.jit.overload # noqa: F811
12904+
@torch.jit._overload # noqa: F811
1290612905
def impl_compile_failure(x, y): # noqa: F811
1290712906
# type: (int, int) -> (int)
1290812907
pass
@@ -12918,19 +12917,22 @@ def test():
1291812917
torch.jit.script(test)
1291912918

1292012919
def test_function_overloading_isinstance(self):
12921-
@torch.jit.overload # noqa: F811
12920+
@torch.jit._overload # noqa: F811
1292212921
def my_conv(x, y): # noqa: F811
1292312922
# type: (float, str) -> (float)
1292412923
pass
1292512924

12926-
@torch.jit.overload # noqa: F811
12925+
@torch.jit._overload # noqa: F811
1292712926
def my_conv(x, y=2.0): # noqa: F811
1292812927
# type: (float, float) -> (float)
1292912928
pass
1293012929

1293112930
def my_conv(x, y=2.0): # noqa: F811
1293212931
if isinstance(y, str):
12933-
return 4.0 - x
12932+
if y == "hi":
12933+
return 4.0 - x
12934+
else:
12935+
return 5.0 - x
1293412936
else:
1293512937
return 2.0 + x
1293612938

torch/csrc/jit/script/init.cpp

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ void checkOverloadDecl(const Decl& new_decl, const Decl& old_decl) {
178178
const auto& new_params = new_decl.params();
179179
const auto& old_params = old_decl.params();
180180

181+
// TODO. same number of parameters not strictly necessary.
181182
TORCH_INTERNAL_ASSERT(
182183
new_params.size() == old_params.size(),
183184
"Overload must have same number of parameters\n",
@@ -233,6 +234,27 @@ FunctionSchema getSchemaWithNameAndDefaults(
233234
schema.is_varret());
234235
}
235236

237+
static StrongFunctionPtr script_compile_function(
238+
const c10::QualifiedName& name,
239+
const Def& def,
240+
const FunctionDefaults& defaults,
241+
ResolutionCallback rcb) {
242+
auto cu = get_python_cu();
243+
auto defined_functions = cu->define(
244+
QualifiedName(name.prefix()),
245+
{def},
246+
{pythonResolver(std::move(rcb))},
247+
nullptr,
248+
true);
249+
TORCH_INTERNAL_ASSERT(defined_functions.size() == 1);
250+
auto& defined = defined_functions[0];
251+
defined->setSchema(getSchemaWithNameAndDefaults(
252+
def.range(), defined->getSchema(), def.name().name(), defaults));
253+
StrongFunctionPtr ret(std::move(cu), defined);
254+
didFinishEmitFunction(ret);
255+
return ret;
256+
}
257+
236258
struct VISIBILITY_HIDDEN ModuleSelf : public Self {
237259
ModuleSelf(const Module& m, py::object& py_m)
238260
: Self(), module_(m), pyModule_(py_m) {}
@@ -705,20 +727,7 @@ void initJitScriptBindings(PyObject* module) {
705727
C10_LOG_API_USAGE_ONCE("torch.script.compile");
706728
const auto name = c10::QualifiedName(qualname);
707729
TORCH_INTERNAL_ASSERT(name.name() == def.name().name());
708-
auto cu = get_python_cu();
709-
auto defined_functions = cu->define(
710-
QualifiedName(name.prefix()),
711-
{def},
712-
{pythonResolver(std::move(rcb))},
713-
nullptr,
714-
true);
715-
TORCH_INTERNAL_ASSERT(defined_functions.size() == 1);
716-
auto& defined = defined_functions[0];
717-
defined->setSchema(getSchemaWithNameAndDefaults(
718-
def.range(), defined->getSchema(), def.name().name(), defaults));
719-
StrongFunctionPtr ret(std::move(cu), defined);
720-
didFinishEmitFunction(ret);
721-
return ret;
730+
return script_compile_function(name, def, defaults, std::move(rcb));
722731
});
723732
m.def(
724733
"_jit_script_compile_overload",
@@ -728,25 +737,9 @@ void initJitScriptBindings(PyObject* module) {
728737
ResolutionCallback rcb,
729738
const FunctionDefaults& defaults) {
730739
const auto name = c10::QualifiedName(qualname);
731-
auto cu = get_python_cu();
732740
checkOverloadDecl(overload_decl, implementation_def.decl());
733741
auto new_def = implementation_def.withDecl(overload_decl);
734-
auto defined_functions = cu->define(
735-
QualifiedName(name.prefix()),
736-
{new_def},
737-
{pythonResolver(std::move(rcb))},
738-
nullptr,
739-
true);
740-
TORCH_INTERNAL_ASSERT(defined_functions.size() == 1);
741-
auto& defined = defined_functions[0];
742-
defined->setSchema(getSchemaWithNameAndDefaults(
743-
new_def.range(),
744-
defined->getSchema(),
745-
new_def.name().name(),
746-
defaults));
747-
StrongFunctionPtr ret(std::move(cu), defined);
748-
didFinishEmitFunction(ret);
749-
return ret;
742+
return script_compile_function(name, new_def, defaults, std::move(rcb));
750743
});
751744

752745
m.def(

torch/csrc/jit/script/python_sugared_value.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,8 @@ std::shared_ptr<SugaredValue> toSugaredValue(
573573

574574
py::bool_ isFunction = py::module::import("inspect").attr("isfunction")(obj);
575575
if (py::cast<bool>(isFunction)) {
576-
auto overloads = py::module::import("torch.jit").attr("get_overloads")(obj);
576+
auto overloads =
577+
py::module::import("torch.jit").attr("_get_overloads")(obj);
577578
if (!overloads.is_none()) {
578579
auto compiled_fns = py::cast<std::vector<StrongFunctionPtr>>(overloads);
579580
return std::make_shared<OverloadedFunctionValue>(std::move(compiled_fns));

torch/jit/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,7 +1154,7 @@ def forward(self, input):
11541154
_compile_and_register_class(obj, _rcb, qualified_name)
11551155
return obj
11561156
else:
1157-
check_directly_compile_overloaded(obj)
1157+
_check_directly_compile_overloaded(obj)
11581158
ast = get_jit_def(obj)
11591159
if _rcb is None:
11601160
_rcb = _gen_rcb(obj, _frames_up)
@@ -2091,7 +2091,7 @@ def _get_script_class(name):
20912091
# qualified name => list[compiled fns]
20922092
_compiled_overloaded_fns = {}
20932093

2094-
def overload(func):
2094+
def _overload(func):
20952095
qual_name = _qualified_name(func)
20962096
global _overloaded_fns
20972097
fn_overload_list = _overloaded_fns.get(qual_name)
@@ -2104,14 +2104,14 @@ def overload(func):
21042104
fn_overload_list.append((torch.jit.get_jit_def(func).decl(), get_default_args(func)))
21052105
return func
21062106

2107-
def compile_function_with_overload(qual_name, impl_fn, overload_decl, overload_defaults):
2107+
def _compile_function_with_overload(qual_name, impl_fn, overload_decl, overload_defaults):
21082108
impl_ast = torch.jit.get_jit_def(impl_fn)
21092109
_frames_up = 0
21102110
_rcb = _gen_rcb(impl_fn, _frames_up)
21112111
fn = torch._C._jit_script_compile_overload(qual_name, overload_decl, impl_ast, _rcb, overload_defaults)
21122112
return fn
21132113

2114-
def get_overloads(obj):
2114+
def _get_overloads(obj):
21152115
# check for cached compiled fns
21162116
qual_name = _qualified_name(obj)
21172117
global _compiled_overloaded_fns
@@ -2130,15 +2130,15 @@ def get_overloads(obj):
21302130
# incompatible with a type of parameter in an overload, and other validation.
21312131
# This is still an internal api so for now use defaults from overload
21322132
for overload_decl, overload_defaults in overloads:
2133-
compiled_fn = compile_function_with_overload(qual_name, obj, overload_decl, overload_defaults)
2133+
compiled_fn = _compile_function_with_overload(qual_name, obj, overload_decl, overload_defaults)
21342134
compiled_fns.append(compiled_fn)
21352135

21362136
# cache compilation, remove information stored to do compilation
21372137
_compiled_overloaded_fns[qual_name] = compiled_fns
21382138
del _overloaded_fns[qual_name]
21392139
return compiled_fns
21402140

2141-
def check_directly_compile_overloaded(obj):
2141+
def _check_directly_compile_overloaded(obj):
21422142
qual_name = _qualified_name(obj)
21432143
global _compiled_overloaded_fns
21442144
global _overloaded_fns

0 commit comments

Comments
 (0)