Skip to content

Commit 26d537d

Browse files
smessmerfacebook-github-bot
authored andcommitted
Remove unboxedAutogradKernel from c10 (#26130)
Summary: Pull Request resolved: #26130 Since we now just use TensorTypeId::VariableTensorId, there's no need to treat autograd kernels any differently. ghstack-source-id: 90130457 Test Plan: unit tests Differential Revision: D17353873 fbshipit-source-id: d4468506a5366bc5e7429144b090b3e78af9de62
1 parent 0e30e65 commit 26d537d

File tree

7 files changed

+31
-132
lines changed

7 files changed

+31
-132
lines changed

aten/src/ATen/core/dispatch/Dispatcher.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,6 @@ RegistrationHandleRAII Dispatcher::registerCatchallKernel(const OperatorHandle&
126126
return op.operatorIterator_->op.registerCatchallKernel(DispatchTableEntry{kernel_func, std::move(cache_creator_func), unboxed_kernel_func});
127127
}
128128

129-
RegistrationHandleRAII Dispatcher::registerUnboxedAutogradKernel(const OperatorHandle& op, void* unboxed_autograd_kernel) {
130-
// note: this doesn't need the mutex to protect the iterator because write operations on the list keep iterators intact.
131-
return op.operatorIterator_->op.registerUnboxedAutogradKernel(unboxed_autograd_kernel);
132-
}
133-
134129
void Dispatcher::addRegistrationListener(std::unique_ptr<OpRegistrationListener> listener) {
135130
std::lock_guard<std::mutex> lock(mutex_);
136131

aten/src/ATen/core/dispatch/Dispatcher.h

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,6 @@ class CAFFE2_API Dispatcher final {
9090
*/
9191
RegistrationHandleRAII registerCatchallKernel(const OperatorHandle& op, KernelFunction* kernel_func, KernelCacheCreatorFunction cache_creator_func, void* unboxed_kernel_func);
9292

93-
RegistrationHandleRAII registerUnboxedAutogradKernel(const OperatorHandle& op, void* unboxed_autograd_kernel);
94-
9593
/**
9694
* Perform a dynamic dispatch and get the kernel for an operator.
9795
*/
@@ -104,11 +102,6 @@ class CAFFE2_API Dispatcher final {
104102
// the (unboxed?) arguments the operator is to be called with.
105103
OpKernel lookup(const OperatorHandle& op, TensorTypeId dispatchKey) const;
106104

107-
// TODO Remove callUnboxedAutogradKernel() and instead figure out in a generic
108-
// callKernel() wrapper if the autograd or the regular kernel need to be called.
109-
template<class Result, class... Args>
110-
Result callUnboxedAutogradKernel(const OperatorHandle& op, Args... args) const;
111-
112105
/**
113106
* Add a listener that gets called whenever a new op is registered or an existing
114107
* op is deregistered. Immediately after registering, this listener gets called
@@ -183,14 +176,4 @@ inline OpKernel Dispatcher::lookup(const OperatorHandle& op, TensorTypeId dispat
183176
return op.operatorIterator_->op.lookupKernel(dispatchKey);
184177
}
185178

186-
template<class Result, class... Args>
187-
inline Result Dispatcher::callUnboxedAutogradKernel(const OperatorHandle& op, Args... args) const {
188-
void* unboxed_autograd_kernel = op.operatorIterator_->op.lookupUnboxedAutogradKernel();
189-
TORCH_CHECK(nullptr != unboxed_autograd_kernel, "Tried to call Dispatcher::callUnboxedAutogradKernel() for operator ", toString(op.schema()), " that doesn't have an autograd kernel.");
190-
191-
using OpSignature = Result (Args...);
192-
OpSignature* kernel = reinterpret_cast<OpSignature*>(unboxed_autograd_kernel);
193-
return (*kernel)(std::forward<Args>(args)...);
194-
}
195-
196179
} // namespace c10

aten/src/ATen/core/dispatch/OperatorEntry.cpp

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -95,41 +95,6 @@ void OperatorEntry::deregisterCatchallKernel_(std::list<DispatchTableEntry>::ite
9595
updateCatchallDispatchTable_();
9696
}
9797

98-
RegistrationHandleRAII OperatorEntry::registerUnboxedAutogradKernel(void* kernel_func) {
99-
std::unique_lock<std::mutex> lock(unboxedAutogradKernelsMutex_);
100-
101-
TORCH_INTERNAL_ASSERT(kernel_func != nullptr);
102-
103-
unboxedAutogradKernels_.push_front(kernel_func);
104-
std::list<void*>::iterator inserted = unboxedAutogradKernels_.begin();
105-
106-
updateCurrentUnboxedAutogradKernel_();
107-
108-
return RegistrationHandleRAII([this, inserted] {
109-
// list iterators stay valid even if the list changes,
110-
// so we can use the iterator to deregister the kernel from the list
111-
deregisterUnboxedAutogradKernel_(inserted);
112-
});
113-
}
114-
115-
void OperatorEntry::deregisterUnboxedAutogradKernel_(std::list<void*>::iterator kernel) {
116-
std::unique_lock<std::mutex> lock(unboxedAutogradKernelsMutex_);
117-
118-
unboxedAutogradKernels_.erase(kernel);
119-
120-
updateCurrentUnboxedAutogradKernel_();
121-
}
122-
123-
void OperatorEntry::updateCurrentUnboxedAutogradKernel_() {
124-
// precondition: unboxedAutogradKernelsMutex_ is locked
125-
126-
if (unboxedAutogradKernels_.empty()) {
127-
currentUnboxedAutogradKernel_ = nullptr;
128-
} else {
129-
currentUnboxedAutogradKernel_ = unboxedAutogradKernels_.front();
130-
}
131-
}
132-
13398
void OperatorEntry::updateDispatchTable_(TensorTypeId dispatch_key) {
13499
// precondition: kernelsMutex_ is locked
135100

aten/src/ATen/core/dispatch/OperatorEntry.h

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -95,25 +95,18 @@ class OperatorEntry final {
9595
});
9696
}
9797

98-
void* lookupUnboxedAutogradKernel() const {
99-
return currentUnboxedAutogradKernel_;
100-
}
101-
10298
void prepareForDeregistration();
10399

104100
RegistrationHandleRAII registerKernel(TensorTypeId dispatch_key, DispatchTableEntry kernel);
105101
RegistrationHandleRAII registerCatchallKernel(DispatchTableEntry kernel);
106102

107-
RegistrationHandleRAII registerUnboxedAutogradKernel(void* kernel_func);
108-
109103
const OperatorOptions& options() {
110104
return options_;
111105
}
112106

113107
private:
114108
void deregisterKernel_(TensorTypeId dispatch_key, std::list<DispatchTableEntry>::iterator kernel);
115109
void deregisterCatchallKernel_(std::list<DispatchTableEntry>::iterator kernel);
116-
void deregisterUnboxedAutogradKernel_(std::list<void*>::iterator kernel);
117110

118111
FunctionSchema schema_;
119112

@@ -155,33 +148,15 @@ class OperatorEntry final {
155148
ska::flat_hash_map<TensorTypeId, std::list<DispatchTableEntry>> kernels_;
156149
std::list<DispatchTableEntry> catchAllKernels_;
157150

158-
// unboxedAutogradKernels_ stores all autograd kernels registered for this op.
159-
// An autograd kernel has the same signature as the main op kernel and
160-
// internally re-dispatches to call the actual kernel.
161-
// Autograd kernels are unboxed currently. We are planning to move this
162-
// towards a system where ops register autograd wrappers (i.e. functions that
163-
// do some wrapping code and get a pointer to the actual kernel) instead of
164-
// autograd functions.
165-
// This is a list because, similar to kernels_, multiple libraries could
166-
// be loaded that register autograd kernels for the same op. The list is
167-
// ordered by registration time descendingly, i.e. newer registrations are
168-
// before older registrations and the list head is the autograd kernel
169-
// which is currently used.
170-
// See the comment for kernels_ above for an explanation for why we do this.
171-
std::list<void*> unboxedAutogradKernels_;
172-
std::atomic<void*> currentUnboxedAutogradKernel_;
173-
174151
// Some metadata about the operator
175152
OperatorOptions options_;
176153

177154
std::mutex kernelsMutex_; // protects kernels_
178-
std::mutex unboxedAutogradKernelsMutex_; // protects unboxedAutogradKernels_
179155

180156
// This function re-establishes the invariant that dispatchTable
181157
// contains the front element from the kernels list for a given dispatch key.
182158
void updateDispatchTable_(TensorTypeId dispatch_key);
183159
void updateCatchallDispatchTable_();
184-
void updateCurrentUnboxedAutogradKernel_();
185160
};
186161

187162
}

aten/src/ATen/core/op_registration/op_registration.cpp

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ static_assert(std::is_nothrow_move_assignable<c10::optional<RegistrationHandleRA
1212
// table deregisters it in the destructor.
1313
class RegisterOperators::OperatorRegistrar final {
1414
public:
15-
explicit OperatorRegistrar(FunctionSchema&& schema, OperatorOptions&& operatorOptions, c10::optional<TensorTypeId> dispatch_key, KernelFunction* kernel, KernelCacheCreatorFunction&& cache_creator, void* unboxed_kernel, void* unboxed_autograd_kernel)
15+
explicit OperatorRegistrar(FunctionSchema&& schema, OperatorOptions&& operatorOptions, c10::optional<TensorTypeId> dispatch_key, KernelFunction* kernel, KernelCacheCreatorFunction&& cache_creator, void* unboxed_kernel)
1616
: op_(Dispatcher::singleton().registerSchema(std::move(schema), std::move(operatorOptions))), kernel_registration_handle_(c10::nullopt) {
1717
// cache creator can only be set if the kernel is also set
1818
TORCH_INTERNAL_ASSERT((kernel != nullptr || unboxed_kernel != nullptr) || !static_cast<bool>(cache_creator));
@@ -24,10 +24,6 @@ class RegisterOperators::OperatorRegistrar final {
2424
kernel_registration_handle_ = Dispatcher::singleton().registerCatchallKernel(op_.opHandle(), kernel, std::move(cache_creator), unboxed_kernel);
2525
}
2626
}
27-
28-
if (unboxed_autograd_kernel != nullptr) {
29-
unboxed_autograd_kernel_registration_handle_ = Dispatcher::singleton().registerUnboxedAutogradKernel(op_.opHandle(), unboxed_autograd_kernel);
30-
}
3127
}
3228

3329
OperatorRegistrar(OperatorRegistrar&& rhs) noexcept = default;
@@ -40,7 +36,6 @@ class RegisterOperators::OperatorRegistrar final {
4036
private:
4137
c10::SchemaRegistrationHandleRAII op_;
4238
c10::optional<RegistrationHandleRAII> kernel_registration_handle_;
43-
c10::optional<RegistrationHandleRAII> unboxed_autograd_kernel_registration_handle_;
4439
};
4540

4641
void RegisterOperators::checkSchemaAndRegisterOp_(Options&& options) {
@@ -150,10 +145,10 @@ void RegisterOperators::registerOp_(Options&& options) {
150145
auto operatorOptions = makeOperatorOptions_(options);
151146

152147
if (0 == options.kernels.size()) {
153-
registerSchemaOnly_(std::move(schema), std::move(operatorOptions), options.unboxedAutogradKernel_);
148+
registerSchemaOnly_(std::move(schema), std::move(operatorOptions));
154149
} else {
155150
for (auto& kernel : options.kernels) {
156-
registerSchemaAndKernel_(schema, std::move(kernel), std::move(operatorOptions), options.unboxedAutogradKernel_);
151+
registerSchemaAndKernel_(schema, std::move(kernel), std::move(operatorOptions));
157152
}
158153
}
159154

@@ -168,14 +163,14 @@ OperatorOptions RegisterOperators::makeOperatorOptions_(const RegisterOperators:
168163
return result;
169164
}
170165

171-
void RegisterOperators::registerSchemaAndKernel_(FunctionSchema schema, Options::KernelRegistrationConfig&& kernel, OperatorOptions&& operatorOptions, void* unboxedAutogradKernel) {
166+
void RegisterOperators::registerSchemaAndKernel_(FunctionSchema schema, Options::KernelRegistrationConfig&& kernel, OperatorOptions&& operatorOptions) {
172167
TORCH_INTERNAL_ASSERT((kernel.kernel_func != nullptr || kernel.unboxed_kernel_func != nullptr), "Kernel must be set");
173168

174-
registrars_.emplace_back(std::move(schema), std::move(operatorOptions), kernel.dispatch_key, kernel.kernel_func, std::move(kernel.cache_creator_func), kernel.unboxed_kernel_func, unboxedAutogradKernel);
169+
registrars_.emplace_back(std::move(schema), std::move(operatorOptions), kernel.dispatch_key, kernel.kernel_func, std::move(kernel.cache_creator_func), kernel.unboxed_kernel_func);
175170
}
176171

177-
void RegisterOperators::registerSchemaOnly_(FunctionSchema&& schema, OperatorOptions&& operatorOptions, void* unboxedAutogradKernel) {
178-
registrars_.emplace_back(std::move(schema), std::move(operatorOptions), c10::nullopt, nullptr, nullptr, nullptr, unboxedAutogradKernel);
172+
void RegisterOperators::registerSchemaOnly_(FunctionSchema&& schema, OperatorOptions&& operatorOptions) {
173+
registrars_.emplace_back(std::move(schema), std::move(operatorOptions), c10::nullopt, nullptr, nullptr, nullptr);
179174
}
180175

181176
RegisterOperators::RegisterOperators() = default;

aten/src/ATen/core/op_registration/op_registration.h

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ class CAFFE2_API RegisterOperators final {
9696
TORCH_CHECK(!legacyATenSchema_.has_value(), "Tried to register operator ", schemaOrName," but specified schema multiple times. You can only specify the schema once per operator registration.");
9797

9898
if (Options::op_is_still_on_aten_dispatcher_(schemaOrName.c_str())) {
99-
TORCH_CHECK(unboxedAutogradKernel_ == nullptr, "For legacy aten ops, the schema() call must happen before any kernel() calls. Operator was ", schemaOrName);
10099
TORCH_CHECK(kernels.size() == 0, "For legacy aten ops, the schema() call must happen before any kernel() calls. Operator was ", schemaOrName);
101100
legacyATenSchema_ = schemaOrName;
102101
} else {
@@ -353,24 +352,6 @@ class CAFFE2_API RegisterOperators final {
353352
return std::move(*this);
354353
}
355354

356-
template<class FuncType>
357-
Options&& impl_unboxedAutogradKernel(FuncType* kernel) && {
358-
static_assert(guts::is_function_type<FuncType>::value, "Wrong argument type for impl_unboxedAutogradKernel");
359-
360-
// TODO Infer and check schema
361-
TORCH_CHECK(kernel != nullptr, "Kernel function pointer cannot be nullptr");
362-
TORCH_CHECK(unboxedAutogradKernel_ == nullptr, "You can only call impl_unboxedAutogradKernel() once per operator registration.");
363-
if (legacyATenSchema_.has_value()) {
364-
// TODO Remove this once all ops are moved to c10.
365-
TORCH_INTERNAL_ASSERT(!schemaOrName_.has_value());
366-
at::globalATenDispatch().registerOp<FuncType>(TensorTypeId::VariableTensorId, legacyATenSchema_->c_str(), kernel);
367-
return std::move(*this);
368-
} else {
369-
unboxedAutogradKernel_ = reinterpret_cast<void*>(kernel);
370-
return std::move(*this);
371-
}
372-
}
373-
374355
private:
375356
static c10::OperatorName parse_operator_name_(const char* schema) {
376357
// TODO Remove this function once all aten ops are on c10
@@ -474,7 +455,6 @@ class CAFFE2_API RegisterOperators final {
474455

475456
std::vector<KernelRegistrationConfig> kernels;
476457
optional<AliasAnalysisKind> aliasAnalysisKind_;
477-
void* unboxedAutogradKernel_; // can be nullptr, not all kernels have this
478458
friend class RegisterOperators;
479459
};
480460

@@ -599,8 +579,8 @@ class CAFFE2_API RegisterOperators final {
599579
static c10::FunctionSchema inferSchemaFromKernels_(const OperatorName& opNameStr, const Options& options);
600580
void checkNoDuplicateKernels_(const Options& options);
601581
void registerOp_(Options&& options);
602-
void registerSchemaAndKernel_(FunctionSchema schema, Options::KernelRegistrationConfig&& config, OperatorOptions&& options, void* unboxedAutogradKernel);
603-
void registerSchemaOnly_(FunctionSchema&& schema, OperatorOptions&& options, void* unboxedAutogradKernel);
582+
void registerSchemaAndKernel_(FunctionSchema schema, Options::KernelRegistrationConfig&& config, OperatorOptions&& options);
583+
void registerSchemaOnly_(FunctionSchema&& schema, OperatorOptions&& options);
604584
static OperatorOptions makeOperatorOptions_(const Options& options);
605585

606586
class OperatorRegistrar;

aten/src/ATen/core/op_registration/op_registration_test.cpp

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ using c10::RegisterOperators;
1919
using c10::OperatorKernel;
2020
using c10::Dispatcher;
2121
using c10::IValue;
22+
using c10::TensorTypeId;
2223
using at::Tensor;
2324

2425
namespace {
@@ -619,37 +620,42 @@ TEST(OperatorRegistrationTest, whenRegisteringMismatchingKernelsInSameOpCall_the
619620
}, "Tried to register kernels for same operator that infer a different function schema");
620621
}
621622

622-
int64_t increment_kernel(int64_t a) {
623-
return a + 1;
623+
bool called_autograd = false;
624+
bool called_catchall = false;
625+
626+
void catchall_kernel(Tensor a) {
627+
called_catchall = true;
624628
}
625629

626-
int64_t decrement_kernel(int64_t a) {
627-
return a - 1;
630+
void autograd_kernel(Tensor a) {
631+
called_autograd = true;
628632
}
629633

630634
TEST(OperatorRegistrationTest, whenRegisteringAutogradKernel_thenCanCallAutogradKernel) {
631-
auto registrar = c10::RegisterOperators().op("_test::dummy(int dummy) -> int", c10::RegisterOperators::options()
632-
.impl_unboxedAutogradKernel(&increment_kernel));
635+
auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
636+
.impl_unboxedOnlyKernel<decltype(autograd_kernel), &autograd_kernel>(TensorTypeId::VariableTensorId));
633637

634638
auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
635639
ASSERT_TRUE(op.has_value());
636-
int64_t result = c10::Dispatcher::singleton().callUnboxedAutogradKernel<int64_t, int64_t>(*op, 4);
637-
EXPECT_EQ(5, result);
640+
641+
called_autograd = false;
642+
c10::Dispatcher::singleton().lookup(*op, TensorTypeId::VariableTensorId).callUnboxed<void, Tensor>(dummyTensor(TensorTypeId::VariableTensorId));
643+
EXPECT_TRUE(called_autograd);
638644
}
639645

640646
TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithRegularKernel_thenCanCallAutogradKernel) {
641-
auto registrar = c10::RegisterOperators().op("_test::dummy(int dummy) -> int", c10::RegisterOperators::options()
642-
.catchAllKernel<decltype(decrement_kernel), &decrement_kernel>()
643-
.impl_unboxedAutogradKernel(&increment_kernel));
647+
auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
648+
.impl_unboxedOnlyCatchAllKernel<decltype(catchall_kernel), &catchall_kernel>()
649+
.impl_unboxedOnlyKernel<decltype(autograd_kernel), &autograd_kernel>(TensorTypeId::VariableTensorId));
644650

645651
auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
646652
ASSERT_TRUE(op.has_value());
647-
int64_t result = c10::Dispatcher::singleton().callUnboxedAutogradKernel<int64_t, int64_t>(*op, 4);
648-
EXPECT_EQ(5, result);
649-
}
650653

651-
// TODO Test cases that adding multiple autograd kernels, removing some, and so on works
652-
// (similar to test cases above for regular kernels "_whenNewerAndThenOlderKernelDeletedAndOpCalled")
654+
called_catchall = called_autograd = false;
655+
c10::Dispatcher::singleton().lookup(*op, TensorTypeId::VariableTensorId).callUnboxed<void, Tensor>(dummyTensor(TensorTypeId::VariableTensorId));
656+
EXPECT_FALSE(called_catchall);
657+
EXPECT_TRUE(called_autograd);
658+
}
653659

654660
/**
655661
* This is used to check that a given type works correctly when passed as input

0 commit comments

Comments
 (0)