-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[BC-BREAKING] Replace Type dispatch with ATenDispatch #21320
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3ec6148
799330a
897abf3
5c15cdb
75b66cd
3e0aaad
7472cdc
11ccf9a
2babd58
cef4290
33fb37e
579dd3f
6840d9b
f495d97
98fb2ee
f1da595
d098e00
afcf125
d7f144d
bd29609
629371b
97ed318
c168c0f
a1e8382
cf1523d
68f6d9d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| #include <ATen/core/ATenDispatch.h> | ||
|
|
||
| namespace at { | ||
|
|
||
| ATenDispatch & globalATenDispatch() { | ||
| static ATenDispatch singleton; | ||
| return singleton; | ||
| } | ||
|
|
||
| } // namespace at | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,106 @@ | ||
| #pragma once | ||
li-roy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| #include <c10/core/Backend.h> | ||
| #include <unordered_map> | ||
| #include <c10/util/C++17.h> | ||
| #include <memory> | ||
| #include <mutex> | ||
|
|
||
| // This dispatch class serves as a replacement for our previous dispatch | ||
| // mechanism, in which all functions were members of a Type class. A derived | ||
| // class existed for each backend (and Variable), and the vtable was used to | ||
| // dispatch to the correct implementation. This class is to be replaced by | ||
| // the c10 dispatcher when it supports all argument and return types. | ||
| // This implementation opts to store implementations in a table of void*. | ||
|
|
||
| namespace at { | ||
|
|
||
| // ATenOpTable stores the implementations for each backend, in addition to | ||
| // an implementation for variables. | ||
| class CAFFE2_API ATenOpTable { | ||
| public: | ||
| ATenOpTable(std::string schema) | ||
| : schema_(std::move(schema)) {} | ||
|
|
||
| template<class FuncType> | ||
| FuncType* getOp(Backend backend, bool is_variable) const { | ||
| if (is_variable) { | ||
| return reinterpret_cast<FuncType*>(getVariableOp()); | ||
| } | ||
| return reinterpret_cast<FuncType*>(getBaseOp(backend)); | ||
| } | ||
| private: | ||
| void registerOp(Backend backend, void* fn) { | ||
| TORCH_CHECK(function_table_[static_cast<int64_t>(backend)] == nullptr, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Internal asserts here? (I guess if you're going to expose directly to extensions an internal assert is not appropriate, but for internal use, this isn't a public API right)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's no type checking in the registration API right now; as long as we don't have any type checking I have a pretty strong preference of NOT allowing external parties to poke this API. I don't know what the situation is with XLA and this diff at the moment. (The reason for this preference is that if XLA uses this directly, and we start changing the types of functions, they'll start segfaulting. Ick!) |
||
| "Attempting to register variable function for schema ", schema_, | ||
| " and backend ", toString(backend), | ||
| " but there is already a function registered"); | ||
| function_table_[static_cast<int64_t>(backend)] = fn; | ||
| } | ||
|
|
||
| void registerVariableOp(void* fn) { | ||
| TORCH_CHECK(variable_function_ == nullptr, | ||
| "Attempting to register variable function for schema ", schema_, | ||
| " but there is already a function registered"); | ||
| variable_function_ = fn; | ||
| } | ||
|
|
||
| void* getBaseOp(Backend backend) const { | ||
| if (function_table_[static_cast<int64_t>(backend)] == nullptr) { | ||
| TORCH_CHECK(function_table_[static_cast<int64_t>(Backend::Undefined)] != nullptr, | ||
| "No function is registered for schema ", schema_, " on backend ", toString(backend)); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Re error message: not only is no function not registered, but there is no undefined implementation registered either. Actually, why are we dispatching to |
||
| return function_table_[static_cast<int64_t>(Backend::Undefined)]; | ||
| } | ||
| return function_table_[static_cast<int64_t>(backend)]; | ||
| } | ||
|
|
||
| void* getVariableOp() const { | ||
| TORCH_CHECK(variable_function_ != nullptr, | ||
| "No variable function registered for ", schema_); | ||
| return variable_function_; | ||
| } | ||
|
|
||
| friend class ATenDispatch; | ||
|
|
||
| std::string schema_; | ||
| void* function_table_[static_cast<int64_t>(Backend::NumOptions)] = {nullptr}; | ||
| void* variable_function_ = nullptr; | ||
| }; | ||
|
|
||
| class CAFFE2_API ATenDispatch { | ||
| public: | ||
| template<class FuncType> | ||
| ATenDispatch& registerOp(Backend backend, const char* schema, FuncType* fn) { | ||
| std::lock_guard<std::mutex> lock(mutex_); | ||
| if (op_tables_.find(schema) == op_tables_.end()) { | ||
| op_tables_.insert(std::make_pair(schema, ATenOpTable(schema))); | ||
| } | ||
| op_tables_.at(schema).registerOp(backend, reinterpret_cast<void*>(fn)); | ||
| return *this; | ||
| } | ||
|
|
||
| template <class FuncType> | ||
| ATenDispatch& registerVariableOp(const char* schema, FuncType* fn) { | ||
| std::lock_guard<std::mutex> lock(mutex_); | ||
| if (op_tables_.find(schema) == op_tables_.end()) { | ||
| op_tables_.insert(std::make_pair(schema, ATenOpTable(schema))); | ||
| } | ||
| op_tables_.at(schema).registerVariableOp(reinterpret_cast<void*>(fn)); | ||
| return *this; | ||
| } | ||
|
|
||
| const ATenOpTable* getOpTable(const char* schema) const { | ||
| auto iter = op_tables_.find(schema); | ||
| TORCH_CHECK(iter != op_tables_.end(), | ||
| "No functions are registered for schema ", schema); | ||
| return &iter->second; | ||
| } | ||
|
|
||
| private: | ||
| std::unordered_map<std::string, ATenOpTable> op_tables_; | ||
| std::mutex mutex_; | ||
| }; | ||
|
|
||
| CAFFE2_API ATenDispatch& globalATenDispatch(); | ||
|
|
||
| } // namespace at | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -59,8 +59,7 @@ class CAFFE2_API LegacyTypeDispatch { | |
| } | ||
| Type * getNonVariableTypeOpt(Backend p, ScalarType s) { | ||
| if (p != Backend::Undefined) { | ||
| initForDeviceType(backendToDeviceType(p)); | ||
| initForScalarType(s); | ||
| initForBackend(p); | ||
| } | ||
| auto type = getNonVariableTypeRaw(p, s); | ||
|
|
||
|
|
@@ -103,10 +102,11 @@ class CAFFE2_API LegacyTypeDispatch { | |
| type_registry[static_cast<int>(b)] = std::move(t); | ||
| detail::getVariableHooks().registerVariableTypeFor(this, b); | ||
| } | ||
| private: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why'd this become public?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For now, I'm calling this directly from Functions.h for all factory functions. |
||
| void initForDeviceType(DeviceType p) { | ||
| void initForBackend(Backend b) { | ||
| auto p = backendToDeviceType(b); | ||
| static std::once_flag cpu_once; | ||
| static std::once_flag cuda_once; | ||
| static std::once_flag complex_once; | ||
| if (p == DeviceType::CPU) { | ||
| std::call_once(cpu_once, [] { | ||
| getLegacyDeviceTypeInit().initCPU(); | ||
|
|
@@ -120,17 +120,13 @@ class CAFFE2_API LegacyTypeDispatch { | |
| getLegacyDeviceTypeInit().initHIP(); | ||
| }); | ||
| } | ||
| } | ||
| void initForScalarType(ScalarType s) { | ||
| static std::once_flag once; | ||
| // Only complex may need initialization | ||
| if (isComplexType(s)) { | ||
| std::call_once(once, [] { | ||
| if (b == Backend::ComplexCPU || b == Backend::ComplexCUDA) { | ||
| std::call_once(complex_once, [] { | ||
| getLegacyDeviceTypeInit().initComplex(); | ||
| }); | ||
| } | ||
| } | ||
|
|
||
| private: | ||
| // NB: type_registry has nullptr for all CUDA backends until | ||
| // CUDA initialization has occurred | ||
| TypeUniquePtr type_registry | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Registering to the global table is not thread safe, right? Gotta be careful: library loads can happen in different threads.