Skip to content

Commit bed250e

Browse files
committed
[generate_vmap_rule] Add generate_vmap_rule to autograd.Function
Design document: https://docs.google.com/document/d/1bIQkWXy3J35_20c_a5kchikabBW5M8_uRAhl0BIMwU4/edit This PR adds a `generate_vmap_rule` option (default False) to autograd.Function. By setting it to True, a user promises to us that their autograd.Function's {forward, backward, jvp}, if defined, only uses PyTorch operations, in addition to the other limitations of autograd.Function+functorch (such as the user not capturing any Tensors being transformed over from outside of the autograd.Function). Concretely, the approach is: - we update `custom_function_call` to accept an additional `generate_vmap_rule` argument. - The vmap rule for `custom_function_call` and `generate_vmap_rule=True` is: we construct a vmapped version of the autograd.Function and dispatch on it. - The vmapped version of the autograd.Function can be thought of like the following: if we have an autograd.Function Foo, then VmappedFoo.apply(in_dims, ...) has the same semantics as vmap(Foo.apply, in_dims...) - VmappedFoo's forward, setup_context, and backward staticmethod are vmapped versions of Foo's staticmethods. - See the design doc for more motivation and explanation Test Plan: - This PR introduces additional autograd.Function with the suffix "GenVmap" to autograd_function_db. - There are also some minor UX tests Future: - jvp support - likely more testing to come, but please let me know if you have cases that you want me to test here. ghstack-source-id: 6905e60 Pull Request resolved: #90966
1 parent 53cb80d commit bed250e

File tree

7 files changed

+497
-21
lines changed

7 files changed

+497
-21
lines changed

aten/src/ATen/functorch/Interpreter.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <ATen/functorch/VmapInterpreter.h>
55
#include <ATen/functorch/FunctionalizeInterpreter.h>
66
#include <ATen/functorch/ADInterpreters.h>
7+
#include <ATen/functorch/DynamicLayer.h>
78

89
namespace at { namespace functorch {
910

@@ -88,10 +89,10 @@ void sanityCheckStack(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
8889
auto num_args = op.schema().arguments().size();
8990
foreachTensorInplace(*stack, stack->size() - num_args, stack->size(),
9091
[](const Tensor& tensor) {
91-
92-
auto* wrapper = maybeGetTensorWrapper(tensor);
92+
auto result = unwrapIfDead(tensor);
93+
auto* wrapper = maybeGetTensorWrapper(result);
9394
TORCH_INTERNAL_ASSERT(wrapper == nullptr);
94-
auto* batched = maybeGetBatchedImpl(tensor);
95+
auto* batched = maybeGetBatchedImpl(result);
9596
TORCH_INTERNAL_ASSERT(batched == nullptr);
9697
return tensor;
9798
});

test/functorch/test_eager_transforms.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1116,7 +1116,7 @@ def f(x):
11161116

11171117
class TestAutogradFunctionVmapAPI(TestCase):
11181118
@_set_autograd_function_extension_enabled()
1119-
def test_no_vmap_staticmethod(self, device):
1119+
def test_no_vmap_staticmethod_and_no_generate_vmap_rule(self, device):
11201120
class NumpyCube(torch.autograd.Function):
11211121
@staticmethod
11221122
def forward(input):
@@ -1136,6 +1136,33 @@ def backward(ctx, grad_output, grad_saved):
11361136
with self.assertRaisesRegex(RuntimeError, 'does not have a vmap rule defined'):
11371137
vmap(NumpyCube.apply)(x)
11381138

1139+
@_set_autograd_function_extension_enabled()
1140+
def test_has_vmap_staticmethod_and_has_generate_vmap_rule(self, device):
1141+
class NumpyCube(torch.autograd.Function):
1142+
generate_vmap_rule = True
1143+
1144+
@staticmethod
1145+
def forward(input):
1146+
input_np = to_numpy(input)
1147+
dinput = torch.tensor(3 * input_np ** 2, device=input.device)
1148+
return torch.tensor(input_np ** 3, device=input.device), dinput
1149+
1150+
@staticmethod
1151+
def setup_context(ctx, outputs, input):
1152+
ctx.save_for_backward(input, outputs[1])
1153+
1154+
@staticmethod
1155+
def backward(ctx, grad_output, grad_saved):
1156+
raise RuntimeError("foobar")
1157+
1158+
@staticmethod
1159+
def vmap(infos, in_dims, x):
1160+
raise RuntimeError("foobar")
1161+
1162+
x = torch.randn(3, device=device)
1163+
with self.assertRaisesRegex(RuntimeError, 'generate_vmap_rule=True and a vmap staticmethod'):
1164+
vmap(NumpyCube.apply)(x)
1165+
11391166
@_set_autograd_function_extension_enabled()
11401167
def test_info_object(self, device):
11411168
batch_size = 10

test/functorch/test_ops.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1349,6 +1349,8 @@ def get_vjp(cotangents, *primals):
13491349
xfail('index_reduce', ''), # NYI: forward-AD for index_reduce
13501350
xfail('segment_reduce', 'lengths'), # NYI: forward-AD for segment_reduce
13511351
xfail('native_dropout_backward'), # NYI
1352+
xfail('CubeGenVmapAutogradFunction'), # NYI
1353+
xfail('SortGenVmapAutogradFunction'), # https://github.com/pytorch/pytorch/issues/90067
13521354
13531355
}))
13541356
@opsToleranceOverride('TestOperators', 'test_jvpvjp', (
@@ -1517,6 +1519,9 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents):
15171519
xfail("_native_batch_norm_legit"),
15181520
xfail('native_dropout_backward'),
15191521
xfail('nn.functional.prelu'),
1522+
1523+
xfail('CubeGenVmapAutogradFunction'), # NYI
1524+
xfail('SortGenVmapAutogradFunction'), # https://github.com/pytorch/pytorch/issues/90067
15201525
}))
15211526
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
15221527
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
@@ -1962,9 +1967,9 @@ def test_vjpvmapvmap(self, device, dtype, op):
19621967
args = [sample.input] + list(sample.args)
19631968
kwargs = sample.kwargs
19641969
generator = generate_vmap_inputs(args, kwargs, batch_size=B)
1965-
for batched_args, in_dims, kwargs in generator:
1966-
inner_vmapped_op = vmap(op, in_dims)
1967-
inner_mapped_op = functools.partial(loop, op, in_dims, 0, B)
1970+
for batched_args, inner_in_dims, kwargs in generator:
1971+
inner_vmapped_op = vmap(op, inner_in_dims)
1972+
inner_mapped_op = functools.partial(loop, op, inner_in_dims, 0, B)
19681973
generator = generate_vmap_inputs(batched_args, kwargs)
19691974
for batched_args, in_dims, kwargs in generator:
19701975
# strategy: compare vjp(vmap(vmap(op)) vs vjp(map(map(op))
@@ -1982,6 +1987,7 @@ def test_vjpvmapvmap(self, device, dtype, op):
19821987
_, vjp_fn = vjp(mapped_fn, *primals)
19831988
expected_vjps = vjp_fn(cotangents)
19841989

1990+
print(inner_in_dims, in_dims)
19851991
_, vjp_fn = vjp(vmapped_fn, *primals)
19861992
result_vjps = vjp_fn(cotangents)
19871993

torch/_C/_functorch.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def unwrap_if_dead(tensor: Tensor) -> Tensor: ...
1616
def _unwrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
1717
def _wrap_for_grad(tensor: Tensor, level: int) -> Tensor: ...
1818
def _unwrap_batched(tensor: Tensor, level: int) -> Tuple[Tensor, Optional[int]]: ...
19+
def current_level() -> int: ...
1920

2021
def set_autograd_function_allowed(allowed: bool) -> None: ...
2122
def get_autograd_function_allowed() -> bool: ...

0 commit comments

Comments
 (0)