-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[BC-BREAKING] Replace Type dispatch with ATenDispatch #21320
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Replace Type dispatch with ATenDispatch gh-metadata: pytorch pytorch 21320 gh/li-roy/25/head
Replace Type dispatch with ATenDispatch gh-metadata: pytorch pytorch 21320 gh/li-roy/25/head
Replace Type dispatch with ATenDispatch gh-metadata: pytorch pytorch 21320 gh/li-roy/25/head
Replace Type dispatch with ATenDispatch gh-metadata: pytorch pytorch 21320 gh/li-roy/25/head
|
Do you want initial review now or wait until it's not WIP? |
|
@ezyang sorry this is ready for review, i forgot to change the title. |
|
cc @ailzhang What's the plan for coordinating this diff with XLA? |
|
UBSAN seems to have caught an error: |
| type_registry[static_cast<int>(b)] = std::move(t); | ||
| detail::getVariableHooks().registerVariableTypeFor(this, b); | ||
| } | ||
| private: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why'd this become public?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now, I'm calling this directly from Functions.h for all factory functions.
| } | ||
| } | ||
|
|
||
| // example |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's blocking removing Type.h entirely?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few functions like unsafeTensorFromTH are still on Type and haven't found a new home yet.
| } | ||
|
|
||
| Tensor & VariableType::detach_(Tensor & self) const { | ||
| Tensor & VariableType::detach_(Tensor & (*_op)(Tensor &), Tensor & self) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is pretty weird, especially since you never actually call _op here. Are you sure this is the pattern you want for "Variable-only" operations?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this is fine.
| template<class FuncType> | ||
| ATenDispatch& registerOp(Backend backend, const char* schema, FuncType* fn) { | ||
| auto id = getSchemaId(schema); | ||
| function_table[static_cast<int64_t>(backend)][id] = reinterpret_cast<void*>(fn); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need a reinterpret_cast to cast a pointer into a void pointer. (It's only needed the other way.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So... no error checking, I guess?
|
|
||
| if (is_variable) { | ||
| if (wrapper_table[id] == nullptr) { | ||
| AT_ERROR("No autograd wrapper is registered for ", name, ". Please report a bug to PyTorch."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is really a bug you should use TORCH_INTERNAL_ASSERT
| int64_t getSchemaId(std::string schema) { | ||
| static std::unordered_map<std::string, int64_t> schema_to_id = { | ||
| ${schema_to_id_pairs} | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zdevito Do you know if a static variable like this takes a long time to compile? Also even though this isn't perf sensitive code, we should really avoid using std::unordered_map
Replace Type dispatch with ATenDispatch gh-metadata: pytorch pytorch 21320 gh/li-roy/25/head
aten/src/ATen/core/ATenDispatch.h
Outdated
| class CAFFE2_API ATenOpTable { | ||
| public: | ||
| ATenOpTable(std::string schema) | ||
| : schema_(schema) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: std::move(schema)
| } | ||
| private: | ||
| void registerOp(Backend backend, void* fn) { | ||
| TORCH_CHECK(function_table_[static_cast<int64_t>(backend)] == nullptr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Internal asserts here? (I guess if you're going to expose directly to extensions an internal assert is not appropriate, but for internal use, this isn't a public API right)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no type checking in the registration API right now; as long as we don't have any type checking I have a pretty strong preference of NOT allowing external parties to poke this API. I don't know what the situation is with XLA and this diff at the moment. (The reason for this preference is that if XLA uses this directly, and we start changing the types of functions, they'll start segfaulting. Ick!)
| void* getBaseOp(Backend backend) const { | ||
| if (function_table_[static_cast<int64_t>(backend)] == nullptr) { | ||
| TORCH_CHECK(function_table_[static_cast<int64_t>(Backend::Undefined)] != nullptr, | ||
| "No function is registered for schema ", schema_, " on backend ", toString(backend)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re error message: not only is no function not registered, but there is no undefined implementation registered either.
Actually, why are we dispatching to Undefined at all, at this point? Can't we just error out straight up at this point?
|
|
||
| namespace at { | ||
|
|
||
| ATenDispatch & globalATenDispatch() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Registering to the global table is not thread safe, right? Gotta be careful: library loads can happen in different threads.
ezyang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not going to block landing this on account of lack of type checking, but if XLA is going to use this interface I would be veeeeery careful. But I don't know what the XLA plan is at the moment.
|
The current path to XLA is that there is an include file generated, which provides a matching between C++ signature, and the magic-string pytorch generates for table based registrations. |
Replace Type dispatch with ATenDispatch gh-metadata: pytorch pytorch 21320 gh/li-roy/25/head
Replace Type dispatch with ATenDispatch gh-metadata: pytorch pytorch 21320 gh/li-roy/25/head
Replace Type dispatch with ATenDispatch gh-metadata: pytorch pytorch 21320 gh/li-roy/25/head
Replace Type dispatch with ATenDispatch gh-metadata: pytorch pytorch 21320 gh/li-roy/25/head
Replace Type dispatch with ATenDispatch gh-metadata: pytorch pytorch 21320 gh/li-roy/25/head
Replace Type dispatch with ATenDispatch gh-metadata: pytorch pytorch 21320 gh/li-roy/25/head
Replace Type dispatch with ATenDispatch gh-metadata: pytorch pytorch 21320 gh/li-roy/25/head
Summary: Pull Request resolved: pytorch/pytorch#21320 ghimport-source-id: cc18f746a1c74df858cb0f6d8b7d4de4315683c7 Test Plan: Imported from OSS Differential Revision: D15637222 Pulled By: li-roy fbshipit-source-id: fcfaea0b5480ab966175341cce92e3aa0be7e3cb
Stack from ghstack:
Replace Type object based dispatch with a table of void*. This is bc breaking because it will break cpp extensions that extend Type.
This gives us a registration system that has a similar API to c10 dispatch. Backend extensions can be folded into use this API directly rather than having its own dispatch system. This also saves us a dispatch because Type dispatch was doing a dynamic dispatch and a vtable dispatch, while this table only does a dynamic dispatch.
As a follow up to this change, we can clean up a lot of code involving Type.
Outline of this change:
Benchmarks are a bit noisy because we're timing such a small op, but I ran benchmarks on a null op, with and without variable unwrapping.
Before, without variable unwrapping: ~365ns
After, without variable unwrapping: ~330ns
Before, with variable unwrapping: ~450ns
After, with variable unwrapping: ~420ns
Differential Revision: D15637222