Skip to content

Conversation

@zou3519
Copy link
Contributor

@zou3519 zou3519 commented Dec 2, 2022

Stack from ghstack:

This PR adds functorch.jvp support for autograd.Function. It does so by
adding a jvp rule for custom_function_call.

For a regular PyTorch operation (like at::sin), the VariableType kernel:

  • re-dispatches to at::sin
  • calls the jvp rule for at::sin

The jvp rule for custom_function_call does just that. It constructs a
new autograd.Function (because the above logic already exists). Inside
the forward, it re-dispatches to custom_function_call. In the jvp rule,
it just calls whatever the jvp rule is supposed to be.

Since this logic is really close to the custom_function_call_grad, I
just put them together.

Test Plan:

  • added jvp rules to the autograd.Function in autograd_function_db

WIP, to be filled in

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 2, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90077

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Failures

As of commit 018bcc1:

The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

zou3519 added a commit that referenced this pull request Dec 2, 2022
WIP, to be filled in

ghstack-source-id: 6ecdf65
Pull Request resolved: #90077
@zou3519
Copy link
Contributor Author

zou3519 commented Dec 2, 2022

Not ready for review yet, getting some CI signal.

@zou3519 zou3519 added the release notes: torch.func release notes category for torch.vmap or torch.func.* APIs label Dec 2, 2022
WIP, to be filled in

[ghstack-poisoned]
WIP, to be filled in

[ghstack-poisoned]
WIP, to be filled in

[ghstack-poisoned]
This PR adds functorch.jvp support for autograd.Function. It does so by
adding a jvp rule for custom_function_call.

For a regular PyTorch operation (like at::sin), the VariableType kernel:
- re-dispatches to at::sin
- calls the jvp rule for at::sin

The jvp rule for custom_function_call does just that. It constructs a
new autograd.Function (because the above logic already exists). Inside
the forward, it re-dispatches to custom_function_call. In the jvp rule,
it just calls whatever the jvp rule is supposed to be.

Since this logic is really close to the custom_function_call_grad, I
just put them together.

Test Plan:
- added jvp rules to the autograd.Function in autograd_function_db

[ghstack-poisoned]
zou3519 added a commit that referenced this pull request Dec 6, 2022
This PR adds functorch.jvp support for autograd.Function. It does so by
adding a jvp rule for custom_function_call.

For a regular PyTorch operation (like at::sin), the VariableType kernel:
- re-dispatches to at::sin
- calls the jvp rule for at::sin

The jvp rule for custom_function_call does just that. It constructs a
new autograd.Function (because the above logic already exists). Inside
the forward, it re-dispatches to custom_function_call. In the jvp rule,
it just calls whatever the jvp rule is supposed to be.

Since this logic is really close to the custom_function_call_grad, I
just put them together.

Test Plan:
- added jvp rules to the autograd.Function in autograd_function_db

ghstack-source-id: 026f85e
Pull Request resolved: #90077
@zou3519 zou3519 requested review from samdow and soulitzer December 6, 2022 14:50
This PR adds functorch.jvp support for autograd.Function. It does so by
adding a jvp rule for custom_function_call.

For a regular PyTorch operation (like at::sin), the VariableType kernel:
- re-dispatches to at::sin
- calls the jvp rule for at::sin

The jvp rule for custom_function_call does just that. It constructs a
new autograd.Function (because the above logic already exists). Inside
the forward, it re-dispatches to custom_function_call. In the jvp rule,
it just calls whatever the jvp rule is supposed to be.

Since this logic is really close to the custom_function_call_grad, I
just put them together.

Test Plan:
- added jvp rules to the autograd.Function in autograd_function_db

[ghstack-poisoned]
This PR adds functorch.jvp support for autograd.Function. It does so by
adding a jvp rule for custom_function_call.

For a regular PyTorch operation (like at::sin), the VariableType kernel:
- re-dispatches to at::sin
- calls the jvp rule for at::sin

The jvp rule for custom_function_call does just that. It constructs a
new autograd.Function (because the above logic already exists). Inside
the forward, it re-dispatches to custom_function_call. In the jvp rule,
it just calls whatever the jvp rule is supposed to be.

Since this logic is really close to the custom_function_call_grad, I
just put them together.

Test Plan:
- added jvp rules to the autograd.Function in autograd_function_db

[ghstack-poisoned]
This PR adds functorch.jvp support for autograd.Function. It does so by
adding a jvp rule for custom_function_call.

For a regular PyTorch operation (like at::sin), the VariableType kernel:
- re-dispatches to at::sin
- calls the jvp rule for at::sin

The jvp rule for custom_function_call does just that. It constructs a
new autograd.Function (because the above logic already exists). Inside
the forward, it re-dispatches to custom_function_call. In the jvp rule,
it just calls whatever the jvp rule is supposed to be.

Since this logic is really close to the custom_function_call_grad, I
just put them together.

Test Plan:
- added jvp rules to the autograd.Function in autograd_function_db

[ghstack-poisoned]
This PR adds functorch.jvp support for autograd.Function. It does so by
adding a jvp rule for custom_function_call.

For a regular PyTorch operation (like at::sin), the VariableType kernel:
- re-dispatches to at::sin
- calls the jvp rule for at::sin

The jvp rule for custom_function_call does just that. It constructs a
new autograd.Function (because the above logic already exists). Inside
the forward, it re-dispatches to custom_function_call. In the jvp rule,
it just calls whatever the jvp rule is supposed to be.

Since this logic is really close to the custom_function_call_grad, I
just put them together.

Test Plan:
- added jvp rules to the autograd.Function in autograd_function_db

[ghstack-poisoned]
This PR adds functorch.jvp support for autograd.Function. It does so by
adding a jvp rule for custom_function_call.

For a regular PyTorch operation (like at::sin), the VariableType kernel:
- re-dispatches to at::sin
- calls the jvp rule for at::sin

The jvp rule for custom_function_call does just that. It constructs a
new autograd.Function (because the above logic already exists). Inside
the forward, it re-dispatches to custom_function_call. In the jvp rule,
it just calls whatever the jvp rule is supposed to be.

Since this logic is really close to the custom_function_call_grad, I
just put them together.

Test Plan:
- added jvp rules to the autograd.Function in autograd_function_db

[ghstack-poisoned]
This PR adds functorch.jvp support for autograd.Function. It does so by
adding a jvp rule for custom_function_call.

For a regular PyTorch operation (like at::sin), the VariableType kernel:
- re-dispatches to at::sin
- calls the jvp rule for at::sin

The jvp rule for custom_function_call does just that. It constructs a
new autograd.Function (because the above logic already exists). Inside
the forward, it re-dispatches to custom_function_call. In the jvp rule,
it just calls whatever the jvp rule is supposed to be.

Since this logic is really close to the custom_function_call_grad, I
just put them together.

Test Plan:
- added jvp rules to the autograd.Function in autograd_function_db

[ghstack-poisoned]
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds pretty good!

def lower(self):
prev_fwd_grad_mode = self.prev_fwd_grad_mode()
if not self.prev_fwd_grad_mode:
return contextlib.nested(_set_fwd_grad_enabled(False), super().lower())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't contextlib.nested deprecated since python 2.7?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, it's not in the Python 3 docs. I'll find an alternative

Copy link
Contributor

@soulitzer soulitzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool!

This PR adds functorch.jvp support for autograd.Function. It does so by
adding a jvp rule for custom_function_call.

For a regular PyTorch operation (like at::sin), the VariableType kernel:
- re-dispatches to at::sin
- calls the jvp rule for at::sin

The jvp rule for custom_function_call does just that. It constructs a
new autograd.Function (because the above logic already exists). Inside
the forward, it re-dispatches to custom_function_call. In the jvp rule,
it just calls whatever the jvp rule is supposed to be.

Since this logic is really close to the custom_function_call_grad, I
just put them together.

Test Plan:
- added jvp rules to the autograd.Function in autograd_function_db

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 4809e83.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: torch.func release notes category for torch.vmap or torch.func.* APIs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants