-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[functorch] add functorch functional_call, update tests to test this #89213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…l_call [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/89213
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 497a034: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…g functional_call" [ghstack-poisoned]
…g functional_call" [ghstack-poisoned]
…g functional_call" [ghstack-poisoned]
…g functional_call" [ghstack-poisoned]
… test this" [ghstack-poisoned]
|
@zou3519 stack now starts here. It's ready for review! |
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code looks reasonable to me. I had some high-level questions (that I might have asked on the other PRs, sorry if they're redundant):
- should the new functional_call API be the same as the nn.stateless.utils.functional_call API, or is the plan that we'll deprecate the nn.stateless.utils.functional_call immediately?
- are we exposing the new functional_call in
functorch, or should it just be exposed in torch.func? - should we have params_and_buffers_disable_autograd_tracking ? Or should we just warn somewhere in the functional_call docs and show an example with .detach() ?
| def functional_call( | ||
| module: 'torch.nn.Module', | ||
| parameter_and_buffer_dicts: Union[Dict[str, Tensor], Tuple[Dict[str, Tensor]]], | ||
| args: Union[Any, Tuple], | ||
| kwargs: Dict[str, Any] = None, | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- I assume the docs are coming later?
- did you want nn.utils.stateless.functional_call to also have this API, or does nn.utils.stateless.functional_call have a different API (and we deprecate it?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahh thanks for the catch I forgot to move the docs over
Re: nn.utils.stateless.functional_call, I think we should deprecate it and ask users to use torch.func instead since they are the same API. I'll add a PR to do that. Also since it was technically BC-breaking to change this since previously parameters_and_buffers must have been a single dictionary
… test this" [ghstack-poisoned]
… test this" [ghstack-poisoned]
|
Btw, torch.func exists now, so let's put functional_call there |
… test this" [ghstack-poisoned]
… test this" [ghstack-poisoned]
Done! Let me know if there's anything else! |
… test this" [ghstack-poisoned]
… test this" [ghstack-poisoned]
… test this" [ghstack-poisoned]
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit confused about the duplicate keys check, but other than that, this LGTM. I had some minor comments on nits and suggestions to improve the testing
docs/source/func.api.rst
Outdated
| .. toctree:: | ||
| :maxdepth: 1 | ||
|
|
||
| batch_norm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this exist? I don't see it in this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! Moved it over but it still refers to functorch.experimental.replace_all_batch_norm_modules_ I was going to add a PR at the top to fix this but happy to squash it in here if that's preferred
torch/_functorch/functional_call.py
Outdated
| >>> mod = nn.Linear(1, 1) | ||
| >>> d = {k: v.detach() for k, v in mod.named_parameters()} | ||
| >>> grad(lambda x: functional_call(mod, d, x), torch.randn((1, 1))) # doesn't tracks grads for params |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: some text explanation of what is going on here, perhaps also with a compute_loss function. "applies the grad transform over the parameters of a model".
https://pytorch.org/functorch/stable/generated/functorch.make_functional.html#functorch.make_functional has a nice end-to-end example that we should follow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| module: 'torch.nn.Module', | ||
| parameter_and_buffer_dicts: Union[Dict[str, Tensor], Tuple[Dict[str, Tensor], ...]], | ||
| args: Union[Any, Tuple], | ||
| kwargs: Dict[str, Any] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to check: Did the weight tying update for stateless.functional_call happen already? If so, should functional_call have an option to toggle the weight tying behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#90477 hasn't been approved yet. I can add it in there if this lands first
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, totally missed that PR
… test this" [ghstack-poisoned]
… test this" [ghstack-poisoned]
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, some minor nits
| What's happening? | ||
| ----------------- | ||
| Batch Norm requires in-place updates to running_mean and running_var of the same size as the input. | ||
| Functorch does not support inplace update to a regular tensor that takes in a batched tensor (i.e. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: vmap? And link section from the UX limitations doc if possible (there is some sphinx syntax to link it)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the UX limitations hasn't been pulled over yet (probably my bad for not looking at a PR yet. Filed #91509 so I don't forget to do this in a follow up)
| How to fix | ||
| ---------- | ||
| All of these options assume that you don't need running stats. If you're using a module this means | ||
| that it's assumed you won't use batch norm in evaluation mode. If you have a use case that involves | ||
| running batch norm with vmap in evaluation mode, please file an issue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we not going to recommend replacing the batch norm with another normalization layer? That should be good for completeness so the user knows their options
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added language to 1 + 2 on how to replace with GroupNorm. From Opacus' work, generally replacing the norm layer with a different layer is fragile but the version functorch offers (replace_batch_norm_modules_) turns off tracking should work since it just updates the layer in place instead of trying to replace the whole layer
… test this" [ghstack-poisoned]
… test this" [ghstack-poisoned]
|
Rebased |
|
This pull request has been merged in c5e5916. |

Stack from ghstack (oldest at bottom):