-
Notifications
You must be signed in to change notification settings - Fork 26.3k
functorch.jvp support for autograd.Function #90077
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WIP, to be filled in [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90077
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 FailuresAs of commit 018bcc1: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Not ready for review yet, getting some CI signal. |
WIP, to be filled in [ghstack-poisoned]
WIP, to be filled in [ghstack-poisoned]
WIP, to be filled in [ghstack-poisoned]
This PR adds functorch.jvp support for autograd.Function. It does so by adding a jvp rule for custom_function_call. For a regular PyTorch operation (like at::sin), the VariableType kernel: - re-dispatches to at::sin - calls the jvp rule for at::sin The jvp rule for custom_function_call does just that. It constructs a new autograd.Function (because the above logic already exists). Inside the forward, it re-dispatches to custom_function_call. In the jvp rule, it just calls whatever the jvp rule is supposed to be. Since this logic is really close to the custom_function_call_grad, I just put them together. Test Plan: - added jvp rules to the autograd.Function in autograd_function_db [ghstack-poisoned]
This PR adds functorch.jvp support for autograd.Function. It does so by adding a jvp rule for custom_function_call. For a regular PyTorch operation (like at::sin), the VariableType kernel: - re-dispatches to at::sin - calls the jvp rule for at::sin The jvp rule for custom_function_call does just that. It constructs a new autograd.Function (because the above logic already exists). Inside the forward, it re-dispatches to custom_function_call. In the jvp rule, it just calls whatever the jvp rule is supposed to be. Since this logic is really close to the custom_function_call_grad, I just put them together. Test Plan: - added jvp rules to the autograd.Function in autograd_function_db ghstack-source-id: 026f85e Pull Request resolved: #90077
This PR adds functorch.jvp support for autograd.Function. It does so by adding a jvp rule for custom_function_call. For a regular PyTorch operation (like at::sin), the VariableType kernel: - re-dispatches to at::sin - calls the jvp rule for at::sin The jvp rule for custom_function_call does just that. It constructs a new autograd.Function (because the above logic already exists). Inside the forward, it re-dispatches to custom_function_call. In the jvp rule, it just calls whatever the jvp rule is supposed to be. Since this logic is really close to the custom_function_call_grad, I just put them together. Test Plan: - added jvp rules to the autograd.Function in autograd_function_db [ghstack-poisoned]
This PR adds functorch.jvp support for autograd.Function. It does so by adding a jvp rule for custom_function_call. For a regular PyTorch operation (like at::sin), the VariableType kernel: - re-dispatches to at::sin - calls the jvp rule for at::sin The jvp rule for custom_function_call does just that. It constructs a new autograd.Function (because the above logic already exists). Inside the forward, it re-dispatches to custom_function_call. In the jvp rule, it just calls whatever the jvp rule is supposed to be. Since this logic is really close to the custom_function_call_grad, I just put them together. Test Plan: - added jvp rules to the autograd.Function in autograd_function_db [ghstack-poisoned]
This PR adds functorch.jvp support for autograd.Function. It does so by adding a jvp rule for custom_function_call. For a regular PyTorch operation (like at::sin), the VariableType kernel: - re-dispatches to at::sin - calls the jvp rule for at::sin The jvp rule for custom_function_call does just that. It constructs a new autograd.Function (because the above logic already exists). Inside the forward, it re-dispatches to custom_function_call. In the jvp rule, it just calls whatever the jvp rule is supposed to be. Since this logic is really close to the custom_function_call_grad, I just put them together. Test Plan: - added jvp rules to the autograd.Function in autograd_function_db [ghstack-poisoned]
This PR adds functorch.jvp support for autograd.Function. It does so by adding a jvp rule for custom_function_call. For a regular PyTorch operation (like at::sin), the VariableType kernel: - re-dispatches to at::sin - calls the jvp rule for at::sin The jvp rule for custom_function_call does just that. It constructs a new autograd.Function (because the above logic already exists). Inside the forward, it re-dispatches to custom_function_call. In the jvp rule, it just calls whatever the jvp rule is supposed to be. Since this logic is really close to the custom_function_call_grad, I just put them together. Test Plan: - added jvp rules to the autograd.Function in autograd_function_db [ghstack-poisoned]
This PR adds functorch.jvp support for autograd.Function. It does so by adding a jvp rule for custom_function_call. For a regular PyTorch operation (like at::sin), the VariableType kernel: - re-dispatches to at::sin - calls the jvp rule for at::sin The jvp rule for custom_function_call does just that. It constructs a new autograd.Function (because the above logic already exists). Inside the forward, it re-dispatches to custom_function_call. In the jvp rule, it just calls whatever the jvp rule is supposed to be. Since this logic is really close to the custom_function_call_grad, I just put them together. Test Plan: - added jvp rules to the autograd.Function in autograd_function_db [ghstack-poisoned]
This PR adds functorch.jvp support for autograd.Function. It does so by adding a jvp rule for custom_function_call. For a regular PyTorch operation (like at::sin), the VariableType kernel: - re-dispatches to at::sin - calls the jvp rule for at::sin The jvp rule for custom_function_call does just that. It constructs a new autograd.Function (because the above logic already exists). Inside the forward, it re-dispatches to custom_function_call. In the jvp rule, it just calls whatever the jvp rule is supposed to be. Since this logic is really close to the custom_function_call_grad, I just put them together. Test Plan: - added jvp rules to the autograd.Function in autograd_function_db [ghstack-poisoned]
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds pretty good!
torch/_functorch/pyfunctorch.py
Outdated
| def lower(self): | ||
| prev_fwd_grad_mode = self.prev_fwd_grad_mode() | ||
| if not self.prev_fwd_grad_mode: | ||
| return contextlib.nested(_set_fwd_grad_enabled(False), super().lower()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isn't contextlib.nested deprecated since python 2.7?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, it's not in the Python 3 docs. I'll find an alternative
soulitzer
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool!
This PR adds functorch.jvp support for autograd.Function. It does so by adding a jvp rule for custom_function_call. For a regular PyTorch operation (like at::sin), the VariableType kernel: - re-dispatches to at::sin - calls the jvp rule for at::sin The jvp rule for custom_function_call does just that. It constructs a new autograd.Function (because the above logic already exists). Inside the forward, it re-dispatches to custom_function_call. In the jvp rule, it just calls whatever the jvp rule is supposed to be. Since this logic is really close to the custom_function_call_grad, I just put them together. Test Plan: - added jvp rules to the autograd.Function in autograd_function_db [ghstack-poisoned]
|
This pull request has been merged in 4809e83. |
Stack from ghstack:
This PR adds functorch.jvp support for autograd.Function. It does so by
adding a jvp rule for custom_function_call.
For a regular PyTorch operation (like at::sin), the VariableType kernel:
The jvp rule for custom_function_call does just that. It constructs a
new autograd.Function (because the above logic already exists). Inside
the forward, it re-dispatches to custom_function_call. In the jvp rule,
it just calls whatever the jvp rule is supposed to be.
Since this logic is really close to the custom_function_call_grad, I
just put them together.
Test Plan: