Skip to content

Commit 0e30e65

Browse files
smessmerfacebook-github-bot
authored andcommitted
Call aten ops through c10 dispatcher (#23668)
Summary: Pull Request resolved: #23668 - The eager mode frontend now calls operators who are defined in native_functions.yaml with `use_c10_dispatcher: True` through the c10 dispatcher and not anymore through globalATenDispatch(). - These operators aren't registered with globalAtenDispatch anymore, only on c10 now. - Backend extensions calling globalATenDispatch().registerOp() to add their own kernels still work, this function will forward the registration to the c10 dispatcher for them. ghstack-source-id: 90130455 Test Plan: benchmarks at https://docs.google.com/document/d/1gpzKZcFf1JJameY1vKxF7Cloul9s6D8HKIK2_Pp1hFo/edit# Differential Revision: D16603133 fbshipit-source-id: 991f17b355e9c78c5e86fee4fa381df7ab98ac82
1 parent e86d99a commit 0e30e65

File tree

11 files changed

+1108
-757
lines changed

11 files changed

+1108
-757
lines changed

aten/src/ATen/core/TensorMethods.h

Lines changed: 973 additions & 650 deletions
Large diffs are not rendered by default.

aten/src/ATen/core/dispatch/DispatchTable.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,12 @@ class DispatchTable final {
117117
TensorTypeId dispatch_key,
118118
const DispatchTableEntry& kernel) {
119119
TORCH_INTERNAL_ASSERT(dispatch_key != TensorTypeId::UndefinedTensorId);
120-
TORCH_CHECK(dispatch_strategy_.is_valid_, "Tried to register a kernel with dispatch key ", toString(dispatch_key), " for operator ", operator_name_, " that doesn't have tensor arguments.");
120+
// The following assertion is disabled because we're codegenerating
121+
// autograd kernels for operators without tensor arguments even though
122+
// they are never called. These, however, register kernels for
123+
// VariableTensorId.
124+
// TODO Stop generating these kernels and re-enable this assertion here.
125+
//TORCH_CHECK(dispatch_strategy_.is_valid_, "Tried to register a kernel with dispatch key ", toString(dispatch_key), " for operator ", operator_name_, " that doesn't have tensor arguments.");
121126
kernels_.set(dispatch_key, kernel, operator_name_);
122127
}
123128

aten/src/ATen/core/dispatch/OperatorEntry.cpp

Lines changed: 15 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ namespace {
2020
OperatorEntry::OperatorEntry(FunctionSchema&& schema, OperatorOptions&& options)
2121
: schema_(std::move(schema))
2222
, dispatchTable_(schema_)
23-
, kernels_(make_left<ska::flat_hash_map<TensorTypeId, std::list<DispatchTableEntry>>, std::list<DispatchTableEntry>>())
23+
, kernels_()
24+
, catchAllKernels_()
2425
, options_(std::move(options)) {
2526
}
2627

@@ -30,18 +31,16 @@ void OperatorEntry::prepareForDeregistration() {
3031
TORCH_INTERNAL_ASSERT(false, "Tried to deregister op schema for an operator that still has kernels registered. The operator schema is ", toString(schema_), ". Registered kernels for dispatch keys: ", dispatchTable.listAllDispatchKeys());
3132
}
3233
});
33-
TORCH_INTERNAL_ASSERT(kernels_.is_left(), "If the dispatch table is empty, then the invariant says there can't be any kernels but we still have a catch-all kernel. The operator schema is ", toString(schema_));
34-
TORCH_INTERNAL_ASSERT(kernels_.left().size() == 0, "If the dispatch table is empty, then the invariant says there can't be any kernels but we still have kernels for dispatch keys ", listAllDispatchKeys(kernels_.left()), ". The operator schema is ", toString(schema_));
34+
TORCH_INTERNAL_ASSERT(kernels_.size() == 0, "If the dispatch table is empty, then the invariant says there can't be any kernels but we still have kernels for dispatch keys ", listAllDispatchKeys(kernels_), ". The operator schema is ", toString(schema_));
35+
TORCH_INTERNAL_ASSERT(catchAllKernels_.size() == 0, "If the dispatch table is empty, then the invariant says there can't be any kernels but we still have catch-all kernel. The operator schema is ", toString(schema_));
3536
}
3637

3738
RegistrationHandleRAII OperatorEntry::registerKernel(TensorTypeId dispatch_key, DispatchTableEntry kernel) {
3839
std::unique_lock<std::mutex> lock(kernelsMutex_);
3940

40-
TORCH_CHECK(kernels_.is_left(), "Tried to register a kernel with dispatch key ", toString(dispatch_key)," for an operator which already has a catch-all kernel registered. An operator can only have either a catch-all kernel or kernels with dispatch keys. The operator schema is ", toString(schema_));
41-
4241
// Add the kernel to the kernels list,
4342
// possibly creating the list if this is the first kernel.
44-
auto& k = kernels_.left()[dispatch_key];
43+
auto& k = kernels_[dispatch_key];
4544
k.push_front(kernel);
4645
std::list<DispatchTableEntry>::iterator inserted = k.begin();
4746
// update the dispatch table, i.e. re-establish the invariant
@@ -58,16 +57,10 @@ RegistrationHandleRAII OperatorEntry::registerKernel(TensorTypeId dispatch_key,
5857
RegistrationHandleRAII OperatorEntry::registerCatchallKernel(DispatchTableEntry kernel) {
5958
std::unique_lock<std::mutex> lock(kernelsMutex_);
6059

61-
if (kernels_.is_left()) {
62-
TORCH_CHECK(0 == kernels_.left().size(), "Tried to register a catch-all kernel for an operator which already has kernels for dispatch keys ", listAllDispatchKeys(kernels_.left()), ". An operator can only have either a catch-all kernel or kernels with dispatch keys. The operator schema is ", toString(schema_));
63-
kernels_ = make_right<ska::flat_hash_map<TensorTypeId, std::list<DispatchTableEntry>>, std::list<DispatchTableEntry>>();
64-
}
65-
6660
// Add the kernel to the kernels list,
6761
// possibly creating the list if this is the first kernel.
68-
auto& k = kernels_.right();
69-
k.push_front(kernel);
70-
std::list<DispatchTableEntry>::iterator inserted = k.begin();
62+
catchAllKernels_.push_front(kernel);
63+
std::list<DispatchTableEntry>::iterator inserted = catchAllKernels_.begin();
7164
// update the dispatch table, i.e. re-establish the invariant
7265
// that the dispatch table points to the newest kernel
7366
updateCatchallDispatchTable_();
@@ -82,16 +75,13 @@ RegistrationHandleRAII OperatorEntry::registerCatchallKernel(DispatchTableEntry
8275
void OperatorEntry::deregisterKernel_(TensorTypeId dispatch_key, std::list<DispatchTableEntry>::iterator kernel) {
8376
std::unique_lock<std::mutex> lock(kernelsMutex_);
8477

85-
TORCH_CHECK(kernels_.is_left(), "Tried deregister a kernel for dispatch key ", toString(dispatch_key), " for an operator that only has a catch-all kernel. The operator schema is ", toString(schema_));
86-
87-
auto& kernels = kernels_.left();
88-
auto found = kernels.find(dispatch_key);
89-
TORCH_INTERNAL_ASSERT(found != kernels.end(), "Tried to deregister a kernel for dispatch key ", toString(dispatch_key), " but there are no kernels registered for this dispatch key. The operator schema is ", toString(schema_));
78+
auto found = kernels_.find(dispatch_key);
79+
TORCH_INTERNAL_ASSERT(found != kernels_.end(), "Tried to deregister a kernel for dispatch key ", toString(dispatch_key), " but there are no kernels registered for this dispatch key. The operator schema is ", toString(schema_));
9080
auto& k = found->second;
9181
k.erase(kernel);
9282
if (k.empty()) {
9383
// the invariant says we don't want empty lists but instead remove the list from the map
94-
kernels.erase(found);
84+
kernels_.erase(found);
9585
}
9686

9787
updateDispatchTable_(dispatch_key);
@@ -100,14 +90,7 @@ void OperatorEntry::deregisterKernel_(TensorTypeId dispatch_key, std::list<Dispa
10090
void OperatorEntry::deregisterCatchallKernel_(std::list<DispatchTableEntry>::iterator kernel) {
10191
std::unique_lock<std::mutex> lock(kernelsMutex_);
10292

103-
TORCH_CHECK(kernels_.is_right(), "Tried to deregister a catch-all kernel for an operator that doesn't have a catch-all kernel registered. The operator schema is ", toString(schema_));
104-
105-
auto& k = kernels_.right();
106-
k.erase(kernel);
107-
if (k.empty()) {
108-
// the invariant says that the empty state is represented with is_left()
109-
kernels_ = make_left<ska::flat_hash_map<TensorTypeId, std::list<DispatchTableEntry>>, std::list<DispatchTableEntry>>();
110-
}
93+
catchAllKernels_.erase(kernel);
11194

11295
updateCatchallDispatchTable_();
11396
}
@@ -150,12 +133,9 @@ void OperatorEntry::updateCurrentUnboxedAutogradKernel_() {
150133
void OperatorEntry::updateDispatchTable_(TensorTypeId dispatch_key) {
151134
// precondition: kernelsMutex_ is locked
152135

153-
TORCH_INTERNAL_ASSERT(kernels_.is_left(), "Can't update the dispatch table a dispatch key ", toString(dispatch_key), " because the operator only has catch-all kernels. The operator schema is ", toString(schema_));
154-
155-
auto& kernels = kernels_.left();
156-
auto k = kernels.find(dispatch_key);
136+
auto k = kernels_.find(dispatch_key);
157137

158-
if (k == kernels.end()) {
138+
if (k == kernels_.end()) {
159139
dispatchTable_.write([&] (DispatchTable& dispatchTable) {
160140
dispatchTable.removeKernelIfExists(dispatch_key);
161141
});
@@ -169,13 +149,13 @@ void OperatorEntry::updateDispatchTable_(TensorTypeId dispatch_key) {
169149
void OperatorEntry::updateCatchallDispatchTable_() {
170150
// precondition: kernelsMutex_ is locked
171151

172-
if (kernels_.is_left()) {
152+
if (catchAllKernels_.size() == 0) {
173153
dispatchTable_.write([&] (DispatchTable& dispatchTable) {
174154
dispatchTable.removeCatchallKernel();
175155
});
176156
} else {
177157
dispatchTable_.write([&] (DispatchTable& dispatchTable) {
178-
dispatchTable.setCatchallKernel(kernels_.right().front());
158+
dispatchTable.setCatchallKernel(catchAllKernels_.front());
179159
});
180160
}
181161
}

aten/src/ATen/core/dispatch/OperatorEntry.h

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class CAFFE2_API OpKernel final {
5454

5555
private:
5656
explicit OpKernel(KernelFunction* kernel, const KernelCacheCreatorFunction& cache_creator, void* unboxed_kernel)
57-
: kernel_(kernel), cache_(cache_creator ? cache_creator() : nullptr), unboxed_kernel_(unboxed_kernel) {}
57+
: kernel_(kernel), cache_(cache_creator ? cache_creator() : c10::guts::make_unique<c10::KernelCache>()), unboxed_kernel_(unboxed_kernel) {}
5858
friend class impl::OperatorEntry;
5959

6060
// All of these fields may be nullptr, but at least one of
@@ -120,14 +120,8 @@ class OperatorEntry final {
120120
// The dispatchTable stores the current kernel for each dispatch key
121121
LeftRight<DispatchTable> dispatchTable_;
122122

123-
// kernels_ is either:
124-
// left: a kernel map listing mapping from a dispatch key to a list of all
125-
// kernels for that operator, or it is
126-
// right: a list of all catch-all kernels registered for this operator.
127-
// An operator can only have either dispatched kernels or catch-all kernels,
128-
// not both.
129-
// In both cases, the list of kernels stores all registered kernels for the
130-
// corresponding dispatch key (or for catch-all).
123+
// kernels_ stores all registered kernels for the corresponding dispatch key
124+
// and catchAllKernels_ stores the catch-all kernels.
131125
// If an operator library gets loaded that overwrites an already existing kernel,
132126
// both kernels will be in that list but only the newer one will be in
133127
// dispatchTable. If any of the kernels go away (say the library gets
@@ -139,15 +133,13 @@ class OperatorEntry final {
139133
// kernels is a larger data structure and accessed quite infrequently
140134
// while dispatchTable is accessed often and should be kept small to fit
141135
// into CPU caches.
142-
// Invariants (assuming kernels_.is_left()):
143-
// - dispatchTable[dispatch_key] == kernels_.left()[dispatch_key].front()
136+
// Invariants:
137+
// - dispatchTable[dispatch_key] == kernels_[dispatch_key].front()
144138
// - dispatchTable[dispatch_key] does not exist if and only if
145-
// kernels_.left()[dispatch_key] does not exist
146-
// - If kernels_.left()[dispatch_key] exists, then it has elements.
139+
// kernels_[dispatch_key] does not exist
140+
// - If kernels_[dispatch_key] exists, then it has elements.
147141
// It is never an empty list.
148-
// Analogous invariants for kernels_.is_right().
149-
// The empty state (i.e. no kernels registered) is represented as an empty
150-
// map with kernels_.is_left().
142+
// Analogous invariants for catchAllKernels_.
151143
//
152144
// Why do we do that?
153145
// -----
@@ -160,10 +152,8 @@ class OperatorEntry final {
160152
// re-ececuted and then only allow one kernel here, i.e. error if a kernel
161153
// is already registered, but that's a lot of effort to implement and
162154
// currently not high-pri.
163-
c10::either<
164-
ska::flat_hash_map<TensorTypeId, std::list<DispatchTableEntry>>, // dispatched kernels
165-
std::list<DispatchTableEntry> // catch-all kernels
166-
> kernels_;
155+
ska::flat_hash_map<TensorTypeId, std::list<DispatchTableEntry>> kernels_;
156+
std::list<DispatchTableEntry> catchAllKernels_;
167157

168158
// unboxedAutogradKernels_ stores all autograd kernels registered for this op.
169159
// An autograd kernel has the same signature as the main op kernel and

aten/src/ATen/core/op_registration/op_registration.h

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -397,14 +397,6 @@ class CAFFE2_API RegisterOperators final {
397397
static bool op_is_still_on_aten_dispatcher_(const char* schema_string) {
398398
// TODO Remove this function once all aten ops are on c10
399399
const auto op_name = parse_operator_name_(schema_string);
400-
if (at::aten_ops_already_moved_to_c10().count(op_name) != 0) {
401-
// For now, even if an op is in aten_ops_already_moved_to_c10, it is still
402-
// not actually moved to c10. It is still on globalATenDispatch.
403-
// TODO This is be removed in a diff stacked on top, then this
404-
// function will only return true iff the op is in
405-
// aten_ops_not_moved_to_c10_yet
406-
return true;
407-
}
408400
return at::aten_ops_not_moved_to_c10_yet().count(op_name) != 0;
409401
}
410402

@@ -432,13 +424,21 @@ class CAFFE2_API RegisterOperators final {
432424

433425
template<class KernelFunctor, class... ConstructorParameters>
434426
Options&& kernelFunctorUnboxedOnly(c10::optional<TensorTypeId>&& dispatch_key, ConstructorParameters&&... constructorParameters) && {
427+
// Setting cache_creator to nullptr so calling the kernel doesn't need to call it, which would be expensive.
428+
// Since the dispatcher static_cast's cache objects into our functor type to call their operator(), this nullptr
429+
// will cause it to create and static_cast an invalid cache object, which is technically illegal in the C++ standard,
430+
// but it works as long as operator() does not access any functor members.
431+
// Exception: Backend extensions use runtime function pointers and store these in the functor as members,
432+
// so we need a cache if sizeof...(ConstructorParameters) != 0
433+
auto cache_creator =
434+
(sizeof...(ConstructorParameters) == 0)
435+
? KernelCacheCreatorFunction(nullptr)
436+
: detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>(std::forward<ConstructorParameters>(constructorParameters)...);
437+
435438
return std::move(*this).kernel(
436439
std::move(dispatch_key),
437440
nullptr,
438-
// setting cache creator to nullptr so calling the kernel doesn't need to call it, which would be expensive
439-
// This, however, only works if there are no constructor parameters (i.e. no runtime function pointer)
440-
// Backend extensions use runtime function pointers, so we need a cache if sizeof...(ConstructorParameters) != 0
441-
(sizeof...(ConstructorParameters) == 0) ? KernelCacheCreatorFunction(nullptr) : detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>(std::forward<ConstructorParameters>(constructorParameters)...),
441+
std::move(cache_creator),
442442
reinterpret_cast<void*>(&detail::wrap_kernel_functor_unboxed<KernelFunctor>::call),
443443
detail::FunctionSchemaInferer<KernelFunctor>()()
444444
);

aten/src/ATen/core/op_registration/op_registration_test.cpp

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -108,22 +108,23 @@ TEST(OperatorRegistrationTest, givenOpWithCatchallKernel_whenCallingOp_thenCalls
108108
EXPECT_TRUE(called);
109109
}
110110

111-
TEST(OperatorRegistrationTest, givenOpWithCatchallKernel_whenRegisteringDispatchedKernel_thenFails) {
112-
bool called = false;
113-
auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel<MockKernel>(&called));
114-
expectThrows<c10::Error>([&] {
115-
c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(c10::TensorTypeId::CPUTensorId, &called));
116-
}, "for an operator which already has a catch-all kernel registered");
117-
}
118-
119-
TEST(OperatorRegistrationTest, givenOpWithCatchallKernel_whenRegisteringDispatchedKernelInSameOpCall_thenFails) {
120-
bool called = false;
121-
expectThrows<c10::Error>([&] {
122-
auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
123-
.catchAllKernel<MockKernel>(&called)
124-
.kernel<MockKernel>(c10::TensorTypeId::CPUTensorId, &called));
125-
}, "for an operator which already has a catch-all kernel registered");
126-
}
111+
// TODO Rewrite (since this is now allowed) and reenable
112+
// TEST(OperatorRegistrationTest, givenOpWithCatchallKernel_whenRegisteringDispatchedKernel_thenFails) {
113+
// bool called = false;
114+
// auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel<MockKernel>(&called));
115+
// expectThrows<c10::Error>([&] {
116+
// c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(c10::TensorTypeId::CPUTensorId, &called));
117+
// }, "for an operator which already has a catch-all kernel registered");
118+
// }
119+
120+
// TEST(OperatorRegistrationTest, givenOpWithCatchallKernel_whenRegisteringDispatchedKernelInSameOpCall_thenFails) {
121+
// bool called = false;
122+
// expectThrows<c10::Error>([&] {
123+
// auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
124+
// .catchAllKernel<MockKernel>(&called)
125+
// .kernel<MockKernel>(c10::TensorTypeId::CPUTensorId, &called));
126+
// }, "for an operator which already has a catch-all kernel registered");
127+
// }
127128

128129
TEST(OperatorRegistrationTest, givenOpWithDispatchedKernelOutOfScope_whenRegisteringCatchallKernelAndCallingOp_thenCallsCatchallKernel) {
129130
bool called = false;
@@ -140,22 +141,23 @@ TEST(OperatorRegistrationTest, givenOpWithDispatchedKernelOutOfScope_whenRegiste
140141
EXPECT_TRUE(called);
141142
}
142143

143-
TEST(OperatorRegistrationTest, givenOpWithDispatchedKernel_whenRegisteringCatchallKernel_thenFails) {
144-
bool called = false;
145-
auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(c10::TensorTypeId::CPUTensorId, &called));
146-
expectThrows<c10::Error>([&] {
147-
c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel<MockKernel>(&called));
148-
}, "Tried to register a catch-all kernel for an operator which already has kernels for dispatch keys CPUTensorId. An operator can only have either a catch-all kernel or kernels with dispatch keys. The operator schema is _test::dummy");
149-
}
150-
151-
TEST(OperatorRegistrationTest, givenOpWithDispatchedKernel_whenRegisteringCatchallKernelInSameOpCall_thenFails) {
152-
bool called = false;
153-
expectThrows<c10::Error>([&] {
154-
auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
155-
.kernel<MockKernel>(c10::TensorTypeId::CPUTensorId, &called)
156-
.catchAllKernel<MockKernel>(&called));
157-
}, "Tried to register a catch-all kernel for an operator which already has kernels for dispatch keys CPUTensorId. An operator can only have either a catch-all kernel or kernels with dispatch keys. The operator schema is _test::dummy");
158-
}
144+
// TODO Rewrite (since this is now allowed) and reenable
145+
// TEST(OperatorRegistrationTest, givenOpWithDispatchedKernel_whenRegisteringCatchallKernel_thenFails) {
146+
// bool called = false;
147+
// auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel<MockKernel>(c10::TensorTypeId::CPUTensorId, &called));
148+
// expectThrows<c10::Error>([&] {
149+
// c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel<MockKernel>(&called));
150+
// }, "Tried to register a catch-all kernel for an operator which already has kernels for dispatch keys CPUTensorId. An operator can only have either a catch-all kernel or kernels with dispatch keys. The operator schema is _test::dummy");
151+
// }
152+
//
153+
// TEST(OperatorRegistrationTest, givenOpWithDispatchedKernel_whenRegisteringCatchallKernelInSameOpCall_thenFails) {
154+
// bool called = false;
155+
// expectThrows<c10::Error>([&] {
156+
// auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
157+
// .kernel<MockKernel>(c10::TensorTypeId::CPUTensorId, &called)
158+
// .catchAllKernel<MockKernel>(&called));
159+
// }, "Tried to register a catch-all kernel for an operator which already has kernels for dispatch keys CPUTensorId. An operator can only have either a catch-all kernel or kernels with dispatch keys. The operator schema is _test::dummy");
160+
// }
159161

160162
TEST(OperatorRegistrationTest, givenOpWithCatchallKernelOutOfScope_whenRegisteringDispatchedKernelAndCallingOp_thenCallsCatchallKernel) {
161163
bool called = false;

0 commit comments

Comments
 (0)