Skip to content

Commit 5dacf6b

Browse files
suofacebook-github-bot
authored andcommitted
improve error message on inferred type (#21058)
Summary: Pull Request resolved: #21058 ghimport-source-id: e7d6e08 Differential Revision: D15534670 Pulled By: suo fbshipit-source-id: 8bbfd6e9c1afbc3006d7d55ed633e18618e05021
1 parent 6ea9044 commit 5dacf6b

File tree

13 files changed

+153
-77
lines changed

13 files changed

+153
-77
lines changed

aten/src/ATen/core/function_schema.h

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,20 @@ struct Argument {
1919
c10::optional<int32_t> N = c10::nullopt,
2020
c10::optional<IValue> default_value = c10::nullopt,
2121
bool kwarg_only = false,
22-
c10::optional<AliasInfo> alias_info = c10::nullopt)
22+
c10::optional<AliasInfo> alias_info = c10::nullopt,
23+
bool is_inferred_type = false)
2324
: name_(std::move(name)),
2425
type_(type ? type : TensorType::get()),
2526
N_(std::move(N)),
2627
default_value_(std::move(default_value)),
2728
kwarg_only_(kwarg_only),
28-
alias_info_(std::move(alias_info)) {
29-
if (default_value_ && default_value_->isTensor()) {
30-
auto t = default_value_->toTensor();
31-
AT_ASSERT(!t.defined() || t.is_variable());
32-
}
33-
}
29+
alias_info_(std::move(alias_info)),
30+
is_inferred_type_(is_inferred_type) {
31+
if (default_value_ && default_value_->isTensor()) {
32+
auto t = default_value_->toTensor();
33+
AT_ASSERT(!t.defined() || t.is_variable());
34+
}
35+
}
3436
const std::string& name() const {
3537
return name_;
3638
}
@@ -49,6 +51,28 @@ struct Argument {
4951
const c10::optional<AliasInfo>& alias_info() const {
5052
return alias_info_;
5153
}
54+
bool is_inferred_type() const {
55+
return is_inferred_type_;
56+
}
57+
std::string formatTypeMismatchMsg(const std::string& actual_type) const {
58+
std::string inferred_type_hint;
59+
if (is_inferred_type()) {
60+
inferred_type_hint = c10::str(
61+
"Inferred '",
62+
name(),
63+
"' to be of type 'Tensor' ",
64+
"because it was not annotated with an explicit type.\n");
65+
}
66+
return c10::str(
67+
"expected a value of type '",
68+
type()->python_str(),
69+
"' for argument '",
70+
name(),
71+
"' but instead found type '",
72+
actual_type,
73+
"'.\n",
74+
inferred_type_hint);
75+
}
5276

5377
Argument cloneWithType(TypePtr new_type) const {
5478
return Argument(name_, new_type, N_, default_value_, kwarg_only_, alias_info_);
@@ -67,6 +91,7 @@ struct Argument {
6791
// is this only specifyable as a keyword argument?
6892
bool kwarg_only_;
6993
c10::optional<AliasInfo> alias_info_;
94+
bool is_inferred_type_;
7095
};
7196

7297
namespace detail {
@@ -182,7 +207,14 @@ struct FunctionSchema {
182207
is_varret());
183208
}
184209

185-
FunctionSchema cloneWithRemappedTypes(const std::function<TypePtr(TypePtr)> type_map) const;
210+
std::string formatTypeMismatchMsg(
211+
const Argument& expected,
212+
const std::string& actual_type,
213+
c10::optional<size_t> position = c10::nullopt,
214+
c10::optional<std::string> value = c10::nullopt) const;
215+
216+
FunctionSchema cloneWithRemappedTypes(
217+
const std::function<TypePtr(TypePtr)> type_map) const;
186218

187219
// Check that inputs have the correct types and appends any missing default
188220
// values.

aten/src/ATen/core/function_schema_inl.h

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,20 +42,39 @@ inline std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema)
4242
return out;
4343
}
4444

45-
inline void FunctionSchema::checkArg(const IValue& value, const Argument& argument, optional<size_t> pos) const {
45+
inline std::string FunctionSchema::formatTypeMismatchMsg(
46+
const Argument& expected,
47+
const std::string& actual_type,
48+
c10::optional<size_t> position,
49+
c10::optional<std::string> value) const {
50+
std::string position_str;
51+
if (position) {
52+
position_str = c10::str("Position: ", *position, "\n");
53+
}
54+
std::string value_str;
55+
if (value) {
56+
value_str = c10::str("Value: ", *value, "\n");
57+
}
58+
return c10::str(
59+
name(),
60+
"() ",
61+
expected.formatTypeMismatchMsg(actual_type),
62+
position_str,
63+
value_str,
64+
"Declaration: ",
65+
*this);
66+
}
67+
68+
inline void FunctionSchema::checkArg(
69+
const IValue& value,
70+
const Argument& argument,
71+
optional<size_t> pos) const {
4672
if (!isSubvalueOf(value, argument.type())) {
4773
std::string position = pos ? ::c10::str(" in position ", *pos) : "";
48-
AT_ERROR(
49-
"Expected value of type ",
50-
*argument.type(),
51-
" for argument '",
52-
argument.name(),
53-
"'",
54-
position,
55-
", but instead got value of type ",
56-
attemptToRecoverType(value)->str(),
57-
". Declaration: ",
58-
*this);
74+
TORCH_CHECK(
75+
false,
76+
formatTypeMismatchMsg(
77+
argument, attemptToRecoverType(value)->python_str(), pos));
5978
}
6079
}
6180

test/cpp/api/jit.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ TEST(TorchScriptTest, TestNestedIValueModuleArgMatching) {
6262
} catch (const c10::Error& error) {
6363
AT_ASSERT(
6464
std::string(error.what_without_backtrace())
65-
.find("Expected value of type Tensor[][] for argument 'a' in "
66-
"position 0, but instead got value of type t[][][]") == 0);
67-
65+
.find("nested_loop() expected a value of type 'List[List[Tensor]]'"
66+
" for argument 'a' but instead found type "
67+
"'List[List[List[t]]]'") == 0);
6868
};
6969

7070
std::vector<torch::jit::IValue> gen_list;
@@ -81,9 +81,9 @@ TEST(TorchScriptTest, TestNestedIValueModuleArgMatching) {
8181
//so the error message is not helpful here.
8282
AT_ASSERT(
8383
std::string(error.what_without_backtrace())
84-
.find("Expected value of type Tensor[][] for argument 'a' in "
85-
"position 0, but instead got value of type Tensor[][]") == 0);
86-
84+
.find("nested_loop() expected a value of type "
85+
"'List[List[Tensor]]' for argument 'a' but "
86+
"instead found type 'List[List[Tensor]]'") == 0);
8787
};
8888
}
8989

test/expect/TestScript.test_python_frontend.expect

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,17 @@
44
(list
55
(param
66
(ident x)
7-
(variable (ident Tensor))
7+
(option)
88
(option)
99
(False))
1010
(param
1111
(ident y)
12-
(variable (ident Tensor))
12+
(option)
1313
(option)
1414
(False))
1515
(param
1616
(ident z)
17-
(variable (ident Tensor))
17+
(option)
1818
(option)
1919
(False)))
2020
(option))

test/test_jit.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8734,8 +8734,8 @@ def fn(x : torch.Tensor, y : Tensor, z) -> Tuple[Tensor, Tensor, Tensor]:
87348734
f.write(code)
87358735
fn = get_fn('test_type_annotation_py3', script_path)
87368736

8737-
with self.assertRaisesRegex(RuntimeError, r"Expected a value of type Tensor for argument"
8738-
r" '0' but found Tuple\[Tensor,"):
8737+
with self.assertRaisesRegex(RuntimeError, r"expected a value of type 'Tensor' for argument"
8738+
r" '0' but instead found type 'Tuple\[Tensor,"):
87398739
@torch.jit.script
87408740
def bad_fn(x):
87418741
x, y = fn((x, x), x, x)
@@ -9668,12 +9668,12 @@ def f2(a):
96689668
def f3(a):
96699669
torch.sum(dim=4)
96709670

9671-
with self.assertRaisesRegex(RuntimeError, 'for argument \'tensors\' but found Tensor'):
9671+
with self.assertRaisesRegex(RuntimeError, 'for argument \'tensors\' but instead found type \'Tensor'):
96729672
@torch.jit.script
96739673
def f4(a):
96749674
torch.cat(a)
96759675

9676-
with self.assertRaisesRegex(RuntimeError, r'argument \'tensors\' but found List\[int\]'):
9676+
with self.assertRaisesRegex(RuntimeError, r'argument \'tensors\' but instead found type \'List\[int\]'):
96779677
@torch.jit.script
96789678
def f5(a):
96799679
torch.cat([3])
@@ -9683,7 +9683,7 @@ def f5(a):
96839683
def f6(a):
96849684
a.expand(size=[3, [4]])
96859685

9686-
with self.assertRaisesRegex(RuntimeError, 'xpected a value of type Tensor for argument \'self\''):
9686+
with self.assertRaisesRegex(RuntimeError, 'xpected a value of type \'Tensor\' for argument \'self\''):
96879687
@torch.jit.script
96889688
def f7(a):
96899689
torch.sum([4])
@@ -13230,6 +13230,18 @@ def test_python_op_name(self):
1323013230
def fn():
1323113231
return random.randint()
1323213232

13233+
def test_inferred_error_msg(self):
13234+
"""
13235+
Test that when we get a type mismatch on a function where we inferred
13236+
the type to be tensor, a good error message is given.
13237+
"""
13238+
@torch.jit.script
13239+
def foo(a):
13240+
return a
13241+
13242+
with self.assertRaisesRegex(RuntimeError, "Inferred \'a\' to be of type \'Tensor"):
13243+
foo(1)
13244+
1323313245

1323413246
class MnistNet(nn.Module):
1323513247
def __init__(self):
@@ -15569,7 +15581,7 @@ def set_non_initialized(self, y):
1556915581
self.bar = y # can't assign to non-initialized attr
1557015582

1557115583
def test_type_annotations(self):
15572-
with self.assertRaisesRegex(RuntimeError, "Expected a value of type bool"):
15584+
with self.assertRaisesRegex(RuntimeError, "expected a value of type \'bool"):
1557315585
@torch.jit.script # noqa: B903
1557415586
class FooTest(object):
1557515587
def __init__(self, x):
@@ -15785,7 +15797,7 @@ def test_list_no_reverse():
1578515797

1578615798
self.assertEqual(test_list_no_reverse(), 1)
1578715799

15788-
with self.assertRaisesRegex(RuntimeError, "bool for argument \'reverse"):
15800+
with self.assertRaisesRegex(RuntimeError, "bool\' for argument \'reverse"):
1578915801
@torch.jit.script
1579015802
def test():
1579115803
li = [Foo(1)]

