-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[jit] reuse stubs_fn whenever possible to create submodule #43872
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This PR allows the recursive scripting to have a separate submodule_stubs_fn to create its submodule with specific user provided rules. Fixes #43729 [ghstack-poisoned]
Codecov Report
@@ Coverage Diff @@
## gh/wanchaol/125/base #43872 +/- ##
=======================================================
Coverage ? 67.85%
=======================================================
Files ? 384
Lines ? 49917
Branches ? 0
=======================================================
Hits ? 33871
Misses ? 16046
Partials ? 0 Continue to review full report at Codecov.
|
This PR allows the recursive scripting to have a separate submodule_stubs_fn to create its submodule with specific user provided rules. Fixes #43729 Differential Revision: [D23430176](https://our.internmc.facebook.com/intern/diff/D23430176) [ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 652f786 (more details on the Dr. CI page): ✅ None of the CI failures appear to be your fault 💚
🚧 1 fixed upstream failure:These were probably caused by upstream breakages that were already fixed.
Please rebase on the
|
torch/jit/_recursive.py
Outdated
| concrete_type._create_methods(defs, rcbs, defaults) | ||
|
|
||
| def create_script_module(nn_module, stubs_fn, share_types=True): | ||
| def create_script_module(nn_module, stubs_fn, submodule_stubs_fn=None, share_types=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to add this extra argument? Is there ever going to be a case where we don't use all the functions that are marked with _jit_internal.FunctionModifiers.EXPORT ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before this PR, the rule is: submodule will be compiled using the default inference rule, which only contains the forward method + export methods, so export methods will always be compiled before/after this PR. This argument is added because we don't want tracing to compile the forward method, we want the tracing to compile the export methods, but still trace the forward calls
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this what stubs_fn is already for ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
stubs_fn is only used for the current module compilation, when we are doing recursive compilation for the current module's submodule, we are using infer_methods_to_compile as the stubs_fn, which disallow us from compiling the submodule with different rules. I was trying to also use the stubs_fn as the rule for submodule compilation to replace infer_methods_to_compile, but it failed because we need to maintain the legacy ScriptModule api where submodule don't have cls._methods available. See
Line 203 in 7816d53
| ] = torch.jit._recursive.create_script_module(self, make_stubs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copied from chat:
we are creating ScriptModule using make_stubs function, and the current ScriptModule initialization use cls._methods as stubs_fn infer rule, but that ScriptModule's submodule is compiled using recursive compilation, if we just use stubs_fn as the submodule method infer rule, the submodule compilation also needs cls._methods, which is not available as cls._methods only available in the current ScriptModule (we construct cls._methods in ScriptMeta.__init__
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR looks good, trying to understand current code organization
|
Can we clarify the context here? I see that this is about something involving mixing tracing and scripting, but what specific use case are we trying to make work? The added test case covers functionality that should already exist today (when tracing a module, we should attempt to compile any |
@suo commented with a detailed code snippet in the issue #43729 (comment) |
This PR allows the recursive scripting to have a separate submodule_stubs_fn to create its submodule with specific user provided rules. Fixes #43729 Differential Revision: [D23430176](https://our.internmc.facebook.com/intern/diff/D23430176) [ghstack-poisoned]
This PR allows the recursive scripting to have a separate submodule_stubs_fn to create its submodule with specific user provided rules. Fixes #43729 Differential Revision: [D23430176](https://our.internmc.facebook.com/intern/diff/D23430176) [ghstack-poisoned]
… allow reuse rule for submodule" This PR allows the recursive scripting to reuse stubs_fn to create its submodule. Fixes #43729 Differential Revision: [D23430176](https://our.internmc.facebook.com/intern/diff/D23430176) [ghstack-poisoned]
… allow reuse rule for submodule" when compiling submodule, we always use the default method inference rule, which includes `forward` and methods it calls, this is also the case for tracing. For tracing we don't want to compile the forward by default for submodule. This PR allows the recursive scripting to reuse stubs_fn to create its submodule instead of alwaysing using the default rule. Fixes #43729 Differential Revision: [D23430176](https://our.internmc.facebook.com/intern/diff/D23430176) [ghstack-poisoned]
…e rule for submodule when compiling submodule, we always use the default method inference rule, which includes `forward` and methods it calls, this is also the case for tracing. For tracing we don't want to compile the forward by default for submodule. This PR allows the recursive scripting to reuse stubs_fn to create its submodule instead of alwaysing using the default rule. Fixes #43729 ghstack-source-id: 2201f7b Pull Request resolved: #43872
torch/jit/_recursive.py
Outdated
| return concrete_type | ||
|
|
||
| def create_script_module(nn_module, stubs_fn, share_types=True): | ||
| def create_script_module(nn_module, stubs_fn, share_types=True, reuse_stubs_fn=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I can tell, reuse_stubs_fn is always True right? Because either:
- You are tracing, and you want to keep tracing the forwards of the other modules
- You are scripting, and you are using the default
stubs_fn.
Either way, the stubs_fn is the same within a given recursive run, right? If that's the case, can we just remove this option and ensure that we always re-use the same stubs fn?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While this is almost true and it was the methodology that I was trying, but unfortunately it fails, because we have one exception: the Legacy ScriptModule compilation, if you look at here
Line 203 in 7816d53
| ] = torch.jit._recursive.create_script_module(self, make_stubs) |
make_stubs rule to compile the legacy ScriptModule as entry point, but inside the recursive run, we are using the default infer_methods_to_compile rule for submodule compilation. I think this is because cls._methods.items() is not available in ScriptModule's submodule bc that might be a plain nn.Module and don't have cls._methods crafted upon construction.
In fact, if I always use the same stubs_fn, it will crash on some tests and complain something like this:
File "/home/wanchaol/pytorch/torch/jit/_recursive.py", line 364, in create_script_module_impl
method_stubs = stubs_fn(nn_module)
File "/home/wanchaol/pytorch/torch/jit/_script.py", line 202, in make_stubs
return [v for k, v in sorted(cls._methods.items())]
AttributeError: type object 'Over' has no attribute '_methods'
I think we should either: 1. use the approach I proposed here, so that we don't reuse_stubs_fn in legacyScriptModule 2. disable recursive compilation in legacy ScriptModule API, only allow user to construct their submodule with all modules inheriting from torch.jit.ScriptModule (or say, torch.jit.ScriptModule should only contain submodule with inheriting from torch.jit.ScriptModule.)
I would prefer the first option bc it does not BC break the behavior that we already exposed to the user.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general, I would prefer to add more to our legacy path shims, rather than change the current behavior just because the legacy path has some constraints. That way we can keep the current behavior clean and add the complexity in the legacy path.
In this case, that would mean: change the custom make_stubs used in the the legacy ScriptModule compilation to fork its behavior, so it would look like:
def make_stubs:
if I am compiled the legacy way:
look in _methods to make stubs
else:
use the regular `infer_methods_to_compile`
when compiling submodule, we always use the default method inference rule, which includes `forward` and methods it calls, this is also the case for tracing. For tracing we don't want to compile the forward by default for submodule. This PR allows the recursive scripting to reuse stubs_fn to create its submodule instead of alwaysing using the default rule. So three cases here: 1. tracing: only script exported methods, still trace forward and its recursive calls 2. scripting: always reuse stubs_fn 3. legacy ScriptModule: recursive compile use default infer_methods rule, if the module is a scriptModule, find methods from self._methods Fixes #43729 Differential Revision: [D23430176](https://our.internmc.facebook.com/intern/diff/D23430176) [ghstack-poisoned]
when compiling submodule, we always use the default method inference rule, which includes `forward` and methods it calls, this is also the case for tracing. For tracing we don't want to compile the forward by default for submodule. This PR allows the recursive scripting to reuse stubs_fn to create its submodule instead of alwaysing using the default rule. So three cases here: 1. tracing: only script exported methods, still trace forward and its recursive calls 2. scripting: always reuse stubs_fn 3. legacy ScriptModule: recursive compile use default infer_methods rule, if the module is a scriptModule, find methods from self._methods Fixes #43729 Differential Revision: [D23430176](https://our.internmc.facebook.com/intern/diff/D23430176) [ghstack-poisoned]
when compiling submodule, we always use the default method inference rule, which includes `forward` and methods it calls, this is also the case for tracing. For tracing we don't want to compile the forward by default for submodule. This PR allows the recursive scripting to reuse stubs_fn to create its submodule instead of alwaysing using the default rule. So three cases here: 1. tracing: only script exported methods, still trace forward and its recursive calls 2. scripting: always reuse stubs_fn 3. legacy ScriptModule: recursive compile use default infer_methods rule, if the module is a scriptModule, find methods from self._methods Fixes #43729 ghstack-source-id: 73f52f8 Pull Request resolved: #43872
suo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
Stack from ghstack:
when compiling submodule, we always use the default method inference rule, which includes
forwardand methods it calls, this is also the case for tracing. For tracing we don't want to compile the forward by default for submodule. This PR allows the recursive scripting to reuse stubs_fn to create its submodule instead of alwaysing using the default rule. So three cases here:recursive calls
rule, if the module is a scriptModule, find methods from
self._methods
Fixes #43729
Differential Revision: D23430176