Skip to content

Commit dc6939e

Browse files
houseroadfacebook-github-bot
authored andcommitted
Add isBackwardCompatibleWith for Argument and FunctionSchema (#23409)
Summary: we intend to be conservative, and will relax the checks in future if necessary. So far, we consider the following three conditions as backward compatible: 1) two schemas are equal 2) two schemas have same number of arguments, and this schema's arguments are backward compatible with the corresponding ones in argument list of old_schema. 3) this schema has m argument, old_argument has n argument, m > n. the first n arguments of this schema are backward compatible with the corresponding arguments of old_schema. the remaning arguments must be either OptionalType or provide default values. Pull Request resolved: #23409 ghstack-source-id: 90111021 Test Plan: buck test //caffe2/test:function_schema Reviewed By: hl475 Differential Revision: D16505203 fbshipit-source-id: e4099537776a60e8945e5c3cd57fa861f3598a9b
1 parent 1563fdb commit dc6939e

File tree

4 files changed

+248
-16
lines changed

4 files changed

+248
-16
lines changed

aten/src/ATen/core/function_schema.h

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,23 @@ namespace c10 {
1313
// errors. These objects should be constructed from C10 schema once those
1414
// are available.
1515

16+
struct Argument;
17+
struct FunctionSchema;
18+
19+
namespace detail {
20+
inline bool defaultValueEquals_(
21+
const c10::optional<IValue>& lhs,
22+
const c10::optional<IValue>& rhs) {
23+
if (lhs.has_value()) {
24+
return rhs.has_value() && impl::shallowEquals(*lhs, *rhs);
25+
} else {
26+
return !rhs.has_value();
27+
}
28+
}
29+
} // namespace detail
30+
31+
bool operator==(const Argument& lhs, const Argument& rhs);
32+
1633
struct Argument {
1734
Argument(
1835
std::string name = "",
@@ -79,6 +96,15 @@ struct Argument {
7996
return Argument(name_, new_type, N_, default_value_, kwarg_only_, alias_info_);
8097
}
8198

99+
// this function check whether this Argument is backward compatible with
100+
// the old one. we consider the following cases are backward compatible:
101+
// 1) two arguments are equal
102+
// 2) this arg's type should be subtype of old
103+
// 3) this arg must provide the same default value if old arg has one,
104+
bool isBackwardCompatibleWith(
105+
const Argument& old,
106+
std::ostream* why_not=nullptr) const;
107+
82108
private:
83109
std::string name_;
84110
TypePtr type_;
@@ -95,16 +121,6 @@ struct Argument {
95121
bool is_inferred_type_;
96122
};
97123

98-
namespace detail {
99-
inline bool defaultValueEquals_(const c10::optional<IValue>& lhs, const c10::optional<IValue>& rhs) {
100-
if (lhs.has_value()) {
101-
return rhs.has_value() && impl::shallowEquals(*lhs, *rhs);
102-
} else {
103-
return !rhs.has_value();
104-
}
105-
}
106-
}
107-
108124
inline bool operator==(const Argument& lhs, const Argument& rhs) {
109125
return lhs.name() == rhs.name()
110126
&& *lhs.type() == *rhs.type()
@@ -114,6 +130,8 @@ inline bool operator==(const Argument& lhs, const Argument& rhs) {
114130
&& lhs.alias_info() == rhs.alias_info();
115131
}
116132

133+
bool operator==(const FunctionSchema& lhs, const FunctionSchema& rhs);
134+
117135
struct FunctionSchema {
118136
FunctionSchema(
119137
std::string name,
@@ -143,6 +161,22 @@ struct FunctionSchema {
143161
is_vararg,
144162
is_varret) {}
145163

164+
// check whether this schema is backward compatible with the old one.
165+
// the following conditions are considered as this schema is backward
166+
// compatible with old:
167+
// 1) two schemas are equal
168+
// 2) this schema has the same or more positional args than old,
169+
// and any positional arg in this schema is backward compatible
170+
// with the corresponding one in old schema, which could be an arg
171+
// or a kwarg, if it has, or it must provide a default value
172+
// 3) this schema has the same or more kwargs than old, and all the kwargs
173+
// in old schema can find the corresponding kwarg in this schema which
174+
// is backward compatible with the old kwarg, and the extra kwargs in
175+
// this schema must provide default values.
176+
bool isBackwardCompatibleWith(
177+
const FunctionSchema& old,
178+
std::ostream* why_not=nullptr) const;
179+
146180
private:
147181
OperatorName name_;
148182
std::vector<Argument> arguments_;

aten/src/ATen/core/function_schema_inl.h

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,30 @@ inline std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema)
5151
return out;
5252
}
5353

54+
inline bool Argument::isBackwardCompatibleWith(
55+
const Argument& old,
56+
std::ostream* why_not) const {
57+
const Argument* lhs = this;
58+
const Argument* rhs = &old;
59+
if (!(lhs->name() == rhs->name()
60+
&& lhs->N() == rhs->N()
61+
&& lhs->alias_info() == rhs->alias_info())) {
62+
return false;
63+
}
64+
if (lhs->kwarg_only() && !rhs->kwarg_only()) {
65+
return false;
66+
}
67+
if (!rhs->type()->isSubtypeOfExt(lhs->type(), why_not)) {
68+
return false;
69+
}
70+
if (rhs->default_value().has_value() &&
71+
!detail::defaultValueEquals_(lhs->default_value(),
72+
rhs->default_value())) {
73+
return false;
74+
}
75+
return true;
76+
}
77+
5478
inline std::string FunctionSchema::formatTypeMismatchMsg(
5579
const Argument& expected,
5680
const std::string& actual_type,
@@ -74,6 +98,90 @@ inline std::string FunctionSchema::formatTypeMismatchMsg(
7498
*this);
7599
}
76100

101+
inline bool FunctionSchema::isBackwardCompatibleWith(
102+
const FunctionSchema& old,
103+
std::ostream* why_not) const {
104+
if (!(name() == old.name()
105+
&& overload_name() == old.overload_name()
106+
// we are conservative on is_vararg and is_varret,
107+
// since they are only used by internal operators
108+
&& is_vararg() == old.is_vararg()
109+
&& is_varret() == old.is_varret()
110+
&& returns().size() == old.returns().size()
111+
&& arguments().size() >= old.arguments().size())) {
112+
return false;
113+
}
114+
for (size_t i = 0; i < returns().size(); ++i) {
115+
// functions are covariant in arguments but contravariant in returns
116+
if (!old.returns().at(i).isBackwardCompatibleWith(
117+
returns().at(i),
118+
why_not)) {
119+
return false;
120+
}
121+
}
122+
std::vector<const Argument*> args, old_args;
123+
std::map<std::string, const Argument*> kwargs, old_kwargs;
124+
auto split_func = [](const std::vector<Argument>& arguments,
125+
std::vector<const Argument*>* positionals,
126+
std::map<std::string, const Argument*>* nameds) {
127+
for (const Argument& arg : arguments) {
128+
if (!arg.kwarg_only()) {
129+
positionals->emplace_back(&arg);
130+
}
131+
nameds->emplace(arg.name(), &arg);
132+
}
133+
};
134+
// we split args into positional and keyward parts,
135+
split_func(arguments(), &args, &kwargs);
136+
split_func(old.arguments(), &old_args, &old_kwargs);
137+
if (old_args.size() > args.size()) {
138+
return false;
139+
}
140+
// make sure that all the old positional args have their corresponding
141+
// backward compatible positional args in this schema
142+
for (size_t i = 0; i < old_args.size(); ++i) {
143+
if (!args.at(i)->isBackwardCompatibleWith(
144+
*old_args.at(i),
145+
why_not)) {
146+
return false;
147+
}
148+
}
149+
// check the extra positional args in this schema either has corresponding
150+
// backward compatible keyward args since positional args also can be used as
151+
// a keyward arg, or provided default values
152+
for (size_t i = old_args.size(); i < args.size(); ++i) {
153+
if (!args.at(i)->default_value()) {
154+
auto it = old_kwargs.find(args.at(i)->name());
155+
if (it == old_kwargs.end() ||
156+
!args.at(i)->isBackwardCompatibleWith(
157+
*it->second,
158+
why_not)) {
159+
return false;
160+
}
161+
}
162+
}
163+
// make sure that all the keyword args in the old schema have their
164+
// corresponding backward compatible keyward args in this schema
165+
for (auto& kv : old_kwargs) {
166+
auto it = kwargs.find(kv.first);
167+
if (it == kwargs.end() ||
168+
!it->second->isBackwardCompatibleWith(
169+
*kv.second,
170+
why_not)) {
171+
return false;
172+
}
173+
kwargs.erase(it);
174+
}
175+
// check all the extra keyword args in this schema provide default values
176+
for (auto& kv : kwargs) {
177+
if (!kv.second->default_value()) {
178+
return false;
179+
}
180+
}
181+
182+
return true;
183+
}
184+
77185
inline void FunctionSchema::checkArg(
78186
const IValue& value,
79187
const Argument& argument,

test/test_function_schema.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44
from common_utils import TestCase, run_tests
5+
from torch._C import parse_schema
56

67

78
class TestFunctionSchema(TestCase):
@@ -10,8 +11,93 @@ def test_serialize_and_deserialize(self):
1011
# so far we have around 1700 registered schemas
1112
self.assertGreater(len(schemas), 1000)
1213
for schema in schemas:
13-
parsed_schema = torch._C.parse_schema(str(schema))
14+
parsed_schema = parse_schema(str(schema))
1415
self.assertEqual(parsed_schema, schema)
16+
self.assertTrue(parsed_schema.is_backward_compatible_with(schema))
17+
18+
def test_backward_compatible_args(self):
19+
old_schema = parse_schema('any(Tensor self, int dim) -> Tensor')
20+
new_schema = parse_schema('any(Tensor self, int? dim) -> Tensor')
21+
self.assertTrue(new_schema.is_backward_compatible_with(old_schema))
22+
self.assertFalse(old_schema.is_backward_compatible_with(new_schema))
23+
new_schema = parse_schema('any(Tensor self, int dim=5) -> Tensor')
24+
self.assertTrue(new_schema.is_backward_compatible_with(old_schema))
25+
self.assertFalse(old_schema.is_backward_compatible_with(new_schema))
26+
new_schema = parse_schema('any(Tensor self, int dim, bool keepdim=False) -> Tensor')
27+
self.assertTrue(new_schema.is_backward_compatible_with(old_schema))
28+
self.assertFalse(old_schema.is_backward_compatible_with(new_schema))
29+
30+
def test_backward_compatible_kwargs(self):
31+
old_schema = parse_schema('any(Tensor self, *, Tensor out) -> Tensor')
32+
new_schema = parse_schema('any(Tensor self, *, bool extra1=True, Tensor out, bool extra2=False) -> Tensor')
33+
self.assertTrue(new_schema.is_backward_compatible_with(old_schema))
34+
self.assertFalse(old_schema.is_backward_compatible_with(new_schema))
35+
new_schema = parse_schema('any(Tensor self, Tensor out) -> Tensor')
36+
self.assertTrue(new_schema.is_backward_compatible_with(old_schema))
37+
self.assertFalse(old_schema.is_backward_compatible_with(new_schema))
38+
39+
def test_backward_compatible_ret(self):
40+
old_schema = parse_schema('any(Tensor self) -> Tensor?')
41+
new_schema = parse_schema('any(Tensor self) -> Tensor')
42+
self.assertTrue(new_schema.is_backward_compatible_with(old_schema))
43+
self.assertFalse(old_schema.is_backward_compatible_with(new_schema))
44+
45+
def test_backward_incompatible_name(self):
46+
old_schema = parse_schema('any(Tensor self, int dim, bool keepdim=False) -> Tensor')
47+
new_schema = parse_schema('any_(Tensor self, int dim, bool keepdim=False) -> Tensor')
48+
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
49+
self.assertFalse(old_schema.is_backward_compatible_with(new_schema))
50+
51+
def test_backward_incompatible_vararg(self):
52+
old_schema = parse_schema('any(Tensor self, int dim, bool keepdim=False) -> Tensor')
53+
new_schema = parse_schema('any(Tensor self, int dim, bool keepdim=False, ...) -> Tensor')
54+
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
55+
self.assertFalse(old_schema.is_backward_compatible_with(new_schema))
56+
57+
def test_backward_incompatible_returns(self):
58+
old_schema = parse_schema('any(Tensor self, int dim, bool keepdim=False) -> Tensor')
59+
new_schema = parse_schema('any(Tensor self, int dim, bool keepdim=False) -> (Tensor, ...)')
60+
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
61+
self.assertFalse(old_schema.is_backward_compatible_with(new_schema))
62+
new_schema = parse_schema('any(Tensor self, int dim, bool keepdim=False) -> int')
63+
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
64+
self.assertFalse(old_schema.is_backward_compatible_with(new_schema))
65+
new_schema = parse_schema('any(Tensor self, int dim, bool keepdim=False) -> Tensor?')
66+
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
67+
self.assertTrue(old_schema.is_backward_compatible_with(new_schema))
68+
new_schema = parse_schema('any(Tensor self, int dim, bool keepdim=False) -> (Tensor, Tensor)')
69+
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
70+
self.assertFalse(old_schema.is_backward_compatible_with(new_schema))
71+
new_schema = parse_schema('any(Tensor self, int dim, bool keepdim=False) -> Tensor out')
72+
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
73+
self.assertFalse(old_schema.is_backward_compatible_with(new_schema))
74+
75+
def test_backward_incompatible_args(self):
76+
old_schema = parse_schema('any(Tensor self, int[] dims, bool keepdim=False) -> Tensor')
77+
new_schema = parse_schema('any(Tensor s, int[] dims, bool keepdim=False) -> Tensor')
78+
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
79+
self.assertFalse(old_schema.is_backward_compatible_with(new_schema))
80+
new_schema = parse_schema('any(Tensor self, int[3] dims, bool keepdim=False) -> Tensor')
81+
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
82+
self.assertFalse(old_schema.is_backward_compatible_with(new_schema))
83+
new_schema = parse_schema('any(Tensor self, int[](a) dims, bool keepdim=False) -> Tensor')
84+
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
85+
self.assertFalse(old_schema.is_backward_compatible_with(new_schema))
86+
new_schema = parse_schema('any(Tensor self, int dims, bool keepdim=False) -> Tensor')
87+
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
88+
self.assertFalse(old_schema.is_backward_compatible_with(new_schema))
89+
new_schema = parse_schema('any(Tensor self, int[] dim, bool keepdim=False, bool? extra) -> Tensor')
90+
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
91+
self.assertFalse(old_schema.is_backward_compatible_with(new_schema))
92+
93+
def test_backward_incompatible_kwargs(self):
94+
old_schema = parse_schema('any(Tensor self, int[] dims, *, bool keepdim=False) -> Tensor')
95+
new_schema = parse_schema('any(Tensor self, int[] dims, *, bool keepdim) -> Tensor')
96+
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
97+
self.assertTrue(old_schema.is_backward_compatible_with(new_schema))
98+
new_schema = parse_schema('any(Tensor self, int[] dims, *, bool keepdim=False, bool extra) -> Tensor')
99+
self.assertFalse(new_schema.is_backward_compatible_with(old_schema))
100+
self.assertFalse(old_schema.is_backward_compatible_with(new_schema))
15101

16102

17103
if __name__ == '__main__':

torch/csrc/jit/init.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,10 @@ void initJITBindings(PyObject* module) {
435435
"arguments", [](FunctionSchema& self) { return self.arguments(); })
436436
.def_property_readonly(
437437
"returns", [](FunctionSchema& self) { return self.returns(); })
438+
.def("is_backward_compatible_with",
439+
[](const FunctionSchema& self, const FunctionSchema& old_schema) {
440+
return self.isBackwardCompatibleWith(old_schema);
441+
})
438442
.def("__eq__", [](const FunctionSchema& self,
439443
const FunctionSchema& other) {
440444
return self == other;
@@ -453,11 +457,11 @@ void initJITBindings(PyObject* module) {
453457
return (self.N()) ? py::cast(*self.N()) : py::none();
454458
})
455459
.def_property_readonly("default_value", [](Argument& self) -> py::object {
456-
if (!self.default_value())
457-
return py::none();
458-
IValue v = *self.default_value();
459-
return toPyObject(std::move(v));
460-
});
460+
if (!self.default_value())
461+
return py::none();
462+
IValue v = *self.default_value();
463+
return toPyObject(std::move(v));
464+
});
461465
m.def(
462466
"_jit_get_all_schemas", []() {
463467
const std::vector<std::shared_ptr<Operator>>& operations = getAllOperators();

0 commit comments

Comments
 (0)