-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Allow Module forward-pre and forward hooks to take kwargs #89389
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
closes #35643 This PR is mostly copied from #82042. Thanks @Padarn for implementing the first version and debugging into the errors. Based on the discussion in #82042, this PR adds a `with_kwargs` argument to `register_forward_pre_hook` and `register_forward_hook` methods. When the arg is set to true, the provided hook must accept kwargs args. Under the hood, the hook is wrapped into `_ForwardHook` and `_ForwardPreHook` types to avoid backward compatibility issues. [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/89389
Note: Links to docs will display an error until the docs builds have been completed. ❌ 5 FailuresAs of commit d940db0: The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
closes #35643 This PR is mostly copied from #82042. Thanks Padarn for implementing the first version and debugging into the errors. Based on the discussion in #82042, this PR adds a `with_kwargs` argument to `register_forward_pre_hook` and `register_forward_hook` methods. When the arg is set to true, the provided hook must accept kwargs args. Under the hood, the hook is wrapped into `_ForwardHook` and `_ForwardPreHook` types to avoid backward compatibility issues. cc jerryzh168 jianyuh raghuramank100 jamesr66a vkuzo jgong5 Xia-Weiwen leslie-fang-intel [ghstack-poisoned]
closes #35643 This PR is mostly copied from #82042. Thanks Padarn for implementing the first version and debugging into the errors. Based on the discussion in #82042, this PR adds a `with_kwargs` argument to `register_forward_pre_hook` and `register_forward_hook` methods. When the arg is set to true, the provided hook must accept kwargs args. Under the hood, the hook is wrapped into `_ForwardHook` and `_ForwardPreHook` types to avoid backward compatibility issues. cc jerryzh168 jianyuh raghuramank100 jamesr66a vkuzo jgong5 Xia-Weiwen leslie-fang-intel albanD mruberry jbschlosser walterddr kshitij12345 saketh-are [ghstack-poisoned]
|
Thaks @mrshenli! I've closed my PR. I like the changes you've made 👍 |
torch/nn/modules/module.py
Outdated
| self.module = weakref.ref(state["module"]) | ||
|
|
||
|
|
||
| # N.B.: This calss is NOT deriving from `_WrappedHook`, because pre- and post- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I had this same realisation. It does feel like they should be mergable, but I couldn't see a simple way to do so.
closes #35643 This PR is mostly copied from #82042. Thanks Padarn for implementing the first version and debugging into the errors. Based on the discussion in #82042, this PR adds a `with_kwargs` argument to `register_forward_pre_hook` and `register_forward_hook` methods. When the arg is set to true, the provided hook must accept kwargs args. Under the hood, the hook is wrapped into `_ForwardHook` and `_ForwardPreHook` types to avoid backward compatibility issues. cc jerryzh168 jianyuh raghuramank100 jamesr66a vkuzo jgong5 Xia-Weiwen leslie-fang-intel albanD mruberry jbschlosser walterddr kshitij12345 saketh-are [ghstack-poisoned]
closes #35643 This PR is mostly copied from #82042. Thanks Padarn for implementing the first version and debugging into the errors. Based on the discussion in #82042, this PR adds a `with_kwargs` argument to `register_forward_pre_hook` and `register_forward_hook` methods. When the arg is set to true, the provided hook must accept kwargs args. Under the hood, the hook is wrapped into `_ForwardHook` and `_ForwardPreHook` types to avoid backward compatibility issues. cc jerryzh168 jianyuh raghuramank100 jamesr66a vkuzo jgong5 Xia-Weiwen leslie-fang-intel albanD mruberry jbschlosser walterddr kshitij12345 saketh-are [ghstack-poisoned]
|
@mrshenli has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
soulitzer
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update, mostly looks good! Mainly just a comment about cleaning up the new dict on removal.
|
Thanks @awgu and @soulitzer for the review! Before updating this PR, will let CI run for a bit longer while waiting for comments from other reviewers |
| # Marks whether the corresponding _forward_hooks accept kwargs or not. All | ||
| # As JIT does not support Set[int], this dict is used as a set, where all | ||
| # hooks represented in this dict accept kwargs. | ||
| _forward_hooks_with_kwargs: Dict[int, bool] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes the handle.remove() method should remove the entry in the _forward_hooks_with_kwargs dict.
closes #35643 This PR is mostly borrowed from #82042. Thanks Padarn for implementing the first version and debugging into the errors. Based on the discussion in #82042 this PR adds a with_kwargs argument to register_forward_pre_hook and register_forward_hook methods. When the arg is set to true, the provided hook must accept kwargs args. Under the hook, this PR adds a `_forward_pre_hooks_with_kwargs` and a `_forward_hook_with_kwargs` set to keep track of which hooks accept kwargs. Differential Revision: [D41431111](https://our.internmc.facebook.com/intern/diff/D41431111) cc albanD mruberry jbschlosser walterddr kshitij12345 saketh-are [ghstack-poisoned]
torch/utils/hooks.py
Outdated
|
|
||
| def __init__(self, hooks_dict: Any) -> None: | ||
| def __init__( | ||
| self, hooks_dict: Any, *, kwargs_dict: Dict[int, bool] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what kwargs_dict means in this context. extra_dict sounds better.
Also the typing here should be the same as hooks_dict no?
Even better might be to just do *args and take as many dictionary as the user wants.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me do extra_dict for now. Can add arbitrary numbers of dict if necessary in followups.
torch/utils/hooks.py
Outdated
| self.kwargs_dict_ref = ( | ||
| None | ||
| if len(state) < 3 | ||
| else weakref.ref(dict() if state[2] is None else state[2]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why dict() nad not OrderedDict() like the one above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought we don't need order for this one. Let me change that to OrderedDict to keep things consistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as with the other one tbh.
closes #35643 This PR is mostly borrowed from #82042. Thanks Padarn for implementing the first version and debugging into the errors. Based on the discussion in #82042 this PR adds a with_kwargs argument to register_forward_pre_hook and register_forward_hook methods. When the arg is set to true, the provided hook must accept kwargs args. Under the hook, this PR adds a `_forward_pre_hooks_with_kwargs` and a `_forward_hook_with_kwargs` set to keep track of which hooks accept kwargs. Differential Revision: [D41431111](https://our.internmc.facebook.com/intern/diff/D41431111) cc albanD mruberry jbschlosser walterddr kshitij12345 saketh-are [ghstack-poisoned]
|
@mrshenli has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
closes #35643 This PR is mostly borrowed from #82042. Thanks Padarn for implementing the first version and debugging into the errors. Based on the discussion in #82042 this PR adds a with_kwargs argument to register_forward_pre_hook and register_forward_hook methods. When the arg is set to true, the provided hook must accept kwargs args. Under the hook, this PR adds a `_forward_pre_hooks_with_kwargs` and a `_forward_hook_with_kwargs` set to keep track of which hooks accept kwargs. Differential Revision: [D41431111](https://our.internmc.facebook.com/intern/diff/D41431111) cc albanD mruberry jbschlosser walterddr kshitij12345 saketh-are [ghstack-poisoned]
|
@mrshenli has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small nit about error but SGTM otherwise.
Will let @soulitzer take a final look as well!
closes #35643 This PR is mostly borrowed from #82042. Thanks Padarn for implementing the first version and debugging into the errors. Based on the discussion in #82042 this PR adds a with_kwargs argument to register_forward_pre_hook and register_forward_hook methods. When the arg is set to true, the provided hook must accept kwargs args. Under the hook, this PR adds a `_forward_pre_hooks_with_kwargs` and a `_forward_hook_with_kwargs` set to keep track of which hooks accept kwargs. Differential Revision: [D41431111](https://our.internmc.facebook.com/intern/diff/D41431111) cc albanD mruberry jbschlosser walterddr kshitij12345 saketh-are [ghstack-poisoned]
closes #35643 This PR is mostly borrowed from #82042. Thanks @Padarn for implementing the first version and debugging into the errors. Based on the discussion in #82042 this PR adds a with_kwargs argument to register_forward_pre_hook and register_forward_hook methods. When the arg is set to true, the provided hook must accept kwargs args. Under the hook, this PR adds a `_forward_pre_hooks_with_kwargs` and a `_forward_hook_with_kwargs` set to keep track of which hooks accept kwargs. ghstack-source-id: 1f40df4 Pull Request resolved: #89389
|
@mrshenli has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
soulitzer
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks!
|
Regarding the errors:
I couldn't reproduce the OOM. Based on the error message, it could be CUDA context from concurrent CI tests ate up CUDA memory? cc @huydhn
This is irrelevant and already on master. See https://github.com/pytorch/pytorch/actions/runs/3518138536/jobs/5896929592 |
|
One more failure, on Checked the test source code, it is op-only test, and doesn't touch module hooks at all. Also I couldn't reproduce this locally. pytorch/test/test_reductions.py Lines 373 to 380 in 1cfd385
|
|
4th failure: This is also irrelevant and on master. See https://github.com/pytorch/pytorch/actions/runs/3528028962/jobs/5917920193 |
|
One more, same as the 3rd one: |
|
@pytorchbot merge -f "all test failures are irrelevant, see comments above" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
) closes pytorch#35643 This PR is mostly borrowed from pytorch#82042. Thanks @Padarn for implementing the first version and debugging into the errors. Based on the discussion in pytorch#82042 this PR adds a with_kwargs argument to register_forward_pre_hook and register_forward_hook methods. When the arg is set to true, the provided hook must accept kwargs args. Under the hook, this PR adds a `_forward_pre_hooks_with_kwargs` and a `_forward_hook_with_kwargs` set to keep track of which hooks accept kwargs. Differential Revision: [D41431111](https://our.internmc.facebook.com/intern/diff/D41431111) Pull Request resolved: pytorch#89389 Approved by: https://github.com/soulitzer
Stack from ghstack (oldest at bottom):
closes #35643
This PR is mostly borrowed from #82042. Thanks @Padarn for implementing
the first version and debugging into the errors.
Based on the discussion in #82042 this PR adds a with_kwargs
argument to register_forward_pre_hook and register_forward_hook
methods. When the arg is set to true, the provided hook must accept
kwargs args. Under the hook, this PR adds a
_forward_pre_hooks_with_kwargsand a_forward_hook_with_kwargsset to keep track of which hooks accept kwargs.
Differential Revision: D41431111
cc @albanD @mruberry @jbschlosser @walterddr @kshitij12345 @saketh-are