Skip to content

Commit 9691025

Browse files
Mikhail Zolotukhinfacebook-github-bot
authored andcommitted
schema_matching.cpp: improve error messages.
Summary: Pull Request resolved: #21141 Differential Revision: D15769066 Pulled By: ZolotukhinM fbshipit-source-id: 5853e0360581c44e42b068add3bf2bc68e671b2b
1 parent 28adca8 commit 9691025

File tree

4 files changed

+41
-44
lines changed

4 files changed

+41
-44
lines changed

aten/src/ATen/core/function_schema.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ struct Argument {
6464
"because it was not annotated with an explicit type.\n");
6565
}
6666
return c10::str(
67-
"expected a value of type '",
67+
"Expected a value of type '",
6868
type()->python_str(),
6969
"' for argument '",
7070
name(),

test/cpp/api/jit.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ TEST(TorchScriptTest, TestNestedIValueModuleArgMatching) {
6262
} catch (const c10::Error& error) {
6363
AT_ASSERT(
6464
std::string(error.what_without_backtrace())
65-
.find("nested_loop() expected a value of type 'List[List[Tensor]]'"
65+
.find("nested_loop() Expected a value of type 'List[List[Tensor]]'"
6666
" for argument 'a' but instead found type "
6767
"'List[List[List[t]]]'") == 0);
6868
};
@@ -81,7 +81,7 @@ TEST(TorchScriptTest, TestNestedIValueModuleArgMatching) {
8181
//so the error message is not helpful here.
8282
AT_ASSERT(
8383
std::string(error.what_without_backtrace())
84-
.find("nested_loop() expected a value of type "
84+
.find("nested_loop() Expected a value of type "
8585
"'List[List[Tensor]]' for argument 'a' but "
8686
"instead found type 'List[List[Tensor]]'") == 0);
8787
};

test/test_jit.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2904,7 +2904,7 @@ def test_sequence_parsing(self):
29042904
("return [x x]", "expected ]"),
29052905
("return x, x,", True),
29062906
("return bar(x, x,)", True),
2907-
("return bar()", "argument x not provided"),
2907+
("return bar()", "Argument x not provided"),
29082908
("for a, b, in x, x,:\n pass", "List of iterables"),
29092909
("a, b, = x, x,\n return a + b", True)
29102910
]
@@ -4509,7 +4509,7 @@ def comp(l1, l2):
45094509
self.assertEqual(comp([1, 2, 3], [4, 5]), [3, 6, 9, 6, 7])
45104510

45114511
def test_comprehensions_wrong_expr_type(self):
4512-
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
4512+
with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):
45134513
@torch.jit.script
45144514
def comp(l):
45154515
# type: (List[int]) -> List[float]
@@ -5599,14 +5599,14 @@ def test_manual_unwrap_opt(x):
55995599
x = torch.jit._unwrap_optional(x)
56005600
return x # noqa: T484
56015601

5602-
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
5602+
with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):
56035603
@torch.jit.script
56045604
def or_error(x, y):
56055605
# type: (Optional[int], Optional[int]) -> None
56065606
if x is None or y is None:
56075607
print(x + y) # noqa: T484
56085608

5609-
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
5609+
with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):
56105610
@torch.jit.script
56115611
def and_error(x, y):
56125612
# type: (Optional[int], Optional[int]) -> None
@@ -5615,15 +5615,15 @@ def and_error(x, y):
56155615
else:
56165616
print(x + y) # noqa: T484
56175617

5618-
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
5618+
with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):
56195619
@torch.jit.script
56205620
def named_var(x):
56215621
# type: (Optional[int]) -> None
56225622
x_none = x is not None
56235623
if x_none:
56245624
print(x + 1) # noqa: T484
56255625

5626-
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
5626+
with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):
56275627
@torch.jit.script
56285628
def named_var_and(x, y):
56295629
# type: (Optional[int], Optional[int]) -> None
@@ -8551,7 +8551,7 @@ def fn(x : torch.Tensor, y : Tensor, z) -> Tuple[Tensor, Tensor, Tensor]:
85518551
f.write(code)
85528552
fn = get_fn('test_type_annotation_py3', script_path)
85538553

