Skip to content

Conversation

@zou3519
Copy link
Contributor

@zou3519 zou3519 commented Dec 20, 2022

Stack from ghstack:

Support for jvp is very similar to support for backward():

  • We need to vmap over a version of the original autograd.Function's jvp
    method that does not take ctx as input.
  • On the output, we need to reductify to ensure the output tangent has
    the same shape as the output. This reductify does not have the
    extra reduction semantics, because PyTorch forward-mode AD requires the
    output tangent to have the same exact shape as the output.
  • setup_context needs to tell us the bdims of the saved_tensors
    (necessary for vmap over jvp_no_context), as well
    as the output shapes (necessary for reductify).

Test Plan:

  • Added jvp support to the *GenVmapAutogradFunction

Support for jvp is very similar to support for backward():
- We need to vmap over a version of the original autograd.Function's jvp
method that does not take ctx as input.
- On the output, we need to reductify to ensure the output tangent has
the same shape as the output. This reductify does not have the
extra reduction semantics, because PyTorch forward-mode AD requires the
output tangent to have the same exact shape as the output.
- setup_context needs to tell us the bdims of the saved_tensors
(necessary for vmap over jvp_no_context), as well
as the output shapes (necessary for reductify).

Test Plan:
- Added jvp support to the *GenVmapAutogradFunction

[ghstack-poisoned]
zou3519 added a commit that referenced this pull request Dec 20, 2022
Support for jvp is very similar to support for backward():
- We need to vmap over a version of the original autograd.Function's jvp
method that does not take ctx as input.
- On the output, we need to reductify to ensure the output tangent has
the same shape as the output. This reductify does not have the
extra reduction semantics, because PyTorch forward-mode AD requires the
output tangent to have the same exact shape as the output.
- setup_context needs to tell us the bdims of the saved_tensors
(necessary for vmap over jvp_no_context), as well
as the output shapes (necessary for reductify).

Test Plan:
- Added jvp support to the *GenVmapAutogradFunction

ghstack-source-id: ed2abc4
Pull Request resolved: #91211
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 20, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/91211

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Failures

As of commit 3fef822:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@zou3519 zou3519 added the release notes: torch.func release notes category for torch.vmap or torch.func.* APIs label Dec 21, 2022
@zou3519 zou3519 requested review from samdow and soulitzer December 21, 2022 15:04
Copy link
Contributor

@soulitzer soulitzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zou3519 zou3519 added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 22, 2022
Support for jvp is very similar to support for backward():
- We need to vmap over a version of the original autograd.Function's jvp
method that does not take ctx as input.
- On the output, we need to reductify to ensure the output tangent has
the same shape as the output. This reductify does not have the
extra reduction semantics, because PyTorch forward-mode AD requires the
output tangent to have the same exact shape as the output.
- setup_context needs to tell us the bdims of the saved_tensors
(necessary for vmap over jvp_no_context), as well
as the output shapes (necessary for reductify).

Test Plan:
- Added jvp support to the *GenVmapAutogradFunction

[ghstack-poisoned]
zou3519 added a commit that referenced this pull request Dec 27, 2022
Support for jvp is very similar to support for backward():
- We need to vmap over a version of the original autograd.Function's jvp
method that does not take ctx as input.
- On the output, we need to reductify to ensure the output tangent has
the same shape as the output. This reductify does not have the
extra reduction semantics, because PyTorch forward-mode AD requires the
output tangent to have the same exact shape as the output.
- setup_context needs to tell us the bdims of the saved_tensors
(necessary for vmap over jvp_no_context), as well
as the output shapes (necessary for reductify).

Test Plan:
- Added jvp support to the *GenVmapAutogradFunction

ghstack-source-id: 1bb23bd
Pull Request resolved: #91211
Support for jvp is very similar to support for backward():
- We need to vmap over a version of the original autograd.Function's jvp
method that does not take ctx as input.
- On the output, we need to reductify to ensure the output tangent has
the same shape as the output. This reductify does not have the
extra reduction semantics, because PyTorch forward-mode AD requires the
output tangent to have the same exact shape as the output.
- setup_context needs to tell us the bdims of the saved_tensors
(necessary for vmap over jvp_no_context), as well
as the output shapes (necessary for reductify).

Test Plan:
- Added jvp support to the *GenVmapAutogradFunction

[ghstack-poisoned]
zou3519 added a commit that referenced this pull request Dec 27, 2022
Support for jvp is very similar to support for backward():
- We need to vmap over a version of the original autograd.Function's jvp
method that does not take ctx as input.
- On the output, we need to reductify to ensure the output tangent has
the same shape as the output. This reductify does not have the
extra reduction semantics, because PyTorch forward-mode AD requires the
output tangent to have the same exact shape as the output.
- setup_context needs to tell us the bdims of the saved_tensors
(necessary for vmap over jvp_no_context), as well
as the output shapes (necessary for reductify).

Test Plan:
- Added jvp support to the *GenVmapAutogradFunction

ghstack-source-id: e09940f
Pull Request resolved: #91211
@zou3519
Copy link
Contributor Author

zou3519 commented Dec 27, 2022

@pytorchbot merge -f "check was cancelled, idk why"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/zou3519/594/head branch June 8, 2023 19:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: torch.func release notes category for torch.vmap or torch.func.* APIs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants