Skip to content

Commit 1b2ee4d

Browse files
soulitzerpytorchmergebot
authored andcommitted
Update functorch supported autograd.Function to allow mark_dirty (#91222)
Fixes #90225 Uses what was originally in #89860 Pull Request resolved: #91222 Approved by: https://github.com/zou3519
1 parent ca39c5b commit 1b2ee4d

File tree

7 files changed

+102
-80
lines changed

7 files changed

+102
-80
lines changed

test/functorch/test_eager_transforms.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -989,7 +989,7 @@ def forward(x, y):
989989
return x, y
990990

991991
@staticmethod
992-
def setup_context(ctx, inputs, outputs):
992+
def setup_context(ctx, inputs, output):
993993
ctx.set_materialize_grads(False)
994994

995995
@staticmethod
@@ -1016,7 +1016,7 @@ def forward(x, y):
10161016
return x * y
10171017

10181018
@staticmethod
1019-
def setup_context(ctx, inputs, outputs):
1019+
def setup_context(ctx, inputs, output):
10201020
return
10211021

10221022
@staticmethod
@@ -1039,8 +1039,8 @@ def forward(input):
10391039
return torch.tensor(input_np ** 3, device=input.device), input_np
10401040

10411041
@staticmethod
1042-
def setup_context(ctx, inputs, outputs):
1043-
ctx.input_np = outputs[1]
1042+
def setup_context(ctx, inputs, output):
1043+
ctx.input_np = output[1]
10441044
ctx.device = inputs[0].device
10451045

10461046
@staticmethod
@@ -1097,7 +1097,7 @@ def forward(x):
10971097
return x.clone()
10981098

10991099
@staticmethod
1100-
def setup_context(ctx, inputs, outputs):
1100+
def setup_context(ctx, inputs, output):
11011101
return
11021102

11031103
@staticmethod
@@ -1125,8 +1125,8 @@ def forward(input):
11251125
return torch.tensor(input_np ** 3, device=input.device), dinput
11261126

11271127
@staticmethod
1128-
def setup_context(ctx, outputs, input):
1129-
ctx.save_for_backward(input, outputs[1])
1128+
def setup_context(ctx, inputs, output):
1129+
ctx.save_for_backward(inputs, output[1])
11301130

11311131
@staticmethod
11321132
def backward(ctx, grad_output, grad_saved):
@@ -1173,7 +1173,7 @@ def forward(input):
11731173
pass
11741174

11751175
@staticmethod
1176-
def setup_context(ctx, outputs, input):
1176+
def setup_context(ctx, inputs, output):
11771177
pass
11781178

11791179
@staticmethod
@@ -1199,7 +1199,7 @@ def forward(input):
11991199
pass
12001200

12011201
@staticmethod
1202-
def setup_context(ctx, outputs, input):
1202+
def setup_context(ctx, inputs, output):
12031203
pass
12041204

12051205
@staticmethod
@@ -1224,7 +1224,7 @@ def forward(input):
12241224
pass
12251225

12261226
@staticmethod
1227-
def setup_context(ctx, outputs, x, y):
1227+
def setup_context(ctx, inputs, output):
12281228
pass
12291229

12301230
@staticmethod
@@ -1249,7 +1249,7 @@ def forward(input):
12491249
return input
12501250

12511251
@staticmethod
1252-
def setup_context(ctx, outputs, input):
1252+
def setup_context(ctx, inputs, output):
12531253
pass
12541254

12551255
@staticmethod

test/functorch/test_ops.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,6 @@ def is_inplace(op, variant):
319319

320320
vjp_fail = {
321321
xfail('tensor_split'), # data_ptr composite compliance
322-
xfail('NumpyExpMarkDirtyAutogradFunction'), # https://github.com/pytorch/pytorch/issues/90225
323322
}
324323

325324
aliasing_ops = {
@@ -462,7 +461,7 @@ def wrapped_fn(*args, **kwargs):
462461
xfail('nn.functional._scaled_dot_product_attention', device_type='cuda'),
463462
464463
xfail('nn.functional.rrelu'), # in-place test errors out with no formula implemented
465-
xfail('NumpyExpMarkDirtyAutogradFunction'), # https://github.com/pytorch/pytorch/issues/90225
464+
xfail('NumpyExpMarkDirtyAutogradFunction'), # TODO: https://github.com/pytorch/pytorch/issues/91280
466465
467466
# --- Non-Contiguous Failures! ---
468467
# This is expected to fail as the operator
@@ -966,6 +965,7 @@ def test_vmapvjp(self, device, dtype, op):
966965
# skip because this is flaky depending on what the max_norm is!
967966
skip('nn.functional.embedding', ''),
968967
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
968+
xfail('NumpyExpMarkDirtyAutogradFunction'), # vmap: inplace into a regular tensor
969969
# ----------------------------------------------------------------------
970970

971971
# ---------------------------- BUGS ------------------------------------
@@ -1003,7 +1003,6 @@ def test_vmapvjp(self, device, dtype, op):
10031003
xfail("_native_batch_norm_legit"),
10041004

10051005
xfail('nn.functional.prelu'),
1006-
xfail('NumpyExpMarkDirtyAutogradFunction'), # https://github.com/pytorch/pytorch/issues/90225
10071006
# ----------------------------------------------------------------------
10081007
}
10091008

@@ -1475,6 +1474,7 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents):
14751474
14761475
# Not actually a problem
14771476
xfail('NumpyCubeNotComposableAutogradFunction'), # not composable
1477+
xfail('NumpyExpMarkDirtyAutogradFunction'), # vmap: inplace into a regular tensor
14781478
14791479
# Potential bugs/errors
14801480
xfail('as_strided'), # AssertionError: Tensor-likes are not close!
@@ -1948,7 +1948,6 @@ def f(x):
19481948
@ops(autograd_function_db, allowed_dtypes=(torch.float32,))
19491949
@skipOps('TestOperators', 'test_vmapvjpvmap', {
19501950
xfail('NumpyCubeNotComposableAutogradFunction'), # Not composable
1951-
xfail('NumpyExpMarkDirtyAutogradFunction'), # https://github.com/pytorch/pytorch/issues/90225
19521951
})
19531952
def test_vmapvjpvmap(self, device, dtype, op):
19541953
samples = op.sample_inputs(device, dtype, requires_grad=True)
@@ -1993,7 +1992,6 @@ def inner(primals, cotangents):
19931992
@ops(autograd_function_db, allowed_dtypes=(torch.float32,))
19941993
@skipOps('TestOperators', 'test_vjpvmapvmap', {
19951994
xfail('NumpyCubeNotComposableAutogradFunction'), # Not composable
1996-
xfail('NumpyExpMarkDirtyAutogradFunction'), # https://github.com/pytorch/pytorch/issues/90225
19971995
})
19981996
def test_vjpvmapvmap(self, device, dtype, op):
19991997
samples = op.sample_inputs(device, dtype, requires_grad=True)
@@ -2032,7 +2030,6 @@ def test_vjpvmapvmap(self, device, dtype, op):
20322030
@ops(autograd_function_db, allowed_dtypes=(torch.float32,))
20332031
@skipOps('TestOperators', 'test_vjpvjpvmap', {
20342032
xfail('NumpyCubeNotComposableAutogradFunction'), # Not composable
2035-
xfail('NumpyExpMarkDirtyAutogradFunction'), # https://github.com/pytorch/pytorch/issues/90225
20362033
})
20372034
def test_vjpvjpvmap(self, device, dtype, op):
20382035
samples = op.sample_inputs(device, dtype, requires_grad=True)
@@ -2063,7 +2060,6 @@ def test_vjpvjpvmap(self, device, dtype, op):
20632060
@ops(autograd_function_db, allowed_dtypes=(torch.float32,))
20642061
@skipOps('TestOperators', 'test_jvpvmap', {
20652062
xfail('NumpyCubeNotComposableAutogradFunction'), # Not composable
2066-
xfail('NumpyExpMarkDirtyAutogradFunction'), # https://github.com/pytorch/pytorch/issues/90225
20672063
})
20682064
def test_jvpvmap(self, device, dtype, op):
20692065
samples = op.sample_inputs(device, dtype, requires_grad=True)
@@ -2092,7 +2088,6 @@ def test_jvpvmap(self, device, dtype, op):
20922088
@ops(autograd_function_db, allowed_dtypes=(torch.float32,))
20932089
@skipOps('TestOperators', 'test_jvpvmapvmap', {
20942090
xfail('NumpyCubeNotComposableAutogradFunction'), # Not composable
2095-
xfail('NumpyExpMarkDirtyAutogradFunction'), # https://github.com/pytorch/pytorch/issues/90225
20962091
})
20972092
def test_jvpvmapvmap(self, device, dtype, op):
20982093
samples = op.sample_inputs(device, dtype, requires_grad=True)
@@ -2127,7 +2122,6 @@ def test_jvpvmapvmap(self, device, dtype, op):
21272122
@ops(autograd_function_db, allowed_dtypes=(torch.float32,))
21282123
@skipOps('TestOperators', 'test_vmapjvpvmap', {
21292124
xfail('NumpyCubeNotComposableAutogradFunction'), # Not composable
2130-
xfail('NumpyExpMarkDirtyAutogradFunction'), # https://github.com/pytorch/pytorch/issues/90225
21312125
})
21322126
def test_vmapjvpvmap(self, device, dtype, op):
21332127
samples = op.sample_inputs(device, dtype, requires_grad=True)
@@ -2163,7 +2157,6 @@ def test_vmapjvpvmap(self, device, dtype, op):
21632157
@ops(autograd_function_db, allowed_dtypes=(torch.float32,))
21642158
@skipOps('TestOperators', 'test_jvpjvpvmap', {
21652159
xfail('NumpyCubeNotComposableAutogradFunction'), # Not composable
2166-
xfail('NumpyExpMarkDirtyAutogradFunction'), # https://github.com/pytorch/pytorch/issues/90225
21672160
})
21682161
def test_jvpjvpvmap(self, device, dtype, op):
21692162
samples = op.sample_inputs(device, dtype, requires_grad=True)
@@ -2193,7 +2186,6 @@ def test_jvpjvpvmap(self, device, dtype, op):
21932186
@ops(autograd_function_db, allowed_dtypes=(torch.float32,))
21942187
@skipOps('TestOperators', 'test_jvpvjpvmap', {
21952188
xfail('NumpyCubeNotComposableAutogradFunction'), # Not composable
2196-
xfail('NumpyExpMarkDirtyAutogradFunction'), # https://github.com/pytorch/pytorch/issues/90225
21972189
})
21982190
def test_jvpvjpvmap(self, device, dtype, op):
21992191
samples = op.sample_inputs(device, dtype, requires_grad=True)

test/test_autograd.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@ def forward(x):
552552
return x ** 2
553553

554554
@staticmethod
555-
def setup_context(ctx, inputs, outputs):
555+
def setup_context(ctx, inputs, output):
556556
x, = inputs
557557
ctx.save_for_backward(x)
558558

@@ -576,9 +576,9 @@ def forward(x):
576576
return x ** 2, two_x
577577

578578
@staticmethod
579-
def setup_context(ctx, inputs, outputs):
579+
def setup_context(ctx, inputs, output):
580580
x, = inputs
581-
_, two_x = outputs
581+
_, two_x = output
582582
ctx.two_x = two_x
583583

584584
@staticmethod
@@ -599,7 +599,7 @@ def forward(x, shape, scale_forward, scale_backward):
599599
return x.reshape(shape) * scale_forward
600600

601601
@staticmethod
602-
def setup_context(ctx, inputs, outputs):
602+
def setup_context(ctx, inputs, output):
603603
x, shape, scale_forward, scale_backward = inputs
604604
ctx.scale_backward = scale_backward
605605
ctx.x_shape = x.shape

torch/_C/_functorch.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def _unwrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
1717
def _wrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
1818
def _unwrap_batched(tensor: Tensor, level: int) -> Tuple[Tensor, Optional[int]]: ...
1919
def current_level() -> int: ...
20+
def _add_batch_dim(tensor: Tensor, bdim: int, level: int) -> Tensor: ...
2021

