Skip to content

Commit cb6faf7

Browse files
author
root
committed
Update on "extend torch.jit._overload to module methods"
Follow up to #23886, add the same overload api specified in PEP 484 to module methods to reduce the friction of adding method overloads that was brought up in #23266. The usage is: ``` @torch.jit.overload def add(self, y: int) -> int: ... @torch.jit.overload def add(self, y: float) -> float: ... def add(): ... ```
1 parent e58f6c6 commit cb6faf7

File tree

1 file changed

+21
-17
lines changed

1 file changed

+21
-17
lines changed

torch/csrc/jit/script/python_sugared_value.cpp

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -218,25 +218,29 @@ std::shared_ptr<SugaredValue> OverloadedMethodValue::call(
218218
std::vector<NamedValue> new_inputs = inputs.vec();
219219
new_inputs.insert(new_inputs.begin(), module_);
220220

221-
for (const std::string& method_name : method_names_) {
222-
auto cls = module_->type()->expect<ClassType>();
223-
const auto fn = cls->getMethod(method_name);
224-
auto match = tryMatchSchema(
225-
fn->getSchema(),
226-
loc,
227-
*caller.graph().get(),
228-
c10::nullopt,
229-
new_inputs,
230-
attributes,
231-
&err,
232-
true);
233-
if (match) {
234-
return MethodValue(module_, method_name)
235-
.call(loc, caller, inputs, attributes, n_binders);
221+
std::stringstream failure_messages;
222+
for (bool allow_conversions : {false, true}) {
223+
// clear previous error messages
224+
failure_messages.str("");
225+
for (const std::string& method_name : method_names_) {
226+
auto cls = module_->type()->expect<ClassType>();
227+
const auto fn = cls->getMethod(method_name);
228+
auto match = tryMatchSchema(
229+
fn->getSchema(),
230+
loc,
231+
*caller.graph().get(),
232+
c10::nullopt,
233+
new_inputs,
234+
attributes,
235+
&err,
236+
allow_conversions);
237+
if (match) {
238+
return MethodValue(module_, method_name)
239+
.call(loc, caller, inputs, attributes, n_binders);
240+
}
236241
}
237242
}
238-
throw ErrorReport(loc) << "Could not find any matching overloads\n"
239-
<< err.str();
243+
throw ErrorReport(loc) << failure_messages.str();
240244
}
241245

242246
std::shared_ptr<SugaredValue> OverloadedFunctionValue::call(

0 commit comments

Comments
 (0)