@@ -19,18 +19,20 @@ struct Argument {
1919 c10::optional<int32_t > N = c10::nullopt ,
2020 c10::optional<IValue> default_value = c10::nullopt ,
2121 bool kwarg_only = false ,
22- c10::optional<AliasInfo> alias_info = c10::nullopt )
22+ c10::optional<AliasInfo> alias_info = c10::nullopt ,
23+ bool is_inferred_type = false )
2324 : name_(std::move(name)),
2425 type_ (type ? type : TensorType::get()),
2526 N_(std::move(N)),
2627 default_value_(std::move(default_value)),
2728 kwarg_only_(kwarg_only),
28- alias_info_(std::move(alias_info)) {
29- if (default_value_ && default_value_->isTensor ()) {
30- auto t = default_value_->toTensor ();
31- AT_ASSERT (!t.defined () || t.is_variable ());
32- }
33- }
29+ alias_info_(std::move(alias_info)),
30+ is_inferred_type_(is_inferred_type) {
31+ if (default_value_ && default_value_->isTensor ()) {
32+ auto t = default_value_->toTensor ();
33+ AT_ASSERT (!t.defined () || t.is_variable ());
34+ }
35+ }
3436 const std::string& name () const {
3537 return name_;
3638 }
@@ -49,6 +51,28 @@ struct Argument {
4951 const c10::optional<AliasInfo>& alias_info () const {
5052 return alias_info_;
5153 }
54+ bool is_inferred_type () const {
55+ return is_inferred_type_;
56+ }
57+ std::string formatTypeMismatchMsg (const std::string& actual_type) const {
58+ std::string inferred_type_hint;
59+ if (is_inferred_type ()) {
60+ inferred_type_hint = c10::str (
61+ " Inferred '" ,
62+ name (),
63+ " ' to be of type 'Tensor' " ,
64+ " because it was not annotated with an explicit type.\n " );
65+ }
66+ return c10::str (
67+ " expected a value of type '" ,
68+ type ()->python_str (),
69+ " ' for argument '" ,
70+ name (),
71+ " ' but instead found type '" ,
72+ actual_type,
73+ " '.\n " ,
74+ inferred_type_hint);
75+ }
5276
5377 Argument cloneWithType (TypePtr new_type) const {
5478 return Argument (name_, new_type, N_, default_value_, kwarg_only_, alias_info_);
@@ -67,6 +91,7 @@ struct Argument {
6791 // is this only specifyable as a keyword argument?
6892 bool kwarg_only_;
6993 c10::optional<AliasInfo> alias_info_;
94+ bool is_inferred_type_;
7095};
7196
7297namespace detail {
@@ -182,7 +207,14 @@ struct FunctionSchema {
182207 is_varret ());
183208 }
184209
185- FunctionSchema cloneWithRemappedTypes (const std::function<TypePtr(TypePtr)> type_map) const ;
210+ std::string formatTypeMismatchMsg (
211+ const Argument& expected,
212+ const std::string& actual_type,
213+ c10::optional<size_t > position = c10::nullopt ,
214+ c10::optional<std::string> value = c10::nullopt ) const ;
215+
216+ FunctionSchema cloneWithRemappedTypes (
217+ const std::function<TypePtr(TypePtr)> type_map) const ;
186218
187219 // Check that inputs have the correct types and appends any missing default
188220 // values.
0 commit comments