-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[JIT] Add support for default args in class types #42988
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
**Summary** This commit adds support for default args in methods of TorchScript classes. Default values are already represented in the JIT TreeView representation of function (in the `defaultValue` member of the `Param` subclass of `TreeView`); this commit adds code to parse out the default values and pack them into the `TreeView` created for the class definition. **Test Plan** This commit adds a unit test to `TestClassType` to test this feature. `python test/test_jit.py TestClassType.test_default_args` **Fixes** This commit fixes #42562. [ghstack-poisoned]
**Summary** This commit adds support for default args in methods of TorchScript classes. Default values are already represented in the JIT TreeView representation of function (in the `defaultValue` member of the `Param` subclass of `TreeView`); this commit adds code to parse out the default values and pack them into the `TreeView` created for the class definition. **Test Plan** This commit adds a unit test to `TestClassType` to test this feature. `python test/test_jit.py TestClassType.test_default_args` **Fixes** This commit fixes #42562. ghstack-source-id: e3fc9c6 Pull Request resolved: #42988
💊 CI failures summary and remediationsAs of commit e522fb2 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 154 times. |
**Summary** This commit adds support for default args in methods of TorchScript classes. Default values are already represented in the JIT TreeView representation of function (in the `defaultValue` member of the `Param` subclass of `TreeView`); this commit adds code to parse out the default values and pack them into the `TreeView` created for the class definition. **Test Plan** This commit adds a unit test to `TestClassType` to test this feature. `python test/test_jit.py TestClassType.test_default_args` **Fixes** This commit fixes #42562. [ghstack-poisoned]
**Summary** This commit adds support for default args in methods of TorchScript classes. Default values are already represented in the JIT TreeView representation of function (in the `defaultValue` member of the `Param` subclass of `TreeView`); this commit adds code to parse out the default values and pack them into the `TreeView` created for the class definition. **Test Plan** This commit adds a unit test to `TestClassType` to test this feature. `python test/test_jit.py TestClassType.test_default_args` **Fixes** This commit fixes #42562. [ghstack-poisoned]
**Summary** This commit adds support for default args in methods of TorchScript classes. Default values are already represented in the JIT TreeView representation of function (in the `defaultValue` member of the `Param` subclass of `TreeView`); this commit adds code to parse out the default values and pack them into the `TreeView` created for the class definition. **Test Plan** This commit adds a unit test to `TestClassType` to test this feature. `python test/test_jit.py TestClassType.test_default_args` **Fixes** This commit fixes #42562. ghstack-source-id: 8c3e916 Pull Request resolved: #42988
**Summary** This commit adds support for default args in methods of TorchScript classes. Default values are already represented in the JIT TreeView representation of function (in the `defaultValue` member of the `Param` subclass of `TreeView`); this commit adds code to parse out the default values and pack them into the `TreeView` created for the class definition. **Test Plan** This commit adds a unit test to `TestClassType` to test this feature. `python test/test_jit.py TestClassType.test_default_args` **Fixes** This commit fixes #42562. [ghstack-poisoned]
**Summary** This commit adds support for default args in methods of TorchScript classes. Default values are already represented in the JIT TreeView representation of function (in the `defaultValue` member of the `Param` subclass of `TreeView`); this commit adds code to parse out the default values and pack them into the `TreeView` created for the class definition. **Test Plan** This commit adds a unit test to `TestClassType` to test this feature. `python test/test_jit.py TestClassType.test_default_args` **Fixes** This commit fixes #42562. [ghstack-poisoned]
**Summary** This commit adds support for default args in methods of TorchScript classes. Default values are already represented in the JIT TreeView representation of function (in the `defaultValue` member of the `Param` subclass of `TreeView`); this commit adds code to parse out the default values and pack them into the `TreeView` created for the class definition. **Test Plan** This commit adds a unit test to `TestClassType` to test this feature. `python test/test_jit.py TestClassType.test_default_args` **Fixes** This commit fixes #42562. ghstack-source-id: 9af3faf Pull Request resolved: #42988
**Summary** This commit adds support for default args in methods of TorchScript classes. Default values are already represented in the JIT TreeView representation of function (in the `defaultValue` member of the `Param` subclass of `TreeView`); this commit adds code to parse out the default values and pack them into the `TreeView` created for the class definition. **Test Plan** This commit adds a unit test to `TestClassType` to test this feature. `python test/test_jit.py TestClassType.test_default_args` **Fixes** This commit fixes #42562. [ghstack-poisoned]
**Summary** This commit adds support for default args in methods of TorchScript classes. Default values are already represented in the JIT TreeView representation of function (in the `defaultValue` member of the `Param` subclass of `TreeView`); this commit adds code to parse out the default values and pack them into the `TreeView` created for the class definition. **Test Plan** This commit adds a unit test to `TestClassType` to test this feature. `python test/test_jit.py TestClassType.test_default_args` **Fixes** This commit fixes #42562. ghstack-source-id: c7fea05 Pull Request resolved: #42988
**Summary** This commit adds support for default args in methods of TorchScript classes. Default values are already represented in the JIT TreeView representation of function (in the `defaultValue` member of the `Param` subclass of `TreeView`); this commit adds code to parse out the default values and pack them into the `TreeView` created for the class definition. **Test Plan** This commit adds a unit test to `TestClassType` to test this feature. `python test/test_jit.py TestClassType.test_default_args` **Fixes** This commit fixes #42562. [ghstack-poisoned]
**Summary** This commit adds support for default args in methods of TorchScript classes. Default values are already represented in the JIT TreeView representation of function (in the `defaultValue` member of the `Param` subclass of `TreeView`); this commit adds code to parse out the default values and pack them into the `TreeView` created for the class definition. **Test Plan** This commit adds a unit test to `TestClassType` to test this feature. `python test/test_jit.py TestClassType.test_default_args` **Fixes** This commit fixes #42562. ghstack-source-id: 085c698 Pull Request resolved: #42988
**Summary** This commit adds support for default args in methods of TorchScript classes. Default values are already represented in the JIT TreeView representation of function (in the `defaultValue` member of the `Param` subclass of `TreeView`); this commit adds code to parse out the default values and pack them into the `TreeView` created for the class definition. This behaviour is gated by a new argument `parse_defaults` that has been added to several functions in `frontend.py` to make sure that this new behaviour is triggered only when TorchScript classes are being compiled. The same code is used to compile functions and interfaces, and default arguments are not expected to be parsed during compilation of these entities. **Test Plan** This commit adds a unit test to `TestClassType` to test this feature. `python test/test_jit.py TestClassType.test_default_args` **Fixes** This commit fixes #42562. [ghstack-poisoned]
test/jit/test_class_type.py
Outdated
| return a.get_int() + a.get_list()[2] + a.get_tup()[1] | ||
|
|
||
| def some_defaults() -> int: | ||
| a: ClassWithDefaultArgs = ClassWithDefaultArgs(b=[5, 6, 7]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
a is both an attr name and a var name here, a bit confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noted, will change the object variable name to obj
**Summary** This commit adds support for default args in methods of TorchScript classes. Default values are already represented in the JIT TreeView representation of function (in the `defaultValue` member of the `Param` subclass of `TreeView`); this commit adds code to parse out the default values and pack them into the `TreeView` created for the class definition. This behaviour is gated by a new argument `parse_defaults` that has been added to several functions in `frontend.py` to make sure that this new behaviour is triggered only when TorchScript classes are being compiled. The same code is used to compile functions and interfaces, and default arguments are not expected to be parsed during compilation of these entities. **Test Plan** This commit adds a unit test to `TestClassType` to test this feature. `python test/test_jit.py TestClassType.test_default_args` **Fixes** This commit fixes #42562. [ghstack-poisoned]
**Summary** This commit adds support for default args in methods of TorchScript classes. Default values are already represented in the JIT TreeView representation of function (in the `defaultValue` member of the `Param` subclass of `TreeView`); this commit adds code to parse out the default values and pack them into the `TreeView` created for the class definition. This behaviour is gated by a new argument `parse_defaults` that has been added to several functions in `frontend.py` to make sure that this new behaviour is triggered only when TorchScript classes are being compiled. The same code is used to compile functions and interfaces, and default arguments are not expected to be parsed during compilation of these entities. **Test Plan** This commit adds a unit test to `TestClassType` to test this feature. `python test/test_jit.py TestClassType.test_default_args` **Fixes** This commit fixes #42562. ghstack-source-id: 127dd00 Pull Request resolved: #42988
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
didnt do full review, just comment here
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think about the PR here ? https://github.com/pytorch/pytorch/pull/31344/files
I think it's more or less what I have here. However, one thing I don't see there is checks for mutable function defaults, which, judging by the failures on this PR, happen quite a bit. |
**Summary** This commit adds support for default args in methods of TorchScript classes. Default values are already represented in the JIT TreeView representation of function (in the `defaultValue` member of the `Param` subclass of `TreeView`); this commit adds code to parse out the default values and pack them into the `TreeView` created for the class definition. This behaviour is gated by a new argument `parse_defaults` that has been added to several functions in `frontend.py` to make sure that this new behaviour is triggered only when TorchScript classes are being compiled. The same code is used to compile functions and interfaces, and default arguments are not expected to be parsed during compilation of these entities. **Test Plan** This commit adds a unit test to `TestClassType` to test this feature. `python test/test_jit.py TestClassType.test_default_args` **Fixes** This commit fixes #42562. [ghstack-poisoned]
**Summary** This commit adds support for default args in methods of TorchScript classes. Default values are already represented in the JIT TreeView representation of function (in the `defaultValue` member of the `Param` subclass of `TreeView`); this commit adds code to parse out the default values and pack them into the `TreeView` created for the class definition. This behaviour is gated by a new argument `parse_defaults` that has been added to several functions in `frontend.py` to make sure that this new behaviour is triggered only when TorchScript classes are being compiled. The same code is used to compile functions and interfaces, and default arguments are not expected to be parsed during compilation of these entities. **Test Plan** This commit adds a unit test to `TestClassType` to test this feature. `python test/test_jit.py TestClassType.test_default_args` **Fixes** This commit fixes #42562. [ghstack-poisoned]
**Summary** This commit adds support for default args in methods of TorchScript classes. Default values are already represented in the JIT TreeView representation of function (in the `defaultValue` member of the `Param` subclass of `TreeView`); this commit adds code to parse out the default values and pack them into the `TreeView` created for the class definition. This behaviour is gated by a new argument `parse_defaults` that has been added to several functions in `frontend.py` to make sure that this new behaviour is triggered only when TorchScript classes are being compiled. The same code is used to compile functions and interfaces, and default arguments are not expected to be parsed during compilation of these entities. **Test Plan** This commit adds a unit test to `TestClassType` to test this feature. `python test/test_jit.py TestClassType.test_default_args` **Fixes** This commit fixes #42562. [ghstack-poisoned]
**Summary** This commit adds support for default args in methods of TorchScript classes. Default values are already represented in the JIT TreeView representation of function (in the `defaultValue` member of the `Param` subclass of `TreeView`); this commit adds code to parse out the default values and pack them into the `TreeView` created for the class definition. This behaviour is gated by a new argument `parse_defaults` that has been added to several functions in `frontend.py` to make sure that this new behaviour is triggered only when TorchScript classes are being compiled. The same code is used to compile functions and interfaces, and default arguments are not expected to be parsed during compilation of these entities. **Test Plan** This commit adds a unit test to `TestClassType` to test this feature. `python test/test_jit.py TestClassType.test_default_args` **Fixes** This commit fixes #42562. ghstack-source-id: 87817f1 Pull Request resolved: #42988
**Summary** This commit adds support for default args in methods of TorchScript classes. Default values are already represented in the JIT TreeView representation of function (in the `defaultValue` member of the `Param` subclass of `TreeView`); this commit adds code to parse out the default values and pack them into the `TreeView` created for the class definition. This behaviour is gated by a new argument `parse_defaults` that has been added to several functions in `frontend.py` to make sure that this new behaviour is triggered only when TorchScript classes are being compiled. The same code is used to compile functions and interfaces, and default arguments are not expected to be parsed during compilation of these entities. **Test Plan** This commit adds a unit test to `TestClassType` to test this feature. `python test/test_jit.py TestClassType.test_default_args` **Fixes** This commit fixes #42562. [ghstack-poisoned]
**Summary** This commit adds support for default args in methods of TorchScript classes. Default values are already represented in the JIT TreeView representation of function (in the `defaultValue` member of the `Param` subclass of `TreeView`); this commit adds code to parse out the default values and pack them into the `TreeView` created for the class definition. This behaviour is gated by a new argument `parse_defaults` that has been added to several functions in `frontend.py` to make sure that this new behaviour is triggered only when TorchScript classes are being compiled. The same code is used to compile functions and interfaces, and default arguments are not expected to be parsed during compilation of these entities. **Test Plan** This commit adds a unit test to `TestClassType` to test this feature. `python test/test_jit.py TestClassType.test_default_args` **Fixes** This commit fixes #42562. ghstack-source-id: 111eb48 Pull Request resolved: #42988
**Summary** This commit adds support for default args in methods of TorchScript classes. Default values are already represented in the JIT TreeView representation of function (in the `defaultValue` member of the `Param` subclass of `TreeView`); this commit adds code to parse out the default values and pack them into the `TreeView` created for the class definition. This behaviour is gated by a new argument `parse_defaults` that has been added to several functions in `frontend.py` to make sure that this new behaviour is triggered only when TorchScript classes are being compiled. The same code is used to compile functions and interfaces, and default arguments are not expected to be parsed during compilation of these entities. **Test Plan** This commit adds a unit test to `TestClassType` to test this feature. `python test/test_jit.py TestClassType.test_default_args` **Fixes** This commit fixes #42562. [ghstack-poisoned]
**Summary** This commit adds support for default args in methods of TorchScript classes. Default values are already represented in the JIT TreeView representation of function (in the `defaultValue` member of the `Param` subclass of `TreeView`); this commit adds code to parse out the default values and pack them into the `TreeView` created for the class definition. This behaviour is gated by a new argument `parse_defaults` that has been added to several functions in `frontend.py` to make sure that this new behaviour is triggered only when TorchScript classes are being compiled. The same code is used to compile functions and interfaces, and default arguments are not expected to be parsed during compilation of these entities. **Test Plan** This commit adds a unit test to `TestClassType` to test this feature. `python test/test_jit.py TestClassType.test_default_args` **Fixes** This commit fixes #42562. [ghstack-poisoned]
**Summary** This commit adds support for default args in methods of TorchScript classes. Default values are already represented in the JIT TreeView representation of function (in the `defaultValue` member of the `Param` subclass of `TreeView`); this commit adds code to parse out the default values and pack them into the `TreeView` created for the class definition. This behaviour is gated by a new argument `parse_defaults` that has been added to several functions in `frontend.py` to make sure that this new behaviour is triggered only when TorchScript classes are being compiled. The same code is used to compile functions and interfaces, and default arguments are not expected to be parsed during compilation of these entities. **Test Plan** This commit adds a unit test to `TestClassType` to test this feature. `python test/test_jit.py TestClassType.test_default_args` **Fixes** This commit fixes #42562. [ghstack-poisoned]
**Summary** This commit adds support for default args in methods of TorchScript classes. Default values are already represented in the JIT TreeView representation of function (in the `defaultValue` member of the `Param` subclass of `TreeView`); this commit adds code to parse out the default values and pack them into the `TreeView` created for the class definition. This behaviour is gated by a new argument `parse_defaults` that has been added to several functions in `frontend.py` to make sure that this new behaviour is triggered only when TorchScript classes are being compiled. The same code is used to compile functions and interfaces, and default arguments are not expected to be parsed during compilation of these entities. **Test Plan** This commit adds a unit test to `TestClassType` to test this feature. `python test/test_jit.py TestClassType.test_default_args` **Fixes** This commit fixes #42562. ghstack-source-id: b0398ec Pull Request resolved: #42988
|
Okay, I think it's ready. Major changes since last review:
|
Codecov Report
@@ Coverage Diff @@
## gh/splitinfinity/32/base #42988 +/- ##
=========================================================
Coverage 67.95% 67.96%
=========================================================
Files 384 384
Lines 49768 49785 +17
=========================================================
+ Hits 33820 33836 +16
- Misses 15948 15949 +1
Continue to review full report at Codecov.
|
|
Ping @eellison and @gmagogsfm There are a lot of internal failures but I sampled a handful and they seem to be ones where the argument is annotated as |
|
@SplitInfinity sure, i'll take a look today. It's a big PR... If it's not too much difficulty, it's probably worth adding special handling so that we can take an int default for a float argument. |
|
|
|
What's an example of the custom type one ? Yea probably just int as float is fine. |
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't do an in-depth review yet, wanted to sync up on some aspects of the approach.
So, as far as I understand, the reason we're switching a lot of the mechanics of how we get and generate default values is because we no longer have access to an object instance of a class, as we did with nn.Modules. (please point out if this is not true).
I want to point out that this doesn't actually prohibit us from getting the default values:
closed_over_var = 1
class A(object):
def forward(self, b=closed_over_var, c=4):
return a
print(A.forward.__defaults__)
# (1, 4)
This also has the nice property that we don't have to regress the scriptability of cases above with closed_over_var. The default value is bound at initialization time so there is no (semantic) reason why we couldn't support this case.
I think using the python value would help with the complexity/brittleness of this stuff
.value("ListLiteralKind", TK_LIST_LITERAL)
.value("DictLiteralKind", TK_DICT_LITERAL)
.value("TupleLiteralKind", TK_TUPLE_LITERAL)
.value("VarKind", TK_VAR)
Which atm wouldn't handle stuff like a default value'd enum, or other future things that get added.
One other question I had was - we do all of the stitching/evaluation of python defaults in the script_init.cpp file. Could we add a check for mutable values there ? Since this is only the python compilation path it shouldn't get affected by the broadcasting list defs in native_functions.yaml
I understand this is a difficult PR to land that is a really nice improvement in functionality. We can consider what aspects we consider necessary for fixing and which would just be nice to have, and then decide what to land with after.
| ) = instantiator.get_arg_return_types_from_interface(MyModuleInterface) | ||
| self.assertEqual(args_str, "tensor, number, word") | ||
| self.assertEqual(arg_types_str, "tensor: Tensor, number: int, word: str") | ||
| self.assertEqual(arg_types_str, "tensor: Tensor, number: int, word: str = \"default\"") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why was this change needed ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is some test that checks that a schema is parsed in a specific way and the test was written in a way that didn't include defaults (because they weren't being parsed).
|
|
||
| @torch.jit.script | ||
| def bool_fn(x, a=outer_c, flag=outer_flag): | ||
| def bool_fn(x, a: Tensor = torch.tensor(9), flag: Tensor = torch.tensor(False)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why was this change needed ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was needed due to issue with closing over variables used as default argument values.
| """ | ||
| @torch.jit.script | ||
| def foo(x=torch.ones(1)): | ||
| def foo(x: Tensor = torch.ones(1)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why was this change needed ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that the defaults are being parsed from the AST, there's no way to infer types for the default argument expressions without making a small graph containing them and running it.
Yeah, I think we could pack all those defaults into a some map and pass it into
I think we could, but we would do the check after compilation. By doing the check in |
I don't think this is a compelling reason for the extra complexity and maintenance. How often do people use default values that we're optimizing for their failure case? It would be just as easy to iterate over the default values before compilation in the other approach.
Yea, it's true that this PR does add that functionality, which is really nice. It does regress other aspects of scripting though (closing over named constants and int as floats is the big one). I would prefer if we tried to improve that aspect of scripting in a separate PR. |
|
I just posted #45098 that addresses the same GitHub issue as this PR but with the same approach that we use now for methods of modules and functions (getting the Python objects for all of the default args, converting them to IValues and attaching them to the schemas of the class methods). Maybe it's better to land that instead of this one? |
|
F |
Stack from ghstack:
Summary
This commit adds support for default args in methods of TorchScript
classes. Default values are already represented in the JIT TreeView
representation of function (in the
defaultValuemember of theParamsubclass of
TreeView); this commit adds code to parse out the defaultvalues and pack them into the
TreeViewcreated for the class definition.This behaviour is gated by a new argument
parse_defaultsthat has beenadded to several functions in
frontend.pyto make sure that this newbehaviour is triggered only when TorchScript classes are being compiled.
The same code is used to compile functions and interfaces, and default
arguments are not expected to be parsed during compilation of these
entities.
Test Plan
This commit adds a unit test to
TestClassTypeto test this feature.python test/test_jit.py TestClassType.test_default_argsFixes
This commit fixes #42562.
Differential Revision: D23762013