torch/csrc/jit/named_value.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ namespace jit {
1010

1111
struct Value;
1212

13+
/**
14+
* A value with optional extra name and location information. Used during
15+
* schema matching to provide extra error information and resolve kwargs.
16+
*/
1317
struct NamedValue {
1418
NamedValue(const SourceRange& loc, const std::string& name, Value* value)
1519
: loc_(loc), name_(name), value_(value) {}

torch/csrc/jit/pybind_utils.h

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -341,21 +341,11 @@ inline IValue argumentToIValue(
341341
try {
342342
return toIValue(object, argument.type(), argument.N());
343343
} catch (const py::cast_error& error) {
344-
throw std::runtime_error(c10::str(
345-
schema.name(),
346-
"() expected value of type ",
347-
argument.type()->str(),
348-
" for argument '",
349-
argument.name(),
350-
"' in position ",
351-
argumentPosition,
352-
", but instead got value of type ",
344+
throw std::runtime_error(schema.formatTypeMismatchMsg(
345+
argument,
353346
py::str(object.get_type().attr("__name__")),
354-
".",
355-
"\nValue: ",
356-
py::repr(object),
357-
"\nDeclaration: ",
358-
schema));
347+
argumentPosition,
348+
py::repr(object)));
359349
}
360350
}
361351

torch/csrc/jit/script/compiler.cpp

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,7 @@ struct to_ir {
671671
auto param = *it;
672672
auto def = param.defaultValue();
673673
if (def.present()) {
674-
default_types.emplace_back(param.type());
674+
default_types.emplace_back(param.type().get());
675675
default_exprs.emplace_back(def.get());
676676
}
677677
}
@@ -684,15 +684,22 @@ struct to_ir {
684684

685685
TypePtr type;
686686
c10::optional<int32_t> N;
687-
688-
// BroadcastList list can only appear at the argument level
689-
if (auto maybe_broad_list =
690-
typeParser_.parseBroadcastList(decl_arg.type())) {
691-
type = maybe_broad_list->first;
692-
N = maybe_broad_list->second;
693-
} else {
694-
type = typeParser_.parseTypeFromExpr(decl_arg.type());
687+
bool is_inferred_type = false;
688+
if (!decl_arg.type().present()) {
689+
// If this param doesn't have a type, default to "tensor"
690+
is_inferred_type = true;
691+
type = TensorType::get();
695692
N = c10::nullopt;
693+
} else {
694+
// BroadcastList list can only appear at the argument level
695+
if (auto maybe_broad_list =
696+
typeParser_.parseBroadcastList(decl_arg.type().get())) {
697+
type = maybe_broad_list->first;
698+
N = maybe_broad_list->second;
699+
} else {
700+
type = typeParser_.parseTypeFromExpr(decl_arg.type().get());
701+
N = c10::nullopt;
702+
}
696703
}
697704
c10::optional<IValue> default_value = c10::nullopt;
698705
if (decl_arg.defaultValue().present()) {
@@ -703,7 +710,9 @@ struct to_ir {
703710
type,
704711
N,
705712
default_value,
706-
decl_arg.kwarg_only());
713+
decl_arg.kwarg_only(),
714+
/*alias_info=*/c10::nullopt,
715+
is_inferred_type);
707716
retval.push_back(arg);
708717
}
709718
return retval;

torch/csrc/jit/script/parser.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -363,9 +363,9 @@ struct ParserImpl {
363363
auto ident = parseIdent();
364364
TreeRef type;
365365
if (L.nextIf(':')) {
366-
type = parseExp();
366+
type = Maybe<Expr>::create(L.cur().range, parseExp());
367367
} else {
368-
type = Var::create(L.cur().range, Ident::create(L.cur().range, "Tensor"));
368+
type = Maybe<Expr>::create(L.cur().range);
369369
}
370370
TreeRef def;
371371
if (L.nextIf('=')) {
@@ -374,15 +374,15 @@ struct ParserImpl {
374374
def = Maybe<Expr>::create(L.cur().range);
375375
}
376376
return Param::create(
377-
type->range(), Ident(ident), Expr(type), Maybe<Expr>(def), kwarg_only);
377+
type->range(), Ident(ident), Maybe<Expr>(type), Maybe<Expr>(def), kwarg_only);
378378
}
379379

380380
Param parseBareTypeAnnotation() {
381381
auto type = parseExp();
382382
return Param::create(
383383
type.range(),
384384
Ident::create(type.range(), ""),
385-
type,
385+
Maybe<Expr>::create(type.range(), type),
386386
Maybe<Expr>::create(type.range()),
387387
/*kwarg_only=*/false);
388388
}

0 commit comments

Comments
 (0)