Skip to content

Conversation

@samdow
Copy link
Contributor

@samdow samdow commented Nov 17, 2022

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 17, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/89213

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 497a034:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

samdow pushed a commit that referenced this pull request Nov 17, 2022
samdow pushed a commit that referenced this pull request Nov 18, 2022
@samdow samdow changed the title [functorch] update make_functional tests to also test using functional_call [functorch] add functorch functional_call, update tests to test this Dec 16, 2022
@samdow samdow requested a review from zou3519 December 16, 2022 15:51
@samdow
Copy link
Contributor Author

samdow commented Dec 16, 2022

@zou3519 stack now starts here. It's ready for review!

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code looks reasonable to me. I had some high-level questions (that I might have asked on the other PRs, sorry if they're redundant):

  • should the new functional_call API be the same as the nn.stateless.utils.functional_call API, or is the plan that we'll deprecate the nn.stateless.utils.functional_call immediately?
  • are we exposing the new functional_call in functorch, or should it just be exposed in torch.func?
  • should we have params_and_buffers_disable_autograd_tracking ? Or should we just warn somewhere in the functional_call docs and show an example with .detach() ?

Comment on lines 13 to 18
def functional_call(
module: 'torch.nn.Module',
parameter_and_buffer_dicts: Union[Dict[str, Tensor], Tuple[Dict[str, Tensor]]],
args: Union[Any, Tuple],
kwargs: Dict[str, Any] = None,
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • I assume the docs are coming later?
  • did you want nn.utils.stateless.functional_call to also have this API, or does nn.utils.stateless.functional_call have a different API (and we deprecate it?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh thanks for the catch I forgot to move the docs over

Re: nn.utils.stateless.functional_call, I think we should deprecate it and ask users to use torch.func instead since they are the same API. I'll add a PR to do that. Also since it was technically BC-breaking to change this since previously parameters_and_buffers must have been a single dictionary

@zou3519
Copy link
Contributor

zou3519 commented Dec 20, 2022

Btw, torch.func exists now, so let's put functional_call there

@samdow
Copy link
Contributor Author

samdow commented Dec 21, 2022

Btw, torch.func exists now, so let's put functional_call there

Done! Let me know if there's anything else!

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused about the duplicate keys check, but other than that, this LGTM. I had some minor comments on nits and suggestions to improve the testing

.. toctree::
:maxdepth: 1

batch_norm
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this exist? I don't see it in this PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Moved it over but it still refers to functorch.experimental.replace_all_batch_norm_modules_ I was going to add a PR at the top to fix this but happy to squash it in here if that's preferred

Comment on lines 55 to 57
>>> mod = nn.Linear(1, 1)
>>> d = {k: v.detach() for k, v in mod.named_parameters()}
>>> grad(lambda x: functional_call(mod, d, x), torch.randn((1, 1))) # doesn't tracks grads for params
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: some text explanation of what is going on here, perhaps also with a compute_loss function. "applies the grad transform over the parameters of a model".

https://pytorch.org/functorch/stable/generated/functorch.make_functional.html#functorch.make_functional has a nice end-to-end example that we should follow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added end-to-end example and moved this down to be a part of the note on not tracking. Here's what the ending looks like, rendered

Screen Shot 2022-12-27 at 3 25 42 PM

Comment on lines +11 to +14
module: 'torch.nn.Module',
parameter_and_buffer_dicts: Union[Dict[str, Tensor], Tuple[Dict[str, Tensor], ...]],
args: Union[Any, Tuple],
kwargs: Dict[str, Any] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to check: Did the weight tying update for stateless.functional_call happen already? If so, should functional_call have an option to toggle the weight tying behavior?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#90477 hasn't been approved yet. I can add it in there if this lands first

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, totally missed that PR

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, some minor nits

What's happening?
-----------------
Batch Norm requires in-place updates to running_mean and running_var of the same size as the input.
Functorch does not support inplace update to a regular tensor that takes in a batched tensor (i.e.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: vmap? And link section from the UX limitations doc if possible (there is some sphinx syntax to link it)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the UX limitations hasn't been pulled over yet (probably my bad for not looking at a PR yet. Filed #91509 so I don't forget to do this in a follow up)

Comment on lines 11 to 15
How to fix
----------
All of these options assume that you don't need running stats. If you're using a module this means
that it's assumed you won't use batch norm in evaluation mode. If you have a use case that involves
running batch norm with vmap in evaluation mode, please file an issue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we not going to recommend replacing the batch norm with another normalization layer? That should be good for completeness so the user knows their options

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added language to 1 + 2 on how to replace with GroupNorm. From Opacus' work, generally replacing the norm layer with a different layer is fragile but the version functorch offers (replace_batch_norm_modules_) turns off tracking should work since it just updates the layer in place instead of trying to replace the whole layer

@pytorchmergebot
Copy link
Collaborator

Rebased gh/samdow/46/orig onto refs/remotes/origin/viable/strict because #88850 was rebased, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/89213)

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in c5e5916.

@facebook-github-bot facebook-github-bot deleted the gh/samdow/48/head branch June 8, 2023 18:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants