-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Update functorch supported autograd.Function to allow mark_dirty #91222
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/91222
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 15e8773: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
test/functorch/test_ops.py
Outdated
| xfail('nn.functional.rrelu'), # in-place test errors out with no formula implemented | ||
| xfail('NumpyExpMarkDirtyAutogradFunction'), # https://github.com/pytorch/pytorch/issues/90225 | ||
| xfail('NumpyExpMarkDirtyAutogradFunction'), # TODO: calling in-place operation that would mutate a captured Tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This errors for a different reason now, need to investigate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zou3519 I think I figured out what the issue was with this, but not sure what the solution is
In cpp functorch:
- we create dual tensor with tangent that is captured (so it has a immutable wrapper)
- when we call into exp_, the
checkForInvalidMutationOnCapturesdoes not care about the tangent having a immutable wrapper, because that is hidden by the dual tensors wrapper which isn't immutable. - before we call into VariableType, we first exclude the dynamicLayerFront
- forward AD formula for inplace ops does tangent.copy_(new_tangent). Because we already excluded dynamicLayerFront, we just go into VariableType again (which is basically a noop since neither tensor has tangent this time around)
In pyfunctorch
- we still create that immutable wrapper for tangent
- when we call process, we do not exclude dynamicLayerFront
- process constructs the single layer autograd Function and calls apply (which calls into forward, then jvp)
- after forward is done (with no problems), jvp is performed, which does tangent.mul_(output). At this point, JvpTransform is still at the stop of the stack.
- since we did not exclude this time, we go into dynamicLayerFront this time, which now errors due to checkForInvalidMutationOnCaptures because now we're performing an in-place op on the tangent which has the immutable wrapper.
click for repro
from functorch import vmap, jvp
import torch
import numpy as np
def to_numpy(tensor):
return tensor.cpu().numpy()
class NumpyMul(torch.autograd.Function):
@staticmethod
def forward(x, y):
return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
@staticmethod
def setup_context(ctx, inputs, outputs):
ctx.save_for_backward(*inputs)
ctx.save_for_forward(*inputs)
@staticmethod
def backward(ctx, grad_output):
x, y = ctx.saved_tensors
gx = None
if ctx.needs_input_grad[0]:
gx = NumpyMul.apply(grad_output, y)
gy = None
if ctx.needs_input_grad[1]:
gy = NumpyMul.apply(grad_output, x)
return gx, gy
@staticmethod
def vmap(info, in_dims, x, y):
x_bdim, y_bdim = in_dims
x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
result = NumpyMul.apply(x, y)
result = result.movedim(-1, 0)
return result, 0
@staticmethod
def jvp(ctx, x_tangent, y_tangent):
x, y = ctx.saved_tensors
return x_tangent * y + y_tangent * x
class NumpyExp_(torch.autograd.Function):
@staticmethod
def forward(x):
x_np = to_numpy(x)
np.exp(x_np, x_np)
return x
@staticmethod
def setup_context(ctx, inputs, outputs):
x, = inputs
ctx.mark_dirty(x)
ctx.save_for_backward(outputs)
ctx.save_for_forward(outputs)
@staticmethod
def backward(ctx, grad_output):
output, = ctx.saved_tensors
return NumpyMul.apply(grad_output, output)
@staticmethod
def vmap(info, in_dims, x):
NumpyExp_.apply(x)
return x, in_dims[0]
@staticmethod
def jvp(ctx, x_tangent):
output, = ctx.saved_tensors
x_tangent.mul_(output)
return x_tangent
def fn(x):
# return torch.exp_(x) <-- does not error
return NumpyExp_.apply(x)
a = torch.rand(4,)
b = torch.rand(4,)
with torch.autograd.function._set_autograd_function_extension_enabled(True):
jvp(fn, (a,), (b,))There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jvp is performed, which does tangent.mul_(output)
Where is the mul_ in the code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the jvp NumpExp_ defines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe one solution could be:
- We could just say that it is okay for us to go through
processagain. It is basicaly a noop, since the stack is the same. Technically it does more checks, but maybe that is fine and we actually want those checks? - Currently in the creation of a dual tensor, the primal and tangent are not explicitly wrapped, instead we rely on them to get automatically lifted. If we manually wrap tangent (and primal) instead, this error should no longer trigger even if we go through
processan extra time. Since tangent is something the user passed in themselves, we should be okay with mutating it, and not mark it with the immutable wrapper.
Alternate solution (doesn't work):
- I also tried excluding manually in PyFunctorch's process to mimic the cpp version but ran into an issue with
unwrapped_count > 0 INTERNAL ASSERT FAILEDin the dead tensor wrapper fallback and not sure what that means yet. (What does this mean?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still processing what is going on, but let me reply to your questions:
I also tried excluding manually in PyFunctorch's process to mimic the cpp version but ran into an issue with unwrapped_count > 0 INTERNAL ASSERT FAILED in the dead tensor wrapper fallback and not sure what that means yet. (What does this mean?)
There's an invariant that a Tensor with a FuncTorchTensorWrapper dispatch key must be a TensorWrapper. Given that we hit the dead_tensor_fallback, then at least one of the inputs must be a TensorWrapper. The assertion is complaining that none of the inputs are TensorWrapper
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The thing I am struggling a bit with right now is, does the in-place mutation check even make sense for forward-mode AD?
- if it does, then it sounds like C++ functorch is wrong because it bypasses it
- if it doesn't, then to what extent can we just get rid of it from C++ and Python functorch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Claim: the input-mutation check makes sense for forward-mode AD. We want to prevent a situation where the dual tensor is created on the wrong TensorWrapper.
There are two cases here:
Case 1: captured value mutated in-place. If we have:
y = torch.tensor(1.)
def f(x):
y.copy_(x)
return x + y
jvp(f, (x,), (t,)Then the dual should be created on the wrapped version of y, not y itself. The in-place error checks should ideally raise an error in this situation.
Case 2: tangent tensor mutated in-place (which is what is happening in this PR).
import torch
import torch.autograd.forward_ad as fwAD
x = torch.tensor(2.)
y = torch.tensor(3.)
with fwAD.dual_level():
x_dual = fwAD.make_dual(x, y)
y.copy_(x_dual)
x, x_tangent = fwAD.unpack_dual(x_dual)If we ran the functorch.jvp equivalent of the above, it's important that the tangent of x is a TensorWrapper, because it ends up getting its own tangent value.
Solution?
Given the above, I like one of the solutions you proposed above, which is:
Currently in the creation of a dual tensor, the primal and tangent are not explicitly wrapped, instead we rely on them to get automatically lifted. If we manually wrap tangent (and primal) instead, this error should no longer trigger even if we go through process an extra time. Since tangent is something the user passed in themselves, we should be okay with mutating it, and not mark it with the immutable wrapper.
functorch.jvp should wrap the primal and the tangent before calling make_dual. The end state is that we get TensorWrapper(primal) that has a tangent which is TensorWrapper(tangent).
Thoughts? Also, thank you for the detailed analysis, it saved me from stepping through the code in gdb
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Case 2 is actually a bug in PyTorch forward mode AD. Even if we make primal and tangent both have TensorWrapper at the same level, the tangent's TensorWrapper should itself never have a tangent. Normally we'd error if we're setting a tangent that itself has a tangent, but we're getting around that check with an in-place lol.
I think that morally tangent should not be wrapped at the same level as the primal (I see the tangent as being metadata that lives on the primal's wrapper, so in a sense it should be subordinate to the primal). Tangent is being wrapped today because we are computing with it while JVP is active, in theory we are only computing with plain tensors at that point, so (if the forward/backward AD kernels were separate) we should be able exclude Autograd key and properly unwrap and pop JVP off the stack before computing forward grads.
That being said, I still think that it is a good idea to manually wrap tangent at the same level as primal today to indicate that it is a tensor explicitly passed in so that its AD metadata isn't immutable.
…_dirty" Uses what was originally in #89860 [ghstack-poisoned]
…_dirty" Uses what was originally in #89860 [ghstack-poisoned]
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, some minor comments. I assume we are punting the handling of the TODO to the future (but feel free to dig into it more if you're interested)
| # skip because this is flaky depending on what the max_norm is! | ||
| skip('nn.functional.embedding', ''), | ||
| skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format | ||
| xfail('NumpyExpMarkDirtyAutogradFunction'), # vmap: inplace into a regular tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to check, this is not "calling in-place operation that would mutate a captured Tensor", right?
Yup leaving this for a follow up for now |
|
@pytorchbot merge -g |
Merge failedReason: Not merging any PRs at the moment because there is a merge blocking https://github.com/pytorch/pytorch/labels/ci:%20sev issue open at: Details for Dev Infra teamRaised by workflow job |
| # def setup_context(ctx, outputs, x): | ||
| # y = outputs | ||
| # def setup_context(ctx, inputs, output): | ||
| # y = output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the rename from outputs -> output? is it a single output now? Or they are unpacked?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has always been a single output, just updating the name to reflect that. Since we are returning what the user returned from forward as-is, that can sometime be a tuple, depending on what the user returns.
In an earlier version of this PR I made it always pass in a tuple for consistency, but after discussion here #91222 (comment), I decided to revert that change.
| # @staticmethod | ||
| # def setup_context(ctx, outputs, x): | ||
| # y = outputs | ||
| # def setup_context(ctx, inputs, output): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you actually swap the order? That wasn't reflected in the tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just an outdated comment, this is now the correct order
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(see the python bindings)
|
@pytorchbot rebase |
|
@pytorchbot successfully started a rebase job. Check the current status here |
|
Rebase failed due to Command Raised by https://github.com/pytorch/pytorch/actions/runs/3790618497 |
|
@pytorchbot merge -g |
Merge startedYour change will be merged once all checks on your PR pass since you used the green (-g) flag (ETA: 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
Fixes #90225
Uses what was originally in #89860