2122
def set_autograd_function_allowed(allowed: bool) -> None: ...
2223
def get_autograd_function_allowed() -> bool: ...

torch/_functorch/autograd_function.py

Lines changed: 58 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
unwrap_batched,
1414
vmap,
1515
restore_vmap,
16+
_add_batch_dim,
1617
)
18+
from torch._functorch.vmap import _broadcast_to_and_flatten
1719
from torch.autograd.forward_ad import _set_fwd_grad_enabled
1820
from typing import Any, NamedTuple, Tuple
1921

@@ -101,16 +103,20 @@ def forward(*operands):
101103
# the transform. _SingleLevelFunction will turn off both fwd and bwd
102104
# gradient computation and we need to turn it back on here.
103105
with torch.enable_grad(), _set_fwd_grad_enabled(True), interpreter.lower():
104-
output = custom_function_call(autograd_function, *unwrapped_operands)
106+
unwrapped_output = custom_function_call(autograd_function, *unwrapped_operands)
105107

106-
return pytree.tree_map_only(
107-
torch.Tensor,
108-
lambda x: _wrap_for_grad(x, level),
109-
output)
108+
# See NOTE [mark_dirty object identity check]
109+
def wrap_fn(output):
110+
return _wrap_for_grad(output, level)
111+
112+
return wrap_outputs_maintaining_identity(
113+
unwrapped_output,
114+
unwrapped_operands,
115+
operands,
116+
wrap_fn)
110117

111-
def setup_context(ctx, outputs, *operands):
112-
ctx.mark_dirty = mark_dirty_error
113-
return autograd_function.setup_context(ctx, outputs, *operands)
118+
def setup_context(ctx, inputs, output):
119+
return autograd_function.setup_context(ctx, inputs, output)
114120

115121
# backward is only used if the transform is TransformType.Grad
116122
def backward(ctx, *grads):
@@ -139,24 +145,39 @@ def jvp(ctx, *tangents):
139145
)
140146
return Generated
141147

148+
# NOTE [mark_dirty object identity check]
149+
# autograd.Function's ctx.mark_dirty expect a returned input
150+
# to have the same object identity as the input.
151+
# Mode-only functorch will greatly simplify this logic.
152+
def wrap_outputs_maintaining_identity(outputs, unwrapped_inputs, orig_inputs, wrap_fn, out_dims=None):
153+
flat_unwrapped_inputs, _ = pytree.tree_flatten(unwrapped_inputs)
154+
flat_orig_inputs, _ = pytree.tree_flatten(orig_inputs)
155+
156+
unwrapped_input_to_orig_input = {
157+
id(unwrapped): orig
158+
for unwrapped, orig in zip(flat_unwrapped_inputs, flat_orig_inputs)
159+
}
160+
161+
flat_outputs, spec = pytree.tree_flatten(outputs)
162+
result = []
163+
164+
if out_dims is not None:
165+
flat_out_dims = _broadcast_to_and_flatten(out_dims, spec)
166+
167+
for i, output in enumerate(flat_outputs):
168+
if not isinstance(output, torch.Tensor):
169+
result.append(output)
170+
continue
171+
if id(output) in unwrapped_input_to_orig_input:
172+
result.append(unwrapped_input_to_orig_input[id(output)])
173+
continue
174+
if out_dims is not None:
175+
assert flat_out_dims is not None
176+
result.append(wrap_fn(output, flat_out_dims[i]))
177+
else:
178+
result.append(wrap_fn(output))
142179

143-
# https://github.com/pytorch/pytorch/issues/90225
144-
# If an input was marked as dirty, and the autograd.Function returns the input
145-
# from the forward, then the grad rule for custom_function_call must also
146-
# return the corresponding input from the forward() of the Generated autograd.Function
147-
#
148-
# We haven't figured out how to do this yet. One possibility is to rely
149-
# on if the return from the redispatched custom_function_call in Generated.forward
150-
# has the same object id as one of the inputs,
151-
# but https://github.com/pytorch/pytorch/issues/90209 means we cannot rely on
152-
# that property.
153-
def mark_dirty_error(*args, **kwargs):
154-
raise RuntimeError(
155-
'NYI: we do not yet support ctx.mark_dirty with functorch transforms. '
156-
'Please try to avoid modifying inputs to the autograd.Function in-place '
157-
'by using out-of-place operations or by cloning the inputs. '
158-
'Please see https://github.com/pytorch/pytorch/issues/90209 for more details'
159-
)
180+
return pytree.tree_unflatten(result, spec)
160181

161182

162183
# NOTE: [functorch vjp and autograd interaction]
@@ -172,8 +193,8 @@ def mark_dirty_error(*args, **kwargs):
172193
# return x.exp()
173194
#
174195
# @staticmethod
175-
# def setup_context(ctx, outputs, x):
176-
# y = outputs
196+
# def setup_context(ctx, inputs, output):
197+
# y = output
177198
# ctx.save_for_backward(y)
178199
#
179200
# @staticmethod
@@ -244,12 +265,20 @@ def custom_function_call_vmap(interpreter, autograd_function, *operands):
244265
with interpreter.lower():
245266
unwrapped_output, out_dims = autograd_function.vmap(info, in_dims, *unwrapped_operands)
246267

268+
# See NOTE [mark_dirty object identity check]
269+
def wrap_fn(output, out_dim):
270+
return output if out_dim is None else _add_batch_dim(output, out_dim, current_level)
271+
247272
# TODO: raise better error message to the user when they don't follow the API.
248273
# Should probably mimic the logic of _process_batched_inputs,
249274
# but that one is hyperspecialized on error messages.
250275
# https://github.com/pytorch/pytorch/issues/90224
251-
output = wrap_batched(unwrapped_output, out_dims, current_level)
252-
return output
276+
return wrap_outputs_maintaining_identity(
277+
unwrapped_output,
278+
unwrapped_operands,
279+
operands,
280+
wrap_fn,
281+
out_dims=out_dims)
253282

254283

255284
def custom_function_call_vmap_generate_rule(interpreter, autograd_function, *operands):

0 commit comments

Comments
 (0)