Skip to content

Commit b01520a

Browse files
smessmerfacebook-github-bot
authored andcommitted
Make schema part of RegisterOperators::Options (#26114)
Summary: Pull Request resolved: #26114 With this diff, the operator schema or name can be specified as part of the options objects: ``` static auto registry = torch::RegisterOperators() .op(torch::RegisterOperators::options().schema("my_op").kernel(&kernel)) .op(...); ``` This does not break backwards compatibility, all old APIs are kept as shorthands. This (a) makes the API more consistent, accumulating all options into the options objects and not treating schema special anymore, and (b) this is required for allowing the c10 dispatcher to forward registration calls to ATenDispatch for ops that are still on that dispatcher, see plan in #24132 ghstack-source-id: 90049402 Test Plan: unit tests Differential Revision: D17350383 fbshipit-source-id: cbb8f33a52dccb2a4522753e7b5ac8ba35b908fd
1 parent 0ea5978 commit b01520a

File tree

5 files changed

+206
-96
lines changed

5 files changed

+206
-96
lines changed

aten/src/ATen/core/function_schema.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,14 @@ struct OperatorName final {
118118
std::string overload_name;
119119
};
120120

121+
inline std::string toString(const OperatorName& opName) {
122+
std::string result = opName.name;
123+
if (opName.overload_name.size() != 0) {
124+
result += "." + opName.overload_name;
125+
}
126+
return result;
127+
}
128+
121129
struct FunctionSchema {
122130
FunctionSchema(
123131
std::string name,
@@ -237,9 +245,9 @@ struct FunctionSchema {
237245
return false;
238246
}
239247

240-
// can a function with this schema be substituted for a function of rhs's
248+
// can a function with this schema be substituted for a function of rhs's
241249
// schema and have the program typecheck?
242-
// as_method - if true, treat this schema as a method and ignore
250+
// as_method - if true, treat this schema as a method and ignore
243251
// the first argument, which will be the object in both cases
244252
bool isSubtypeOf(const FunctionSchema& rhs, bool as_method, std::ostream* why_not=nullptr) const;
245253
};

aten/src/ATen/core/op_registration/op_registration.cpp

Lines changed: 62 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -43,55 +43,65 @@ class RegisterOperators::OperatorRegistrar final {
4343
c10::optional<RegistrationHandleRAII> unboxed_autograd_kernel_registration_handle_;
4444
};
4545

46-
void RegisterOperators::checkSchemaAndRegisterOp_(const std::string& schemaOrNameStr, Options&& options) {
47-
#if defined(CAFFE2_IS_XPLAT_BUILD)
48-
throw std::logic_error("Tried to register operator " + schemaOrNameStr + ". We don't support registering c10 ops on mobile yet because the function schema parser isn't present in the mobile build.");
49-
#else
50-
either<OperatorName, FunctionSchema> schemaOrName = torch::jit::parseSchemaOrName(schemaOrNameStr);
51-
if (schemaOrName.is_right()) {
52-
// schema was explicitly specified. Check it matches the inferred one and register the op.
53-
54-
auto schema = std::move(schemaOrName).right();
55-
TORCH_CHECK(
56-
options.aliasAnalysisKind_ == AliasAnalysisKind::FROM_SCHEMA ||
57-
!schema.hasAnyAliasInfo(),
58-
"In operator registration: Tried to register operator ",
59-
schemaOrNameStr,
60-
" with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA.");
61-
62-
checkSchemaAndRegisterOp_(std::move(schema), std::move(options));
63-
} else {
64-
// schema wasn't explicitly specified. Take the inferred schema for registering the op.
65-
66-
FunctionSchema inferred_schema = inferSchemaFromKernels_(schemaOrNameStr, options);
67-
OperatorName name = std::move(schemaOrName).left();
68-
FunctionSchema inferred_schema_with_name(
69-
std::move(name.name),
70-
std::move(name.overload_name),
71-
inferred_schema.arguments(),
72-
inferred_schema.returns(),
73-
inferred_schema.is_vararg(),
74-
inferred_schema.is_varret()
75-
);
76-
77-
checkNoDuplicateKernels_(inferred_schema_with_name, options);
78-
79-
// This would have unexpected behavior since an inferred schema will not
80-
// have aliasing annotations.
81-
TORCH_CHECK(
82-
options.aliasAnalysisKind_ != AliasAnalysisKind::FROM_SCHEMA,
83-
"In operator registration: Tried to register operator ",
84-
schemaOrNameStr,
85-
" with AliasAnalysisKind::FROM_SCHEMA, but the schema is inferred.");
86-
87-
// Register all kernels with the schema we inferred
88-
registerOp_(std::move(inferred_schema_with_name), std::move(options));
46+
void RegisterOperators::checkSchemaAndRegisterOp_(Options&& options) {
47+
TORCH_CHECK(options.schemaOrName_.has_value(), "In operator registration: Tried to register an operator without specifying a schema or operator name.");
48+
if (options.schemaOrName_->is_right()) {
49+
// schema was explicitly specified. Check it matches the inferred one and register the op.
50+
51+
const FunctionSchema& schema = options.schemaOrName_->right();
52+
TORCH_CHECK(
53+
options.aliasAnalysisKind_ == AliasAnalysisKind::FROM_SCHEMA ||
54+
!schema.hasAnyAliasInfo(),
55+
"In operator registration: Tried to register operator ",
56+
options.schemaOrName_->right(),
57+
" with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA.");
58+
59+
for (auto& kernel : options.kernels) {
60+
if (nullptr != kernel.inferred_function_schema.get()) {
61+
c10::optional<std::string> schema_difference = findSchemaDifferences(schema, *kernel.inferred_function_schema);
62+
if (schema_difference.has_value()) {
63+
TORCH_CHECK(false, "In operator registration: Specified function schema [", toString(schema), "] ",
64+
"doesn't match inferred function schema [", toString(*kernel.inferred_function_schema), "]. ",
65+
*schema_difference);
66+
}
67+
}
8968
}
90-
#endif
69+
70+
checkNoDuplicateKernels_(options);
71+
72+
registerOp_(std::move(options));
73+
} else {
74+
// schema wasn't explicitly specified. Take the inferred schema for registering the op.
75+
76+
OperatorName name = std::move(*options.schemaOrName_).left();
77+
FunctionSchema inferred_schema = inferSchemaFromKernels_(name, options);
78+
79+
options.schemaOrName_ = c10::make_right<OperatorName, FunctionSchema>(
80+
std::move(name.name),
81+
std::move(name.overload_name),
82+
inferred_schema.arguments(),
83+
inferred_schema.returns(),
84+
inferred_schema.is_vararg(),
85+
inferred_schema.is_varret()
86+
);
87+
88+
checkNoDuplicateKernels_(options);
89+
90+
// This would have unexpected behavior since an inferred schema will not
91+
// have aliasing annotations.
92+
TORCH_CHECK(
93+
options.aliasAnalysisKind_ != AliasAnalysisKind::FROM_SCHEMA,
94+
"In operator registration: Tried to register operator ",
95+
options.schemaOrName_->right(),
96+
" with AliasAnalysisKind::FROM_SCHEMA, but the schema is inferred.");
97+
98+
// Register all kernels with the schema we inferred
99+
registerOp_(std::move(options));
100+
}
91101
}
92102

93-
c10::FunctionSchema RegisterOperators::inferSchemaFromKernels_(const std::string& opNameStr, const RegisterOperators::Options& options) {
94-
TORCH_CHECK(options.kernels.size() > 0, "Cannot infer operator schema in registration of operator ", opNameStr, " because there is no kernel specified.");
103+
c10::FunctionSchema RegisterOperators::inferSchemaFromKernels_(const OperatorName& opName, const RegisterOperators::Options& options) {
104+
TORCH_CHECK(options.kernels.size() > 0, "Cannot infer operator schema in registration of operator ", toString(opName), " because there is no kernel specified.");
95105

96106
c10::optional<FunctionSchema> inferred_schema = c10::nullopt;
97107
for (const auto& kernel : options.kernels) {
@@ -108,44 +118,28 @@ c10::FunctionSchema RegisterOperators::inferSchemaFromKernels_(const std::string
108118
}
109119
}
110120
}
111-
TORCH_CHECK(inferred_schema.has_value(), "Cannot infer operator schema for this kind of kernel in registration of operator ", opNameStr,". Please explicitly specify the operator schema or specify at least one kernel for which we can infer the schema.");
121+
TORCH_CHECK(inferred_schema.has_value(), "Cannot infer operator schema for this kind of kernel in registration of operator ", toString(opName), ". Please explicitly specify the operator schema or specify at least one kernel for which we can infer the schema.");
112122

113123
return *inferred_schema;
114124
}
115125

116-
void RegisterOperators::checkSchemaAndRegisterOp_(FunctionSchema schema, Options&& options) {
117-
for (auto& kernel : options.kernels) {
118-
if (nullptr != kernel.inferred_function_schema.get()) {
119-
c10::optional<std::string> schema_difference = findSchemaDifferences(schema, *kernel.inferred_function_schema);
120-
if (schema_difference.has_value()) {
121-
TORCH_CHECK(false, "In operator registration: Specified function schema [", toString(schema), "] ",
122-
"doesn't match inferred function schema [", toString(*kernel.inferred_function_schema), "]. ",
123-
*schema_difference);
124-
}
125-
}
126-
}
127-
128-
checkNoDuplicateKernels_(schema, options);
129-
130-
registerOp_(std::move(schema), std::move(options));
131-
}
132-
133-
void RegisterOperators::checkNoDuplicateKernels_(const FunctionSchema& schema, const Options& options) {
126+
void RegisterOperators::checkNoDuplicateKernels_(const Options& options) {
134127
std::unordered_set<TensorTypeId> dispatch_keys;
135128
bool has_catchall_kernel = false;
136129

137130
for (const auto& kernel : options.kernels) {
138131
if (kernel.dispatch_key.has_value()) {
139-
TORCH_CHECK(0 == dispatch_keys.count(*kernel.dispatch_key), "In operator registration: Tried to register multiple kernels with same dispatch key ", toString(*kernel.dispatch_key), " for operator schema ", toString(schema));
132+
TORCH_CHECK(0 == dispatch_keys.count(*kernel.dispatch_key), "In operator registration: Tried to register multiple kernels with same dispatch key ", toString(*kernel.dispatch_key), " for operator schema ", toString(options.schemaOrName_->right()));
140133
dispatch_keys.insert(*kernel.dispatch_key);
141134
} else {
142-
TORCH_CHECK(!has_catchall_kernel, "In operator registration: Tried to register multiple catch-all kernels for operator schema " + toString(schema));
135+
TORCH_CHECK(!has_catchall_kernel, "In operator registration: Tried to register multiple catch-all kernels for operator schema " + toString(options.schemaOrName_->right()));
143136
has_catchall_kernel = true;
144137
}
145138
}
146139
}
147140

148-
void RegisterOperators::registerOp_(FunctionSchema&& schema, Options&& options) {
141+
void RegisterOperators::registerOp_(Options&& options) {
142+
FunctionSchema schema = std::move(*options.schemaOrName_).right();
149143
OperatorName op_name = schema.operator_name();
150144

151145
auto operatorOptions = makeOperatorOptions_(options);

0 commit comments

Comments
 (0)