Skip to content

Commit cadf836

Browse files
smessmerfacebook-github-bot
authored andcommitted
Allow overwriting catch-all kernels (#25947)
Summary: Pull Request resolved: #25947 Previously, the c10 dispatcher didn't allow having a catch-all kernel and backend specific kernels at the same time. This is also the long term goal. But to make the current XLA implementation work, we need to allow them to overwrite these ops with XLA variants. This diff changes that so that ops can have both, catchall and backend specific kernels, and will call into the catchall kernel if there is no more specific kernel registered. This is also the current behavior of globalATenDispatch. ghstack-source-id: 90049398 Test Plan: unit tests Differential Revision: D17293036 fbshipit-source-id: f2d5928e904c1dc9b6b89e9bb468debe48a4056c
1 parent b01520a commit cadf836

File tree

2 files changed

+35
-49
lines changed

2 files changed

+35
-49
lines changed

aten/src/ATen/core/dispatch/DispatchTable.h

Lines changed: 35 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ class KernelTable_ final {
103103
class DispatchTable final {
104104
public:
105105
DispatchTable(const FunctionSchema& schema)
106-
: kernels_(make_left<detail::KernelTable_, DispatchTableEntry>())
106+
: kernels_()
107+
, catchall_kernel_(c10::nullopt)
107108
, dispatch_strategy_(get_dispatch_strategy_(schema))
108109
, operator_name_(schema.name()) {}
109110

@@ -117,8 +118,7 @@ class DispatchTable final {
117118
const DispatchTableEntry& kernel) {
118119
TORCH_INTERNAL_ASSERT(dispatch_key != TensorTypeId::UndefinedTensorId);
119120
TORCH_CHECK(dispatch_strategy_.is_valid_, "Tried to register a kernel with dispatch key ", toString(dispatch_key), " for operator ", operator_name_, " that doesn't have tensor arguments.");
120-
TORCH_CHECK(kernels_.is_left(), "Tried to register a kernel with dispatch key ", toString(dispatch_key)," for operator ", operator_name_, ", which already has a catch-all kernel registered. An operator can only have either a catch-all kernel or kernels with dispatch keys.");
121-
kernels_.left().set(dispatch_key, kernel, operator_name_);
121+
kernels_.set(dispatch_key, kernel, operator_name_);
122122
}
123123

124124
/**
@@ -127,8 +127,7 @@ class DispatchTable final {
127127
* @param dispatch_key Dispatch key to unregister.
128128
*/
129129
void removeKernelIfExists(TensorTypeId dispatch_key) {
130-
TORCH_INTERNAL_ASSERT(kernels_.is_left(), "Tried to remove the kernel for dispatch key ", toString(dispatch_key), " for operator ", operator_name_, ", which only has a catch-all kernel.");
131-
kernels_.left().removeIfExists(dispatch_key, operator_name_);
130+
kernels_.removeIfExists(dispatch_key, operator_name_);
132131
}
133132

134133
/**
@@ -138,20 +137,18 @@ class DispatchTable final {
138137
* dispatch keys, not both.
139138
*/
140139
void setCatchallKernel(const DispatchTableEntry& kernel) {
141-
if (kernels_.is_right()) {
140+
if (catchall_kernel_.has_value()) {
142141
TORCH_WARN("Registered a catch-all kernel for operator ", operator_name_," that overwrote a previously registered catch-all kernel for the same operator.");
143-
} else {
144-
TORCH_CHECK(0 == kernels_.left().size(), "Tried to register a catch-all kernel for operator ", operator_name_, " which already has kernels with dispatch keys. An operator can only have either a catch-all kernel or kernels with dispatch keys.");
145142
}
146-
kernels_ = make_right<detail::KernelTable_, DispatchTableEntry>(kernel);
143+
catchall_kernel_ = kernel;
147144
}
148145

149146
/**
150147
* Remove the catch-all kernel.
151148
*/
152149
void removeCatchallKernel() {
153-
TORCH_INTERNAL_ASSERT(kernels_.is_right(), "Tried to remove the catch-all kernel for operator ", operator_name_," but there is no catch-all kernel registered.");
154-
kernels_ = make_left<detail::KernelTable_, DispatchTableEntry>();
150+
TORCH_INTERNAL_ASSERT(catchall_kernel_.has_value(), "Tried to remove the catch-all kernel for operator ", operator_name_," but there is no catch-all kernel registered.");
151+
catchall_kernel_ = c10::nullopt;
155152
}
156153

157154
/**
@@ -162,28 +159,28 @@ class DispatchTable final {
162159
* @return Kernel function pointing to the right kernel for the given arguments.
163160
*/
164161
const DispatchTableEntry& lookup(const Stack* stack) const {
165-
return lookup_([=] {
166-
TORCH_INTERNAL_ASSERT(dispatch_strategy_.is_valid_, "Operator ", operator_name_, " has an invalid dispatch key but kernels registered.");
162+
return lookup_([=] () -> c10::optional<TensorTypeId> {
163+
if (!dispatch_strategy_.is_valid_) {
164+
return c10::nullopt;
165+
}
167166
return dispatch_strategy_.get_dispatch_key(stack, operator_name_);
168167
});
169168
}
170169

171170
const DispatchTableEntry& lookup(TensorTypeId dispatchKey) const {
172-
return lookup_([=] {return dispatchKey;});
171+
return lookup_([=] () -> c10::optional<TensorTypeId> { return dispatchKey;});
173172
}
174173

175174
bool isEmpty() const {
176-
return kernels_.map<bool>(
177-
[] (const detail::KernelTable_& table) {return 0 == table.size();},
178-
[] (const DispatchTableEntry&) {return false;}
179-
);
175+
return !catchall_kernel_.has_value() && kernels_.size() == 0;
180176
}
181177

182178
std::string listAllDispatchKeys() const {
183-
return kernels_.map<std::string>(
184-
[] (const detail::KernelTable_& table) {return table.list_all_dispatch_keys();},
185-
[] (const DispatchTableEntry&) {return "CATCH-ALL";}
186-
);
179+
std::string result = kernels_.list_all_dispatch_keys();
180+
if (catchall_kernel_.has_value()) {
181+
result += ", CATCH-ALL";
182+
}
183+
return result;
187184
}
188185

189186
private:
@@ -243,30 +240,27 @@ class DispatchTable final {
243240

244241
template<class GetDispatchKeyFunc>
245242
const DispatchTableEntry& lookup_(const GetDispatchKeyFunc& getDispatchKey) const {
246-
return kernels_.map<const DispatchTableEntry&>(
247-
[&] (const detail::KernelTable_& table) -> const DispatchTableEntry& {
248-
// We have a dispatch table. Find the correct kernel for the inputs and return it.
249-
TensorTypeId dispatch_key = getDispatchKey();
250-
auto found = table.lookup(dispatch_key);
243+
c10::optional<TensorTypeId> dispatch_key = getDispatchKey();
244+
if (dispatch_key.has_value()) {
245+
const auto* found = kernels_.lookup(*dispatch_key);
251246

252-
TORCH_CHECK(nullptr != found, "Didn't find kernel to dispatch to for operator '", operator_name_,
253-
"'. Tried to look up kernel for dispatch key '", toString(dispatch_key),
254-
"'. Registered dispatch keys are: ", listAllDispatchKeys());
247+
if (nullptr != found) {
248+
return *found;
249+
}
250+
}
255251

256-
return *found;
257-
},
258-
[] (const DispatchTableEntry& entry) -> const DispatchTableEntry& {
259-
// We have a catch-all kernel. Just return it.
260-
return entry;
252+
if (catchall_kernel_.has_value()) {
253+
return *catchall_kernel_;
261254
}
262-
);
255+
256+
const std::string dispatch_key_str = dispatch_key.has_value() ? toString(*dispatch_key) : "None";
257+
TORCH_CHECK(false, "Didn't find kernel to dispatch to for operator '", operator_name_,
258+
"'. Tried to look up kernel for dispatch key '", dispatch_key_str,
259+
"'. Registered dispatch keys are: ", listAllDispatchKeys());
263260
}
264261

265-
// kernels_ either contains a dispatch table or
266-
// a single catch-all kernel that is called for every backend
267-
// The empty state (i.e. no kernels registered) is represented
268-
// as an empty table.
269-
either<detail::KernelTable_, DispatchTableEntry> kernels_;
262+
detail::KernelTable_ kernels_;
263+
c10::optional<DispatchTableEntry> catchall_kernel_;
270264
DispatchStrategy dispatch_strategy_;
271265
std::string operator_name_;
272266
};

aten/src/ATen/core/op_registration/op_registration_test.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -852,10 +852,6 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) {
852852
c10::List<std::string>(), [] (const c10::List<std::string>& v) {EXPECT_EQ(0, v.size());},
853853
c10::List<std::string>(), [] (const IValue& v) {EXPECT_EQ(0, v.toGenericListRef().size());},
854854
"(str[] a) -> str[]");
855-
testArgTypes<c10::List<Tensor>>::test(
856-
c10::List<Tensor>({}), [] (const c10::List<Tensor>& v) {EXPECT_EQ(0, v.size());},
857-
c10::List<Tensor>({}), [] (const IValue& v) {EXPECT_EQ(0, v.to<c10::List<at::Tensor>>().size());},
858-
"(Tensor[] a) -> Tensor[]");
859855

860856

861857
// list types (with non-empty list)
@@ -906,10 +902,6 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) {
906902
std::vector<std::string>(), [] (const std::vector<std::string>& v) {EXPECT_EQ(0, v.size());},
907903
std::vector<std::string>(), [] (const IValue& v) {EXPECT_EQ(0, v.toGenericListRef().size());},
908904
"(str[] a) -> str[]");
909-
testArgTypes<std::vector<Tensor>>::test<TestLegacyAPI>(
910-
std::vector<Tensor>({}), [] (const std::vector<Tensor>& v) {EXPECT_EQ(0, v.size());},
911-
std::vector<Tensor>({}), [] (const IValue& v) {EXPECT_EQ(0, v.to<c10::List<at::Tensor>>().size());},
912-
"(Tensor[] a) -> Tensor[]");
913905

914906

915907
// deprecated list types (with non-empty list)

0 commit comments

Comments
 (0)