Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions aten/src/ATen/core/function_schema.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,14 @@ struct OperatorName final {
std::string overload_name;
};

inline std::string toString(const OperatorName& opName) {
std::string result = opName.name;
if (opName.overload_name.size() != 0) {
result += "." + opName.overload_name;
}
return result;
}

struct FunctionSchema {
FunctionSchema(
std::string name,
Expand Down Expand Up @@ -237,9 +245,9 @@ struct FunctionSchema {
return false;
}

// can a function with this schema be substituted for a function of rhs's
// can a function with this schema be substituted for a function of rhs's
// schema and have the program typecheck?
// as_method - if true, treat this schema as a method and ignore
// as_method - if true, treat this schema as a method and ignore
// the first argument, which will be the object in both cases
bool isSubtypeOf(const FunctionSchema& rhs, bool as_method, std::ostream* why_not=nullptr) const;
};
Expand Down
130 changes: 62 additions & 68 deletions aten/src/ATen/core/op_registration/op_registration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,55 +43,65 @@ class RegisterOperators::OperatorRegistrar final {
c10::optional<RegistrationHandleRAII> unboxed_autograd_kernel_registration_handle_;
};

void RegisterOperators::checkSchemaAndRegisterOp_(const std::string& schemaOrNameStr, Options&& options) {
#if defined(CAFFE2_IS_XPLAT_BUILD)
throw std::logic_error("Tried to register operator " + schemaOrNameStr + ". We don't support registering c10 ops on mobile yet because the function schema parser isn't present in the mobile build.");
#else
either<OperatorName, FunctionSchema> schemaOrName = torch::jit::parseSchemaOrName(schemaOrNameStr);
if (schemaOrName.is_right()) {
// schema was explicitly specified. Check it matches the inferred one and register the op.

auto schema = std::move(schemaOrName).right();
TORCH_CHECK(
options.aliasAnalysisKind_ == AliasAnalysisKind::FROM_SCHEMA ||
!schema.hasAnyAliasInfo(),
"In operator registration: Tried to register operator ",
schemaOrNameStr,
" with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA.");

checkSchemaAndRegisterOp_(std::move(schema), std::move(options));
} else {
// schema wasn't explicitly specified. Take the inferred schema for registering the op.

FunctionSchema inferred_schema = inferSchemaFromKernels_(schemaOrNameStr, options);
OperatorName name = std::move(schemaOrName).left();
FunctionSchema inferred_schema_with_name(
std::move(name.name),
std::move(name.overload_name),
inferred_schema.arguments(),
inferred_schema.returns(),
inferred_schema.is_vararg(),
inferred_schema.is_varret()
);

checkNoDuplicateKernels_(inferred_schema_with_name, options);

// This would have unexpected behavior since an inferred schema will not
// have aliasing annotations.
TORCH_CHECK(
options.aliasAnalysisKind_ != AliasAnalysisKind::FROM_SCHEMA,
"In operator registration: Tried to register operator ",
schemaOrNameStr,
" with AliasAnalysisKind::FROM_SCHEMA, but the schema is inferred.");

// Register all kernels with the schema we inferred
registerOp_(std::move(inferred_schema_with_name), std::move(options));
void RegisterOperators::checkSchemaAndRegisterOp_(Options&& options) {
TORCH_CHECK(options.schemaOrName_.has_value(), "In operator registration: Tried to register an operator without specifying a schema or operator name.");
if (options.schemaOrName_->is_right()) {
// schema was explicitly specified. Check it matches the inferred one and register the op.

const FunctionSchema& schema = options.schemaOrName_->right();
TORCH_CHECK(
options.aliasAnalysisKind_ == AliasAnalysisKind::FROM_SCHEMA ||
!schema.hasAnyAliasInfo(),
"In operator registration: Tried to register operator ",
options.schemaOrName_->right(),
" with aliasing information in the schema but without AliasAnalysisKind::FROM_SCHEMA.");

for (auto& kernel : options.kernels) {
if (nullptr != kernel.inferred_function_schema.get()) {
c10::optional<std::string> schema_difference = findSchemaDifferences(schema, *kernel.inferred_function_schema);
if (schema_difference.has_value()) {
TORCH_CHECK(false, "In operator registration: Specified function schema [", toString(schema), "] ",
"doesn't match inferred function schema [", toString(*kernel.inferred_function_schema), "]. ",
*schema_difference);
}
}
}
#endif

checkNoDuplicateKernels_(options);

registerOp_(std::move(options));
} else {
// schema wasn't explicitly specified. Take the inferred schema for registering the op.

OperatorName name = std::move(*options.schemaOrName_).left();
FunctionSchema inferred_schema = inferSchemaFromKernels_(name, options);

options.schemaOrName_ = c10::make_right<OperatorName, FunctionSchema>(
std::move(name.name),
std::move(name.overload_name),
inferred_schema.arguments(),
inferred_schema.returns(),
inferred_schema.is_vararg(),
inferred_schema.is_varret()
);

checkNoDuplicateKernels_(options);

// This would have unexpected behavior since an inferred schema will not
// have aliasing annotations.
TORCH_CHECK(
options.aliasAnalysisKind_ != AliasAnalysisKind::FROM_SCHEMA,
"In operator registration: Tried to register operator ",
options.schemaOrName_->right(),
" with AliasAnalysisKind::FROM_SCHEMA, but the schema is inferred.");

// Register all kernels with the schema we inferred
registerOp_(std::move(options));
}
}

c10::FunctionSchema RegisterOperators::inferSchemaFromKernels_(const std::string& opNameStr, const RegisterOperators::Options& options) {
TORCH_CHECK(options.kernels.size() > 0, "Cannot infer operator schema in registration of operator ", opNameStr, " because there is no kernel specified.");
c10::FunctionSchema RegisterOperators::inferSchemaFromKernels_(const OperatorName& opName, const RegisterOperators::Options& options) {
TORCH_CHECK(options.kernels.size() > 0, "Cannot infer operator schema in registration of operator ", toString(opName), " because there is no kernel specified.");

c10::optional<FunctionSchema> inferred_schema = c10::nullopt;
for (const auto& kernel : options.kernels) {
Expand All @@ -108,44 +118,28 @@ c10::FunctionSchema RegisterOperators::inferSchemaFromKernels_(const std::string
}
}
}
TORCH_CHECK(inferred_schema.has_value(), "Cannot infer operator schema for this kind of kernel in registration of operator ", opNameStr,". Please explicitly specify the operator schema or specify at least one kernel for which we can infer the schema.");
TORCH_CHECK(inferred_schema.has_value(), "Cannot infer operator schema for this kind of kernel in registration of operator ", toString(opName), ". Please explicitly specify the operator schema or specify at least one kernel for which we can infer the schema.");

return *inferred_schema;
}

void RegisterOperators::checkSchemaAndRegisterOp_(FunctionSchema schema, Options&& options) {
for (auto& kernel : options.kernels) {
if (nullptr != kernel.inferred_function_schema.get()) {
c10::optional<std::string> schema_difference = findSchemaDifferences(schema, *kernel.inferred_function_schema);
if (schema_difference.has_value()) {
TORCH_CHECK(false, "In operator registration: Specified function schema [", toString(schema), "] ",
"doesn't match inferred function schema [", toString(*kernel.inferred_function_schema), "]. ",
*schema_difference);
}
}
}

checkNoDuplicateKernels_(schema, options);

registerOp_(std::move(schema), std::move(options));
}

void RegisterOperators::checkNoDuplicateKernels_(const FunctionSchema& schema, const Options& options) {
void RegisterOperators::checkNoDuplicateKernels_(const Options& options) {
std::unordered_set<TensorTypeId> dispatch_keys;
bool has_catchall_kernel = false;

for (const auto& kernel : options.kernels) {
if (kernel.dispatch_key.has_value()) {
TORCH_CHECK(0 == dispatch_keys.count(*kernel.dispatch_key), "In operator registration: Tried to register multiple kernels with same dispatch key ", toString(*kernel.dispatch_key), " for operator schema ", toString(schema));
TORCH_CHECK(0 == dispatch_keys.count(*kernel.dispatch_key), "In operator registration: Tried to register multiple kernels with same dispatch key ", toString(*kernel.dispatch_key), " for operator schema ", toString(options.schemaOrName_->right()));
dispatch_keys.insert(*kernel.dispatch_key);
} else {
TORCH_CHECK(!has_catchall_kernel, "In operator registration: Tried to register multiple catch-all kernels for operator schema " + toString(schema));
TORCH_CHECK(!has_catchall_kernel, "In operator registration: Tried to register multiple catch-all kernels for operator schema " + toString(options.schemaOrName_->right()));
has_catchall_kernel = true;
}
}
}

void RegisterOperators::registerOp_(FunctionSchema&& schema, Options&& options) {
void RegisterOperators::registerOp_(Options&& options) {
FunctionSchema schema = std::move(*options.schemaOrName_).right();
OperatorName op_name = schema.operator_name();

auto operatorOptions = makeOperatorOptions_(options);
Expand Down
Loading