-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add prim::EnumName and prim::EnumValue ops #41965
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gmagogsfm has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gmagogsfm has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
💊 CI failures summary and remediationsAs of commit 804329f (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 20 times. |
SplitInfinity
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool!
test/jit/test_enum.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # TODO(gmagogsfm): Re-enanble hooks when serialization/deserialization | |
| # TODO(gmagogsfm): Re-enable hooks when serialization/deserialization |
torch/csrc/jit/ir/ir.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding an assert here to check that e's type is an enum and that enum_val_type is an enum type. Actually, I think you can pass in only e and derive enum_val_type from it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you choose to pass in enum_val_type explicitly instead of deriving it from e? Am I looking at an old version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly to follow convention and to save one 'cast'. But I guess the second reason is too minute and voided by having an assert there. Done.
test/jit/test_enum.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def enum_comp(x: Color) -> int: | |
| def enum_value(x: Color) -> int: |
test/jit/test_enum.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| scripted_enum_comp = torch.jit.script(enum_comp) | |
| self.assertEqual(scripted_enum_comp(Color.RED), Color.RED.value) | |
| self.assertEqual(scripted_enum_comp(Color.GREEN), Color.GREEN.value) | |
| scripted_enum_value = torch.jit.script(enum_value) | |
| self.assertEqual(scripted_enum_value(Color.RED), Color.RED.value) | |
| self.assertEqual(scripted_enum_value(Color.GREEN), Color.GREEN.value) |
test/jit/test_enum.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| scripted_enum_comp = torch.jit.script(enum_comp) | |
| self.assertEqual(scripted_enum_comp(Color.RED), Color.RED.name) | |
| self.assertEqual(scripted_enum_comp(Color.GREEN), Color.GREEN.name) | |
| scripted_enum_name = torch.jit.script(enum_name) | |
| self.assertEqual(scripted_enum_name(Color.RED), Color.RED.name) | |
| self.assertEqual(scripted_enum_name(Color.GREEN), Color.GREEN.name) |
test/jit/test_enum.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def enum_comp(x: Color) -> str: | |
| def enum_name(x: Color) -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The types of the IR values corresponding to enum and the return value are more narrow than AnyEnumType and Any. I think there is nothing we can do about AnyEnumType because those types are user-defined, but is there a restriction on underlying storage type? Can we use that to define more specific operators? E.g.
prim::EnumValue(AnyEnumType enum) -> int
prim::EnumValue(AnyEnumType enum) -> float
prim::EnumValue(AnyEnumType enum) -> bool (lol)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great point! Done.
aten/src/ATen/core/jit_type.h
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're not including bool as a type here, but we are registering a kernel for it
torch/csrc/jit/ir/ir.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, we have two types of assert's we use: TORCH_CHECK, and TORCH_INTERNAL_ASSERT. The first is for user errors and the second is for developer invariants. In this case, it would be a developer invariant. But you don't need to use either here because we have e->type()->expect<EnumType>(); which will throw if it's not the right type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, replaced with expect()
torch/csrc/jit/ir/ir.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove assert and use -type()->expect
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, thanks. Didn't know that. Done.
torch/csrc/jit/ir/ir.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't actually need this custom create logic, bc there is only one possible schema for prim::EnumName, so you can use the normal schema driven insert logic:
https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/ir/ir.h#L1181
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, but it is actually nice to have a method to create it, more discoverable and we can enforce type checking here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is at least worth a discussion, because the duplication of the kernel has some binary implications. However, I think not-standard ops (unschematized ops) have some other complications around mobile and stuff so I think this is probably worth it. @ljk53 what do you think ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the "binary implication" you are referring to with dup kernels?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Each operator increases the size of the mobile build (if they're included in a mobile model). In this case, we could unify all of the prim::EnumValue operators but at the cost of having to make it a non-standard op where the schema doesnt express its true types. I think our current decision is good.
|
LGTM !! I just have one comment for the mobile folks and a couple nits. Might be good to hear in from @ljk53 before continuing, but it's a pretty minor decision either way so i think we're also fine to continue |
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!! This a kind of minor note, but right now when we call unify_types we unify to Optional[int] etc, but we didnt register a kernel for those options. I think it's fine if we don't handle optional, we can just throw an error in that case
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gmagogsfm has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gmagogsfm has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Thanks for the review. Do you mean this? I think it will become an error when constructing EnumType in IR, so we should be safe the error should be clear enough. Or do you mean something else? |
|
@gmagogsfm merged this pull request in 8e03c38. |
[2/N] Implement Enum JIT support
Add prim::EnumName and prim::EnumValue and their lowerings to support getting
nameandvalueattribute of Python enums.Supported:
Enum-typed function targuments
using Enum type and comparing them
Support getting name/value attrs of enums
TODO:
Add PyThon sugared value for Enum
Support Enum-typed return values
Support enum values of different types in same Enum class
Support serialization and deserialization