8554-
with self.assertRaisesRegex(RuntimeError, r"expected a value of type 'Tensor' for argument"
8554+
with self.assertRaisesRegex(RuntimeError, r"Expected a value of type 'Tensor' for argument"
85558555
r" '0' but instead found type 'Tuple\[Tensor,"):
85568556
@torch.jit.script
85578557
def bad_fn(x):
@@ -8626,13 +8626,13 @@ def method(self, x):
86268626
y = self.baz(x)
86278627
return x
86288628

8629-
with self.assertRaisesRegex(RuntimeError, "expected at most 1 arguments but found 2"):
8629+
with self.assertRaisesRegex(RuntimeError, "Expected at most 1 arguments but found 2"):
86308630
ModuleTooMany()
8631-
with self.assertRaisesRegex(RuntimeError, "argument 1 not provided"):
8631+
with self.assertRaisesRegex(RuntimeError, "Argument 1 not provided"):
86328632
ModuleTooFew()
86338633
with self.assertRaisesRegex(RuntimeError, "need 3 values .* found only 2"):
86348634
ModuleTooManyAssign()
8635-
with self.assertRaisesRegex(RuntimeError, "argument 1 not provided."):
8635+
with self.assertRaisesRegex(RuntimeError, "Argument 1 not provided."):
86368636
ModuleDefault()
86378637

86388638
def test_script_define_order(self):
@@ -9462,12 +9462,12 @@ def foo(a, b):
94629462

94639463
def test_builtin_args_fails(self):
94649464

9465-
with self.assertRaisesRegex(RuntimeError, 'expected at most'):
9465+
with self.assertRaisesRegex(RuntimeError, 'xpected at most'):
94669466
@torch.jit.script
94679467
def f0(a):
94689468
torch.sum(a, a, a, a)
94699469

9470-
with self.assertRaisesRegex(RuntimeError, 'argument self not provided'):
9470+
with self.assertRaisesRegex(RuntimeError, 'Argument self not provided'):
94719471
@torch.jit.script
94729472
def f1(a):
94739473
torch.sum(foo=4)
@@ -9833,7 +9833,7 @@ def forward(self, x):
98339833
ReassignSelfRHS()
98349834

98359835
def test_unknown_builtin(self):
9836-
with self.assertRaisesRegex(RuntimeError, 'unknown builtin op'):
9836+
with self.assertRaisesRegex(RuntimeError, 'Unknown builtin op'):
98379837
@torch.jit.script
98389838
def unknown_builtin(x):
98399839
return x.splork(3)
@@ -9916,7 +9916,7 @@ def multi_reduction(x):
99169916
''')
99179917

99189918
def test_invalid_call_arguments(self):
9919-
with self.assertRaisesRegex(RuntimeError, 'arguments for call are not valid'):
9919+
with self.assertRaisesRegex(RuntimeError, 'Arguments for call are not valid'):
99209920
@torch.jit.script
99219921
def invalid_call_arguments(x):
99229922
return torch.unsqueeze(3, 4, 5, 6, 7, 8)
@@ -10000,7 +10000,7 @@ def wrong_module_attr_lookup():
1000010000
return io.BytesIO
1000110001

1000210002
def test_wrong_method_call_inputs(self):
10003-
with self.assertRaisesRegex(RuntimeError, 'argument y not provided'):
10003+
with self.assertRaisesRegex(RuntimeError, 'Argument y not provided'):
1000410004
class SomeModule(torch.jit.ScriptModule):
1000510005

1000610006
@torch.jit.script_method
@@ -10033,7 +10033,7 @@ def test():
1003310033
''')
1003410034

1003510035
def test_call_ge(self):
10036-
with self.assertRaisesRegex(RuntimeError, 'expected at most 1 arguments but found 3'):
10036+
with self.assertRaisesRegex(RuntimeError, 'Expected at most 1 arguments but found 3'):
1003710037
@_trace(torch.zeros(1, 2, 3))
1003810038
def foo(x):
1003910039
return x
@@ -10824,7 +10824,7 @@ def return_tup(x):
1082410824
return x, x # noqa: T484
1082510825

1082610826
def test_annotated_script_fn_arg_mismatch(self):
10827-
with self.assertRaisesRegex(RuntimeError, r"arguments for call are not valid"):
10827+
with self.assertRaisesRegex(RuntimeError, r"Arguments for call are not valid"):
1082810828
@torch.jit.script
1082910829
def tuple_arg(x):
1083010830
# type: (Tuple[Tensor, Tensor]) -> Tensor
@@ -11318,7 +11318,7 @@ def forward(self, x):
1131811318
foo(torch.ones([123])) # wrong size
1131911319

1132011320
def test_builtin_error_messsage(self):
11321-
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
11321+
with self.assertRaisesRegex(RuntimeError, "Arguments for call are not valid"):
1132211322
@torch.jit.script
1132311323
def close_match(x):
1132411324
return x.masked_fill(True)
@@ -15644,7 +15644,7 @@ def set_non_initialized(self, y):
1564415644
self.bar = y # can't assign to non-initialized attr
1564515645

1564615646
def test_type_annotations(self):
15647-
with self.assertRaisesRegex(RuntimeError, "expected a value of type \'bool"):
15647+
with self.assertRaisesRegex(RuntimeError, "Expected a value of type \'bool"):
1564815648
@torch.jit.script # noqa: B903
1564915649
class FooTest(object):
1565015650
def __init__(self, x):

torch/csrc/jit/script/schema_matching.cpp

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,7 @@ static Value* tryMatchArgument(
149149
if (failure_messages) {
150150
err() << "Could not match type " << value->type()->python_str() << " to "
151151
<< arg.type()->python_str() << " in argument '" << arg.name()
152-
<< "': " << matched_type.errMsg << "\n"
153-
<< named_value.locOr(loc);
152+
<< "': " << matched_type.errMsg << ".\n";
154153
}
155154
return nullptr;
156155
}
@@ -169,16 +168,15 @@ static Value* tryMatchArgument(
169168
if (v->getElementType()->isSubtypeOf(TensorType::get())) {
170169
ostream << "Empty lists default to List[Tensor]. Use torch.jit."
171170
"annotate(List[my_type], []) to create an empty list of"
172-
" another type\n";
171+
" another type.\n";
173172
}
174173
}
175174

176175
if (value->type() == NumberType::get() &&
177176
value->node()->kind() == aten::item) {
178177
ostream << "Use int(tensor) or float(tensor) to retrieve item() from a "
179-
<< "tensor with the appropriate type\n";
178+
<< "tensor with the appropriate type.\n";
180179
}
181-
ostream << named_value.locOr(loc);
182180
}
183181

184182
return nullptr;
@@ -263,7 +261,7 @@ c10::optional<MatchedSchema> tryMatchSchema(
263261
std::ostream* failure_messages,
264262
bool allow_conversions) {
265263
auto err = [&]() -> std::ostream& {
266-
*failure_messages << "\nfor operator " << schema << ":\n";
264+
*failure_messages << "\n" << schema << ":\n";
267265
return *failure_messages;
268266
};
269267

@@ -320,9 +318,8 @@ c10::optional<MatchedSchema> tryMatchSchema(
320318
const NamedValue& nv = kwargs[*kwarg_idx];
321319
if (used_kwarg[*kwarg_idx]) {
322320
if (failure_messages) {
323-
err() << "argument " << nv.name()
324-
<< " specified twice in schema, submit a bug report!\n"
325-
<< nv.locOr(loc);
321+
err() << "Argument " << nv.name()
322+
<< " specified twice in schema, submit a bug report!\n";
326323
}
327324
return c10::nullopt;
328325
}
@@ -334,9 +331,8 @@ c10::optional<MatchedSchema> tryMatchSchema(
334331
actual_named_value = NamedValue(*arg.default_value());
335332
} else {
336333
if (failure_messages) {
337-
err() << "argument " << schema.arguments()[schema_i].name()
338-
<< " not provided.\n"
339-
<< loc;
334+
err() << "Argument " << schema.arguments()[schema_i].name()
335+
<< " not provided.\n";
340336
}
341337
return c10::nullopt;
342338
}
@@ -358,7 +354,7 @@ c10::optional<MatchedSchema> tryMatchSchema(
358354
}
359355
// check for unused self argument
360356
if (self != c10::nullopt && failure_messages) {
361-
err() << "provided self argument not used in schema\n";
357+
err() << "Provided self argument not used in schema.\n";
362358
}
363359

364360
if (schema.is_vararg()) {
@@ -370,9 +366,8 @@ c10::optional<MatchedSchema> tryMatchSchema(
370366
// check for unused positional arguments
371367
if (used_args < args.size()) {
372368
if (failure_messages) {
373-
err() << "expected at most " << used_args << " arguments "
374-
<< "but found " << args.size() << " positional arguments.\n"
375-
<< loc << "\n";
369+
err() << "Expected at most " << used_args << " arguments "
370+
<< "but found " << args.size() << " positional arguments.\n";
376371
}
377372
return c10::nullopt;
378373
}
@@ -382,9 +377,9 @@ c10::optional<MatchedSchema> tryMatchSchema(
382377
if (!used_kwarg[i]) {
383378
if (failure_messages) {
384379
if (!schema.argumentIndexWithName(nv.name())) {
385-
err() << "keyword argument " << nv.name() << " unknown\n";
380+
err() << "Keyword argument " << nv.name() << " unknown.\n";
386381
} else {
387-
err() << "keyword argument " << nv.name() << " specified twice\n";
382+
err() << "Keyword argument " << nv.name() << " specified twice.\n";
388383
}
389384
}
390385
return c10::nullopt;
@@ -557,23 +552,25 @@ Value* emitBuiltinCall(
557552
const auto close_symbols = findSimilarOperators(name);
558553
auto error = ErrorReport(loc);
559554
const auto& user_function_name = name.toQualString();
560-
error << "unknown builtin op: " << user_function_name << "\n";
555+
error << "Unknown builtin op: " << user_function_name << ".\n";
561556
if (close_symbols.size() == 0) {
562557
error
563558
<< "Could not find any similar ops to " << user_function_name
564-
<< ". This op may not exist or may not be currently supported in TorchScript\n";
559+
<< ". This op may not exist or may not be currently supported in TorchScript.\n";
565560
} else {
566561
error << "Here are some suggestions: \n";
567562
for (const auto& sym : close_symbols) {
568563
error << "\t" << sym.toQualString() << "\n";
569564
}
565+
error << "\nThe original call is";
570566
}
571567
throw error;
572568
}
573569

574-
throw ErrorReport(loc) << "arguments for call are not valid:\n"
570+
throw ErrorReport(loc) << "Arguments for call are not valid.\n"
571+
<< "The following operator variants are available:\n"
575572
<< prefixLine(failure_messages.str(), " ")
576-
<< "for call at";
573+
<< "\nThe original call is";
577574
}
578575
} // namespace script
579576
} // namespace jit

0 commit comments

Comments
 (0)