Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 35 additions & 41 deletions aten/src/ATen/core/dispatch/DispatchTable.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ class KernelTable_ final {
class DispatchTable final {
public:
DispatchTable(const FunctionSchema& schema)
: kernels_(make_left<detail::KernelTable_, DispatchTableEntry>())
: kernels_()
, catchall_kernel_(c10::nullopt)
, dispatch_strategy_(get_dispatch_strategy_(schema))
, operator_name_(schema.name()) {}

Expand All @@ -117,8 +118,7 @@ class DispatchTable final {
const DispatchTableEntry& kernel) {
TORCH_INTERNAL_ASSERT(dispatch_key != TensorTypeId::UndefinedTensorId);
TORCH_CHECK(dispatch_strategy_.is_valid_, "Tried to register a kernel with dispatch key ", toString(dispatch_key), " for operator ", operator_name_, " that doesn't have tensor arguments.");
TORCH_CHECK(kernels_.is_left(), "Tried to register a kernel with dispatch key ", toString(dispatch_key)," for operator ", operator_name_, ", which already has a catch-all kernel registered. An operator can only have either a catch-all kernel or kernels with dispatch keys.");
kernels_.left().set(dispatch_key, kernel, operator_name_);
kernels_.set(dispatch_key, kernel, operator_name_);
}

/**
Expand All @@ -127,8 +127,7 @@ class DispatchTable final {
* @param dispatch_key Dispatch key to unregister.
*/
void removeKernelIfExists(TensorTypeId dispatch_key) {
TORCH_INTERNAL_ASSERT(kernels_.is_left(), "Tried to remove the kernel for dispatch key ", toString(dispatch_key), " for operator ", operator_name_, ", which only has a catch-all kernel.");
kernels_.left().removeIfExists(dispatch_key, operator_name_);
kernels_.removeIfExists(dispatch_key, operator_name_);
}

/**
Expand All @@ -138,20 +137,18 @@ class DispatchTable final {
* dispatch keys, not both.
*/
void setCatchallKernel(const DispatchTableEntry& kernel) {
if (kernels_.is_right()) {
if (catchall_kernel_.has_value()) {
TORCH_WARN("Registered a catch-all kernel for operator ", operator_name_," that overwrote a previously registered catch-all kernel for the same operator.");
} else {
TORCH_CHECK(0 == kernels_.left().size(), "Tried to register a catch-all kernel for operator ", operator_name_, " which already has kernels with dispatch keys. An operator can only have either a catch-all kernel or kernels with dispatch keys.");
}
kernels_ = make_right<detail::KernelTable_, DispatchTableEntry>(kernel);
catchall_kernel_ = kernel;
}

/**
* Remove the catch-all kernel.
*/
void removeCatchallKernel() {
TORCH_INTERNAL_ASSERT(kernels_.is_right(), "Tried to remove the catch-all kernel for operator ", operator_name_," but there is no catch-all kernel registered.");
kernels_ = make_left<detail::KernelTable_, DispatchTableEntry>();
TORCH_INTERNAL_ASSERT(catchall_kernel_.has_value(), "Tried to remove the catch-all kernel for operator ", operator_name_," but there is no catch-all kernel registered.");
catchall_kernel_ = c10::nullopt;
}

/**
Expand All @@ -162,28 +159,28 @@ class DispatchTable final {
* @return Kernel function pointing to the right kernel for the given arguments.
*/
const DispatchTableEntry& lookup(const Stack* stack) const {
return lookup_([=] {
TORCH_INTERNAL_ASSERT(dispatch_strategy_.is_valid_, "Operator ", operator_name_, " has an invalid dispatch key but kernels registered.");
return lookup_([=] () -> c10::optional<TensorTypeId> {
if (!dispatch_strategy_.is_valid_) {
return c10::nullopt;
}
return dispatch_strategy_.get_dispatch_key(stack, operator_name_);
});
}

const DispatchTableEntry& lookup(TensorTypeId dispatchKey) const {
return lookup_([=] {return dispatchKey;});
return lookup_([=] () -> c10::optional<TensorTypeId> { return dispatchKey;});
}

bool isEmpty() const {
return kernels_.map<bool>(
[] (const detail::KernelTable_& table) {return 0 == table.size();},
[] (const DispatchTableEntry&) {return false;}
);
return !catchall_kernel_.has_value() && kernels_.size() == 0;
}

std::string listAllDispatchKeys() const {
return kernels_.map<std::string>(
[] (const detail::KernelTable_& table) {return table.list_all_dispatch_keys();},
[] (const DispatchTableEntry&) {return "CATCH-ALL";}
);
std::string result = kernels_.list_all_dispatch_keys();
if (catchall_kernel_.has_value()) {
result += ", CATCH-ALL";
}
return result;
}

private:
Expand Down Expand Up @@ -243,30 +240,27 @@ class DispatchTable final {

template<class GetDispatchKeyFunc>
const DispatchTableEntry& lookup_(const GetDispatchKeyFunc& getDispatchKey) const {
return kernels_.map<const DispatchTableEntry&>(
[&] (const detail::KernelTable_& table) -> const DispatchTableEntry& {
// We have a dispatch table. Find the correct kernel for the inputs and return it.
TensorTypeId dispatch_key = getDispatchKey();
auto found = table.lookup(dispatch_key);
c10::optional<TensorTypeId> dispatch_key = getDispatchKey();
if (dispatch_key.has_value()) {
const auto* found = kernels_.lookup(*dispatch_key);

TORCH_CHECK(nullptr != found, "Didn't find kernel to dispatch to for operator '", operator_name_,
"'. Tried to look up kernel for dispatch key '", toString(dispatch_key),
"'. Registered dispatch keys are: ", listAllDispatchKeys());
if (nullptr != found) {
return *found;
}
}

return *found;
},
[] (const DispatchTableEntry& entry) -> const DispatchTableEntry& {
// We have a catch-all kernel. Just return it.
return entry;
if (catchall_kernel_.has_value()) {
return *catchall_kernel_;
}
);

const std::string dispatch_key_str = dispatch_key.has_value() ? toString(*dispatch_key) : "None";
TORCH_CHECK(false, "Didn't find kernel to dispatch to for operator '", operator_name_,
"'. Tried to look up kernel for dispatch key '", dispatch_key_str,
"'. Registered dispatch keys are: ", listAllDispatchKeys());
}

// kernels_ either contains a dispatch table or
// a single catch-all kernel that is called for every backend
// The empty state (i.e. no kernels registered) is represented
// as an empty table.
either<detail::KernelTable_, DispatchTableEntry> kernels_;
detail::KernelTable_ kernels_;
c10::optional<DispatchTableEntry> catchall_kernel_;
DispatchStrategy dispatch_strategy_;
std::string operator_name_;
};
Expand Down
8 changes: 0 additions & 8 deletions aten/src/ATen/core/op_registration/op_registration_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -852,10 +852,6 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) {
c10::List<std::string>(), [] (const c10::List<std::string>& v) {EXPECT_EQ(0, v.size());},
c10::List<std::string>(), [] (const IValue& v) {EXPECT_EQ(0, v.toGenericListRef().size());},
"(str[] a) -> str[]");
testArgTypes<c10::List<Tensor>>::test(
c10::List<Tensor>({}), [] (const c10::List<Tensor>& v) {EXPECT_EQ(0, v.size());},
c10::List<Tensor>({}), [] (const IValue& v) {EXPECT_EQ(0, v.to<c10::List<at::Tensor>>().size());},
"(Tensor[] a) -> Tensor[]");


// list types (with non-empty list)
Expand Down Expand Up @@ -906,10 +902,6 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) {
std::vector<std::string>(), [] (const std::vector<std::string>& v) {EXPECT_EQ(0, v.size());},
std::vector<std::string>(), [] (const IValue& v) {EXPECT_EQ(0, v.toGenericListRef().size());},
"(str[] a) -> str[]");
testArgTypes<std::vector<Tensor>>::test<TestLegacyAPI>(
std::vector<Tensor>({}), [] (const std::vector<Tensor>& v) {EXPECT_EQ(0, v.size());},
std::vector<Tensor>({}), [] (const IValue& v) {EXPECT_EQ(0, v.to<c10::List<at::Tensor>>().size());},
"(Tensor[] a) -> Tensor[]");


// deprecated list types (with non-empty list)
Expand Down