-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add hooks for register_buffer/module/parameter #86148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add hooks for register_buffer/module/parameter #86148
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/86148
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 41607f6: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Hi @antoniojkim thanks for sending a PR for this!
|
Would this PR pass the linting check in CI without the proper formatting? |
No. But we can merge the formatting PR quickly and so it will disappear from this PR. |
2dd2179 to
546e93d
Compare
|
I've undone all the formatting changes
Turns out that the behaviour is slightly different. So, I've reverted it back to how it was before. There is now multiple copies of the code that calls the buffer registration and module registration hooks |
|
@pytorchbot rebase |
|
You don't have permissions to rebase this PR, only people with write permissions may rebase PRs. |
|
@pytorchbot -h |
PyTorchBot HelpMergeRevertRebaseLabel |
546e93d to
6ad0084
Compare
6ad0084 to
127f0c9
Compare
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty good!
We would need a couple tests and update to the signature and that will be good to go!
torch/nn/modules/module.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you name these with a "module" in the name to match the other hooks?
torch/nn/modules/module.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't the buffer be passed in?
Also for consistency, could you pass the module first: hook(module, name, buffer) -> None or new buffer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same for the others below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep, good catch. That's my bad for just copying the docstring over and not replacing all instances 😅
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the misunderstanding. I think we can do a small update for the signatures.
The calls will need to be updated as well to pass in the right arguments to the hooks.
torch/nn/modules/module.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should still be hook(module, name, buffer) -> None or new buffer
torch/nn/modules/module.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this one should be hook(module, name, submodule) -> None or new submodule
torch/nn/modules/module.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one should be hook(module, name, param) -> None or new param
5893b2d to
73d6ced
Compare
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Signatures look good!
I think we only need two things now:
- Update the call sites to pass in the right arguments
- Add some basic tests for each of these hooks
945dff2 to
86fa769
Compare
86fa769 to
41607f6
Compare
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM!
Thanks for taking the time to add full testing!
|
@pytorchbot merge -g |
Merge startedYour change will be merged once all checks on your PR pass since you used the green (-g) flag (ETA: 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
Hey @antoniojkim. |
There is a bug in the implementation of the registration hooks introduced in #86148 whereby if the hook returns a tensor, then the short circuiting logic: ``` value = hook(self, name, value) or value ``` Raises an exception ``` RuntimeError: Boolean value of Tensor with more than one value is ambiguous ``` Fixing the logic so that it only checks to see if the value is `None` before overriding Fixes #85837 CC: @albanD @jbschlosser Pull Request resolved: #87369 Approved by: https://github.com/albanD
As described in the issue, this PR adds hooks to be run when
register_parameter,register_bufferandregister_moduleare called.Fixes #85837
cc @albanD @mruberry @jbschlosser @walterddr @kshitij12345 @saketh-are