-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Update __torch_dispatch__ to return op overload instead of the opoverload packet function #72673
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…load packet function [ghstack-poisoned]
CI Flow Status⚛️ CI FlowRuleset - Version:
|
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 54783e1 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
… the opoverload packet function" [ghstack-poisoned]
… the opoverload packet function" [ghstack-poisoned]
test/test_autograd.py
Outdated
| @classmethod | ||
| def __torch_dispatch__(cls, func, types, args=(), kwargs=None): | ||
| if func == torch.ops.aten.alias: | ||
| if func.overload_packet == torch.ops.aten.alias: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't this just be func == torch.ops.aten.alias[''] (or whatever the syntax is for the alias overload)? Ditto for the split example later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure so we could also change it to if func == torch.ops.aten.alias.default: (we use 'default' as the attribute to access '' overload name)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it took a while for us to make this BC-breaking change, it seems to me that we should introduce some goop to smooth over the transition. The simplest thing we can do is overload the meaning of == on overloads so that they test equal to their overload packet (perhaps throwing a deprecation warning saying that this behavior will get removed in a later version).
We also need to publish guidance on when you should match against the overload_packet (this is still such a weird name to me haha) as opposed to the overload directly. Hot take: matching against the overload packet is probably the wrong thing to do in most cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
discussed offline that it's ok to just pull the plug
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this. Two main things that we should resolve:
- Are the test failures real?
- cc @Chillee -- AOTAutograd can be modified to work with this, right?
… the opoverload packet function" [ghstack-poisoned]
| kwargs.ptr(), | ||
| "detach", | ||
| py::module::import("torch").attr("ops").attr("aten").attr("detach").ptr(), | ||
| py::module::import("torch").attr("ops").attr("aten").attr("detach").attr("default").ptr(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
huh, I'm surprised this works (granted, I didn't read the opoverload packet PR in detail haha). Does this mean that default is now a reserved keyword for overload names? Do we have tests that ensure that no one tries to name an overload default?
When I look at tools/codegen, I don't see any according logic:
ezyang-mbp:pytorch-tmp ezyang$ git grep \"default tools/codegen
tools/codegen/dest/register_dispatch_key.py: // so this "default" kernel doesn't actually handle backends that don't support at::empty
ezyang-mbp:pytorch-tmp ezyang$ git grep \'default tools/codegen
tools/codegen/gen.py: arg['default'] = cpp_a.default
tools/codegen/gen.py: arg['default'] = pythonify_default(cpp.default_expr(a.default, a.type))
tools/codegen/gen.py: 'default': str(f.has_composite_kernel or has_autogenerated_composite_kernel(f))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang I added checks in the schema parser to disable default as an overload name in the schema parser to also disable that when registering through the RegisterOperators API https://github.com/pytorch/pytorch/pull/72206/files
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I recommend also synchronizing the parser in model.py to also have the same checking; otherwise you get the error only at runtime when you load torch. Doesn't have to be this PR.
|
This broadly looks good to me, although I see there are a number of failing tests. Most of my questions have to do with the prior PR. |
… the opoverload packet function" [ghstack-poisoned]
torch/_ops.py
Outdated
| @property | ||
| def overload_name(self): | ||
| return self._schema.overload_name | ||
| def name(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does anyone have a name suggestion for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually would suggest this as __str__ and your current __str__ become __repr__
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but usually repr should print something that can be used to construct the class object right? This is not true for the current __str__ return
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This property is currently called qualified_op_name for the OpOverloadPacket btw.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok so I changed it to __str__ as you suggested and added angular brackets to the string returned by __repr__ following the official python guideline: https://docs.python.org/3/library/functions.html#repr
torch/_ops.py
Outdated
| self._op = op | ||
| self._schema = schema | ||
| self._overloadpacket = overloadpacket | ||
| self.__name__ = 'default' if schema.overload_name is '' else schema.overload_name |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@albanD does this look ok?
… the opoverload packet function" [ghstack-poisoned]
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure about the torch function handling, looks good otherwise.
… the opoverload packet function" 1. __torch_dispatch__ now returns the OpOverload instead of the OpOverloadPacket to Python. 2. FX can trace the overloads (OpOverload objects) like torch.ops.aten.add.Tensor op, for example ``` import torch.fx as fx def f(x): return torch.ops.aten.add.Tensor(x, x) print(fx.symbolic_trace(f).code) ``` Differential Revision: [D34627164](https://our.internmc.facebook.com/intern/diff/D34627164) [ghstack-poisoned]
|
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
… the opoverload packet function" 1. __torch_dispatch__ now returns the OpOverload instead of the OpOverloadPacket to Python. 2. FX can trace the overloads (OpOverload objects) like torch.ops.aten.add.Tensor op, for example ``` import torch.fx as fx def f(x): return torch.ops.aten.add.Tensor(x, x) print(fx.symbolic_trace(f).code) ``` Differential Revision: [D34627164](https://our.internmc.facebook.com/intern/diff/D34627164) [ghstack-poisoned]
|
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
… the opoverload packet function" 1. __torch_dispatch__ now returns the OpOverload instead of the OpOverloadPacket to Python. 2. FX can trace the overloads (OpOverload objects) like torch.ops.aten.add.Tensor op, for example ``` import torch.fx as fx def f(x): return torch.ops.aten.add.Tensor(x, x) print(fx.symbolic_trace(f).code) ``` Differential Revision: [D34627164](https://our.internmc.facebook.com/intern/diff/D34627164) [ghstack-poisoned]
|
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
… the opoverload packet function" 1. __torch_dispatch__ now returns the OpOverload instead of the OpOverloadPacket to Python. 2. FX can trace the overloads (OpOverload objects) like torch.ops.aten.add.Tensor op, for example ``` import torch.fx as fx def f(x): return torch.ops.aten.add.Tensor(x, x) print(fx.symbolic_trace(f).code) ``` Differential Revision: [D34627164](https://our.internmc.facebook.com/intern/diff/D34627164) [ghstack-poisoned]
|
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
… the opoverload packet function" 1. __torch_dispatch__ now returns the OpOverload instead of the OpOverloadPacket to Python. 2. FX can trace the overloads (OpOverload objects) like torch.ops.aten.add.Tensor op, for example ``` import torch.fx as fx def f(x): return torch.ops.aten.add.Tensor(x, x) print(fx.symbolic_trace(f).code) ``` Differential Revision: [D34627164](https://our.internmc.facebook.com/intern/diff/D34627164) [ghstack-poisoned]
|
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
… the opoverload packet function" 1. __torch_dispatch__ now returns the OpOverload instead of the OpOverloadPacket to Python. 2. FX can trace the overloads (OpOverload objects) like torch.ops.aten.add.Tensor op, for example ``` import torch.fx as fx def f(x): return torch.ops.aten.add.Tensor(x, x) print(fx.symbolic_trace(f).code) ``` Differential Revision: [D34627164](https://our.internmc.facebook.com/intern/diff/D34627164) [ghstack-poisoned]
|
@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
|
Hey @anjali411. |
…load packet function (#72673) Summary: Pull Request resolved: pytorch/pytorch#72673 Test Plan: Imported from OSS Reviewed By: mruberry Differential Revision: D34627164 Pulled By: anjali411 fbshipit-source-id: 3cb6406a392d530bf9da36b4d8e0a62b30e6497e (cherry picked from commit 65b85a0a67df4d0f16ac8964e2b685d478a610fb)
…load packet function (#72673) Summary: Pull Request resolved: pytorch/pytorch#72673 Test Plan: Imported from OSS Reviewed By: mruberry Differential Revision: D34627164 Pulled By: anjali411 fbshipit-source-id: 3cb6406a392d530bf9da36b4d8e0a62b30e6497e (cherry picked from commit 65b85a0a67df4d0f16ac8964e2b685d478a610fb)
Stack from ghstack:
BC Breaking change
update
func == torch.ops.aten.footofunc.overloadpacket == torch.ops.aten.foo(if you want to perform the logical equivalent of the old check)func == torch.ops.aten.foo.my_overload(if you want to add overload specific behavior)Differential Revision: D34627164