-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[JIT] Improve class type annotation inference #45940
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
**Summary** In `try_ann_to_type`, if an annotation has an attribute named `__torch_script_class__`, it is assumed to be a TorchScript class that has already been scripted. However, if it is a class that extends another class, this code path causes a crash because it looks up the JIT type for the class by name in the compilation unit. This JIT type obviously cannot exist because inheritance is not supported. **Test Plan** This commit adds a unit test consisting of the code sample from the issue that reported this problem. **Fixes** This commit fixes #45860. [ghstack-poisoned]
**Summary** In `try_ann_to_type`, if an annotation has an attribute named `__torch_script_class__`, it is assumed to be a TorchScript class that has already been scripted. However, if it is a class that extends another class, this code path causes a crash because it looks up the JIT type for the class by name in the compilation unit. This JIT type obviously cannot exist because inheritance is not supported. **Test Plan** This commit adds a unit test consisting of the code sample from the issue that reported this problem. **Fixes** This commit fixes #45860. ghstack-source-id: d90df83 Pull Request resolved: #45940
torch/jit/annotations.py
Outdated
| return EnumType(_qualified_name(ann), get_enum_value_type(ann, loc), list(ann)) | ||
| if inspect.isclass(ann): | ||
| if hasattr(ann, "__torch_script_class__"): | ||
| if hasattr(ann, "__torch_script_class__") and len(ann.mro()) <= 2: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Won't this break torch.no_grad ? I think it would be more intuitive to switch to storing map of {python class pointer, jit class} instead of annotating the class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have a map from qualified_name -> cls (Python class, not JIT type) in _state.py for all classes that have truly been scripted; what if we changed this to
| if hasattr(ann, "__torch_script_class__") and len(ann.mro()) <= 2: | |
| if torch.jit._state._get_script_class(_qualified_name(ann)): |
In the example that motivated this fix, A would be in _state._script_classes, but B would not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea sgtm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although I realize now that inheritance is prohibited if the class is scripted directly, but allowed if it is scripted implicitly (e.g. through try_ann_to_type) and doesn't use any features of inheritance (like super()). This seems inconsistent and awkward, but as you pointed out, no_grad relies on this. 😞
**Summary** In `try_ann_to_type`, if an annotation has an attribute named `__torch_script_class__`, it is assumed to be a TorchScript class that has already been scripted. However, if it is a class that extends another class, this code path causes a crash because it looks up the JIT type for the class by name in the compilation unit. This JIT type obviously cannot exist because inheritance is not supported. This commit fixes this by looking up the qualified name of a class in torch.jit._state._script_class in order to ascertain whether it has already been scripted (instead of looking for a `__torch_script_class__` attribute on the class object. **Test Plan** This commit adds a unit test consisting of the code sample from the issue that reported this problem. **Fixes** This commit fixes #45860. [ghstack-poisoned]
**Summary** In `try_ann_to_type`, if an annotation has an attribute named `__torch_script_class__`, it is assumed to be a TorchScript class that has already been scripted. However, if it is a class that extends another class, this code path causes a crash because it looks up the JIT type for the class by name in the compilation unit. This JIT type obviously cannot exist because inheritance is not supported. This commit fixes this by looking up the qualified name of a class in torch.jit._state._script_class in order to ascertain whether it has already been scripted (instead of looking for a `__torch_script_class__` attribute on the class object. **Test Plan** This commit adds a unit test consisting of the code sample from the issue that reported this problem. **Fixes** This commit fixes #45860. ghstack-source-id: 90cfb48 Pull Request resolved: #45940
Codecov Report
@@ Coverage Diff @@
## gh/splitinfinity/58/base #45940 +/- ##
============================================================
+ Coverage 68.25% 68.27% +0.02%
============================================================
Files 410 410
Lines 53611 53248 -363
============================================================
- Hits 36593 36357 -236
+ Misses 17018 16891 -127
Continue to review full report at Codecov.
|
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, two other invocations we need to update i think
**Summary** In `try_ann_to_type`, if an annotation has an attribute named `__torch_script_class__`, it is assumed to be a TorchScript class that has already been scripted. However, if it is a class that extends another class, this code path causes a crash because it looks up the JIT type for the class by name in the compilation unit. This JIT type obviously cannot exist because inheritance is not supported. This commit fixes this by looking up the qualified name of a class in torch.jit._state._script_class in order to ascertain whether it has already been scripted (instead of looking for a `__torch_script_class__` attribute on the class object. **Test Plan** This commit adds a unit test consisting of the code sample from the issue that reported this problem. **Fixes** This commit fixes #45860. [ghstack-poisoned]
**Summary** In `try_ann_to_type`, if an annotation has an attribute named `__torch_script_class__`, it is assumed to be a TorchScript class that has already been scripted. However, if it is a class that extends another class, this code path causes a crash because it looks up the JIT type for the class by name in the compilation unit. This JIT type obviously cannot exist because inheritance is not supported. This commit fixes this by looking up the qualified name of a class in torch.jit._state._script_class in order to ascertain whether it has already been scripted (instead of looking for a `__torch_script_class__` attribute on the class object. **Test Plan** This commit adds a unit test consisting of the code sample from the issue that reported this problem. **Fixes** This commit fixes #45860. ghstack-source-id: 6fe19a4 Pull Request resolved: #45940
💊 CI failures summary and remediationsAs of commit f401180 (more details on the Dr. CI page):
1 failure not recognized by patterns:
This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 1 times. |
|
The test failure is a mypy type check failure in FX. |
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
@SplitInfinity merged this pull request in 75bf5f2. |
1 similar comment
|
@SplitInfinity merged this pull request in 75bf5f2. |
**Summary** In `try_ann_to_type`, if an annotation has an attribute named `__torch_script_class__`, it is assumed to be a TorchScript class that has already been scripted. However, if it is a class that extends another class, this code path causes a crash because it looks up the JIT type for the class by name in the compilation unit. This JIT type obviously cannot exist because inheritance is not supported. This commit fixes this by looking up the qualified name of a class in torch.jit._state._script_class in order to ascertain whether it has already been scripted (instead of looking for a `__torch_script_class__` attribute on the class object. **Test Plan** This commit adds a unit test consisting of the code sample from the issue that reported this problem. **Fixes** This commit fixes #45860. ghstack-source-id: 6fe19a4 Pull Request resolved: #45940
**Summary** In `try_ann_to_type`, if an annotation has an attribute named `__torch_script_class__`, it is assumed to be a TorchScript class that has already been scripted. However, if it is a class that extends another class, this code path causes a crash because it looks up the JIT type for the class by name in the compilation unit. This JIT type obviously cannot exist because inheritance is not supported. This commit fixes this by looking up the qualified name of a class in torch.jit._state._script_class in order to ascertain whether it has already been scripted (instead of looking for a `__torch_script_class__` attribute on the class object. **Test Plan** This commit adds a unit test consisting of the code sample from the issue that reported this problem. **Fixes** This commit fixes #45860. ghstack-source-id: 6fe19a4 Pull Request resolved: #45940
Stack from ghstack:
Summary
In
try_ann_to_type, if an annotation has an attribute named__torch_script_class__, it is assumed to be a TorchScript class thathas already been scripted. However, if it is a class that extends
another class, this code path causes a crash because it looks up the
JIT type for the class by name in the compilation unit. This JIT type
obviously cannot exist because inheritance is not supported.
This commit fixes this by looking up the qualified name of a class
in torch.jit._state._script_class in order to ascertain whether it has
already been scripted (instead of looking for a
__torch_script_class__attribute on the class object.
Test Plan
This commit adds a unit test consisting of the code sample from the
issue that reported this problem.
Fixes
This commit fixes #45860.
Differential Revision: D24310027