-
Notifications
You must be signed in to change notification settings - Fork 26.3k
jacrev : Support chunked computation #89376
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/89376
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit ff0e9e1: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
CUDA Memory Summary from the different approaches (using NOTE: Have copied only the interesting part of the summary. Stacked Approach Pre-allocation Approach Single Chunk Script: Detailsimport functorch
import torch
def prod(l):
prod = 1
for el in l:
prod *= el
return prod
def fn(x, y):
return x + y, x.sum(0)
shape = (144, 144)
chunk = 10
x = torch.zeros(*shape, dtype=torch.float, device='cuda')
y = x.sum()
chunk_size = (prod(shape) + prod(shape[1:])) // chunk
# Stack approach
# jacrev_fn_chunked = functorch.jacrev(fn, (0, 1), chunk_size=chunk_size)
# jacrev_fn_chunked(x, y)
# Pre-allocate and copy approach
# jacrev_fn_chunked = functorch.jacrev(fn, (0, 1), chunk_size=chunk_size, _preallocate_and_copy=True)
# jacrev_fn_chunked(x, y)
# Single chunk
jacrev_fn = functorch.jacrev(fn, (0, 1), chunk_size=None)
jacrev_fn(x, y)
print(torch.cuda.memory_summary()) |
torch/_functorch/eager_transforms.py
Outdated
| auxiliary objects that will not be differentiated. | ||
| Default: False. | ||
| chunk_size (None or int): If specified, controls the maximum size of chunk for computing | ||
| Jacobian. Default: None. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to tweak it.
Specify what happens for None.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If None (default), we will use the maximum chunk size (this is equivalent to doing a single vmap over vjp to compute the jacobian). If not None, then we will compute the jacobian chunk_size rows at a time using vmap to vectorize the computation. Note that chunk_size=1 is equivalent to computing the jacobian row-by-row with a for-loop. If you run into memory issues computing the jacobian, please try to specify a non-None chunk_size.
Something like that
|
Just putting this out there for review on the API (chunk vs chunk_size), perf and memory benchmarks between different approaches. |
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code looks correct and clean. Let's discuss the API options with more folks (and I'll think about it as well).
| def f(x, y): | ||
| return (x.sin(), x + y), (x + 2, x.sum()) | ||
|
|
||
| for chunk_size in [1, 2, 3, 4, 7, 10]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: When we have figured out the API, we should test some extreme cases:
- check that chunk_size <= 0 raises an error
- try chunk_size = 100000 (some big number)
| def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False, | ||
| chunk_size: Optional[int] = None, | ||
| _preallocate_and_copy=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @samdow @soulitzer @Chillee for API help.
We've got a couple of options here.
- Either we have a chunk_size argument, or we have a chunks argument (for the number of total chunks).
- _preallocate_and_copy is private, or we expose it publicly.
My opinion is:
- we should make preallocate_and_copy public. In the long run, memory-planning in PT2 should save us, but idk how soon that is coming and the preallocate_and_copy code is simple enough to maintain.
- I have a slight preference for
chunks:- If jacrev(f)(x) OOMs, the user needs to try out a chunks/chunks_size argument. If the API is
chunks, then it is clear what the next number the user should try is: 2, and they can keep incrementing this until they're satisfied. If the API ischunk_size, the user just tosses random numbers or needs to compute the size of their jacobian to figure out what the max chunk_size is so they know what the range of numbers to try is.
- If jacrev(f)(x) OOMs, the user needs to try out a chunks/chunks_size argument. If the API is
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some nuggets of wisdom from Horace:
- chunk_size seems nicer, because (1) it is batch-agnostic and (2) it's like loop unrolling - you loop unroll some number of things instead of indicate how many times you want to loop unroll.
- _preallocate_and_copy should NOT be public. PT2 is around the corner and already includes this optimization (assuming the entire jacrev call can be captured); one of the goals of compilation is to remove code smells like this.
Also, we can always update chunk_size to have different behavior, if we really want.
- if the number is >= 1 then it is a chunk size
- if the number is between 0 and 1 then it is 1/chunks.
…to dev/jacrev/looped
|
|
||
| # NB: numpy is a testing dependency! | ||
| import numpy as np | ||
|
|
||
| USE_TORCHVISION = False | ||
| try: | ||
| import torchvision # noqa: F401 | ||
| USE_TORCHVISION = True | ||
| except ImportError: | ||
| warnings.warn("Couldn't import torchvision. Some of our tests use it, try " | ||
| "to install it with commands from pytorch.org, post-fixed with " | ||
| "`--no-deps` to avoid overwriting the pytorch installation", | ||
| UserWarning) | ||
|
|
||
| # TestCase for _slice_argnums, an important helper funciton | ||
|
|
||
|
|
||
| class TestSliceArgnums(TestCase): | ||
| def test_invalid_argnum_type(self): | ||
| x = torch.randn(3) | ||
| args = (x,) | ||
| with self.assertRaisesRegex(RuntimeError, "int or Tuple"): | ||
| _slice_argnums(args, 0.0) | ||
| with self.assertRaisesRegex(RuntimeError, "int or Tuple"): | ||
| _slice_argnums(args, [0]) | ||
| with self.assertRaisesRegex(RuntimeError, "must be int"): | ||
| _slice_argnums(args, (0.0,)) | ||
|
|
||
| args = (0.1, 1.1, 2.1, 3.1, 4.1) | ||
|
|
||
| with self.assertRaisesRegex(RuntimeError, "must be int"): | ||
| _slice_argnums(args, ((0, 1), 2)) | ||
|
|
||
| def test_out_of_bounds_argnum_values(self): | ||
| x = torch.randn(3) | ||
| args = (x,) | ||
| with self.assertRaisesRegex(RuntimeError, "positional inputs"): | ||
| _slice_argnums(args, 1) | ||
| with self.assertRaisesRegex(RuntimeError, "positional inputs"): | ||
| _slice_argnums(args, -2) | ||
| with self.assertRaisesRegex(RuntimeError, "positional inputs"): | ||
| _slice_argnums(args, (-2,)) | ||
|
|
||
| def test_not_enough_argnums(self): | ||
| x = torch.randn(3) | ||
| args = (x,) | ||
| with self.assertRaisesRegex(RuntimeError, "must be non-empty"): | ||
| _slice_argnums(args, ()) | ||
|
|
||
| def test_duplicate_argnums(self): | ||
| x = torch.randn(3) | ||
| args = (x, x) | ||
| with self.assertRaisesRegex(RuntimeError, "must be unique"): | ||
| _slice_argnums(args, (0, 0)) | ||
| with self.assertRaisesRegex(RuntimeError, "must be unique"): | ||
| _slice_argnums(args, (0, -2)) | ||
|
|
||
| def test_flat_args_with_positive_int_argnum(self): | ||
| args = (0.1, 1.1, 2.1, 3.1, 4.1) | ||
|
|
||
| res = _slice_argnums(args, 0) | ||
| self.assertEqual(res, (0.1,)) | ||
|
|
||
| res = _slice_argnums(args, 4) | ||
| self.assertEqual(res, (4.1,)) | ||
|
|
||
| def test_flat_args_with_negative_int_argnum(self): | ||
| args = (0.1, 1.1, 2.1, 3.1, 4.1) | ||
|
|
||
| res = _slice_argnums(args, -1) | ||
| self.assertEqual(res, (4.1,)) | ||
|
|
||
| res = _slice_argnums(args, -5) | ||
| self.assertEqual(res, (0.1,)) | ||
|
|
||
| def test_flat_args_with_tuple_argnum(self): | ||
| args = (0.1, 1.1, 2.1, 3.1, 4.1) | ||
|
|
||
| res = _slice_argnums(args, (0, 1, 2, 3, 4)) | ||
| self.assertEqual(res, args) | ||
|
|
||
| res = _slice_argnums(args, (0, -3)) | ||
| self.assertEqual(res, (0.1, 2.1)) | ||
|
|
||
| def test_pytree_args(self): | ||
| args = ((0.1, 1.1), 2.0, [3.1]) | ||
|
|
||
| res = _slice_argnums(args, 0) | ||
| self.assertEqual(res, args[0:1]) | ||
|
|
||
| res = _slice_argnums(args, (0,)) | ||
| self.assertEqual(res, args[0:1]) | ||
|
|
||
| res = _slice_argnums(args, -1) | ||
| self.assertEqual(res, args[-1:]) | ||
|
|
||
| res = _slice_argnums(args, (0, -2)) | ||
| self.assertEqual(res, args[0:2]) | ||
|
|
||
| def test_argnums_reorders(self): | ||
| args = ((0.1, 1.1, 2.1), 3.1, 4.1) | ||
|
|
||
| res = _slice_argnums(args, (1, 0)) | ||
| self.assertEqual(res, (args[1], args[0])) | ||
|
|
||
|
|
||
| class TestGradTransform(TestCase): | ||
| def test_primitive(self, device): | ||
| x = torch.randn([], device=device) | ||
| result = grad(torch.sin)(x) | ||
| self.assertEqual(result, torch.cos(x)) | ||
|
|
||
| def test_composite_simple(self, device): | ||
| x = torch.randn(2, 3, 4, device=device) | ||
| result = grad(lambda x: torch.flatten(x).sum())(x) | ||
| self.assertEqual(result, torch.ones_like(x)) | ||
|
|
||
| def test_fn_with_kwargs(self, device): | ||
| def foo(x, y): | ||
| return (x * y).sum() | ||
|
|
||
| x = torch.randn(3, device=device) | ||
| y = torch.randn(3, device=device) | ||
| expected = grad(foo)(x, y) | ||
| result = grad(foo)(x, y=y) | ||
| self.assertEqual(result, expected) | ||
|
|
||
| def test_composite_complicated(self, device): | ||
| x = torch.randn(3, device=device) | ||
| y = torch.randn(3, 5, device=device) | ||
|
|
||
| def foo(x, y): | ||
| result = x @ y | ||
| return result.sum() | ||
|
|
||
| result = grad(foo)(x, y) | ||
|
|
||
| x.requires_grad_() | ||
| out = foo(x, y) | ||
| expected, = torch.autograd.grad(out, x) | ||
|
|
||
| self.assertEqual(result, expected) | ||
|
|
||
| def test_composite_two_ops(self, device): | ||
| N, C = 2, 5 | ||
| y = torch.randn(N, C, device=device) | ||
| targets = torch.randint(0, C, (N,), device=device) | ||
|
|
||
| def foo(y, targets): | ||
| return F.cross_entropy(y, targets) | ||
|
|
||
| result = grad(foo)(y, targets) | ||
|
|
||
| y.requires_grad_() | ||
| expected, = torch.autograd.grad(foo(y, targets), y) | ||
|
|
||
| self.assertEqual(result, expected) | ||
|
|
||
| def _test_attributes(self, get_attr_lambda, device): | ||
| x = torch.randn(2, 3, 5, dtype=torch.double, device=device) | ||
| expected = get_attr_lambda(x) | ||
|
|
||
| def foo(x): | ||
| self.assertEqual(get_attr_lambda(x), expected) | ||
| return x.sum() | ||
|
|
||
| grad(foo)(x) | ||
|
|
||
| def test_shape(self, device): | ||
| self._test_attributes(lambda x: x.shape, device) | ||
|
|
||
| def test_dtype(self, device): | ||
| self._test_attributes(lambda x: x.dtype, device) | ||
|
|
||
| def test_is_cuda(self, device): | ||
| self._test_attributes(lambda x: x.is_cuda, device) | ||
|
|
||
| def test_numel(self, device): | ||
| self._test_attributes(lambda x: x.numel(), device) | ||
|
|
||
| def test_inplace(self, device): | ||
| x = torch.randn([], device=device) | ||
|
|
||
| def foo(x): | ||
| return x.clone().sin_() | ||
|
|
||
| result = grad(foo)(x) | ||
| self.assertEqual(result, x.cos()) | ||
|
|
||
| def test_inplace_on_view(self, device): | ||
| x = torch.randn(3, device=device) | ||
|
|
||
| def foo(x): | ||
| y = x.clone() | ||
| y0 = y[0] | ||
| y0.sin_() | ||
| return y.sum() | ||
|
|
||
| result = grad(foo)(x) | ||
|
|
||
| x.requires_grad_() | ||
| out = foo(x) | ||
| expected, = torch.autograd.grad(out, x) | ||
|
|
||
| self.assertEqual(result, expected) | ||
|
|
||
| def test_inplace_on_view_base(self, device): | ||
| x = torch.randn(3, device=device) | ||
|
|
||
| def foo(x): | ||
| y = x.clone() | ||
| y0 = y[0] | ||
| y.sin_() | ||
| return y0 | ||
|
|
||
| result = grad(foo)(x) | ||
|
|
||
| x.requires_grad_() | ||
| out = foo(x) | ||
| expected, = torch.autograd.grad(out, x) | ||
|
|
||
| self.assertEqual(result, expected) | ||
|
|
||
| def test_inplace_on_captures(self, device): | ||
| x = torch.tensor([1., 2., 3.], device=device) | ||
| captured = torch.randn(3, device=device) | ||
|
|
||
| def foo(x): | ||
| captured.copy_(x) | ||
| return (x * captured).sum() | ||
|
|
||
| with self.assertRaisesRegex(RuntimeError, 'mutate a captured Tensor'): | ||
| grad(foo)(x) | ||
|
|
||
| def test_nesting_simple(self, device): | ||
| x = torch.randn([], device=device) | ||
| result = grad(grad(torch.sin))(x) | ||
| self.assertEqual(result, -torch.sin(x)) | ||
|
|
||
| def test_escaped_wrappers_are_marked_as_dead(self, device): | ||
| x = torch.randn([], device=device) | ||
| escaped = [] | ||
|
|
||
| def foo(x): | ||
| y = x.sin() | ||
| escaped.append(y) | ||
| return y | ||
|
|
||
| grad(foo)(x) | ||
| self.assertEqual(torch._C._functorch.dlevel(escaped[0]), -1) | ||
|
|
||
| def test_escaped_wrappers_are_ignored(self, device): | ||
| x = torch.randn([], device=device) | ||
| escaped = [] | ||
|
|
||
| def foo(x): | ||
| y = x.sin() | ||
| escaped.append(y) | ||
| return y | ||
|
|
||
| grad(foo)(x) | ||
|
|
||
| something = escaped[0].sum() | ||
| self.assertEqual(torch._C._functorch.dlevel(something), 0) | ||
| self.assertEqual(something, x.sin().sum()) | ||
|
|
||
| def test_manual_seed_inside_grad(self, device): | ||
| x = torch.randn([], device=device) | ||
|
|
||
| def f(x): | ||
| torch.manual_seed(0) | ||
| return x * torch.randn_like(x) | ||
|
|
||
| with freeze_rng_state(): | ||
| result = grad(f)(x) | ||
| x.requires_grad_() | ||
| expected, = torch.autograd.grad(f(x), x) | ||
| self.assertEqual(result, expected) | ||
|
|
||
| def test_vjp(self, device): | ||
| x = torch.randn([], device=device) | ||
| out, vjp_fn = vjp(torch.sin, x) | ||
| self.assertEqual(out, x.sin()) | ||
|
|
||
| v = torch.randn([], device=device) | ||
| result, = vjp_fn(v) | ||
| self.assertEqual(result, v * x.cos()) | ||
|
|
||
| def test_vjp_two_outputs(self, device): | ||
| def f(x): | ||
| return x, x | ||
| result, vjp_fn = vjp(f, torch.tensor(1.)) | ||
| vjp_fn(result) | ||
|
|
||
| def test_conj_bit(self): | ||
| x = torch.tensor(1 + 1j) | ||
|
|
||
| def foo(x): | ||
| assert not x.is_conj() | ||
| y = x.conj() | ||
| assert y.is_conj() | ||
| return y | ||
| res = grad(foo)(x) | ||
| self.assertEqual(res, torch.ones_like(res)) | ||
|
|
||
| def test_composed_with_autograd(self, device): | ||
| x = torch.randn([], requires_grad=True, device=device) | ||
|
|
||
| y = grad(torch.sin)(x) | ||
| result, = torch.autograd.grad(y, x) | ||
| self.assertEqual(result, -x.sin()) | ||
|
|
||
| def test_grad_of_vjp_composition(self, device): | ||
| x = torch.randn([], device=device) | ||
| y = torch.randn([], device=device) | ||
|
|
||
| def foo(x, y): | ||
| out, vjp_fn = vjp(torch.sin, x) | ||
| return grad(lambda y: vjp_fn(y)[0])(y) | ||
|
|
||
| result = foo(x, y) | ||
| expected = x.cos() | ||
| self.assertEqual(result, expected) | ||
|
|
||
| def test_vjp_of_grad_composition(self, device): | ||
| x = torch.randn([], device=device) | ||
| y = torch.randn([], device=device) | ||
|
|
||
| def foo(x, y): | ||
| out, vjp_fn = vjp(grad(torch.sin), x) | ||
| return vjp_fn(y)[0] | ||
|
|
||
| result = foo(x, y) | ||
| expected = -y * x.sin() | ||
| self.assertEqual(result, expected) | ||
|
|
||
| def test_grad_of_vjp_of_grad_composition(self, device): | ||
| x = torch.randn([], device=device) | ||
| y = torch.randn([], device=device) | ||
|
|
||
| def foo(x, y): | ||
| df, vjp_fn = vjp(grad(lambda x: -torch.cos(x)), x) | ||
| return grad(lambda y: vjp_fn(y)[0])(y) | ||
|
|
||
| result = foo(x, y) | ||
| expected = x.cos() | ||
| self.assertEqual(result, expected) | ||
|
|
||
| def test_views(self, device): | ||
| x = torch.randn([], requires_grad=True, device=device) | ||
| y = torch.randn([], requires_grad=True, device=device) | ||
|
|
||
| def silly_sin(x): | ||
| x = x.view([]) | ||
| x = x.sin() | ||
| return x | ||
|
|
||
| def foo(x, y): | ||
| z1 = grad(silly_sin)(x) | ||
| z2 = torch.cos(y) | ||
| return z1 + z2 | ||
|
|
||
| result = foo(x, y) | ||
| grads = torch.autograd.grad(result, [x, y]) | ||
| self.assertEqual(grads[0], -x.sin()) | ||
| self.assertEqual(grads[1], -y.sin()) | ||
|
|
||
| def test_view_inplace_simple(self, device): | ||
| def foo(x): | ||
| x = x.clone() | ||
| x.view([]).sin_() | ||
| return x | ||
|
|
||
| x = torch.randn([], requires_grad=True, device=device) | ||
| result = grad(foo)(x) | ||
| self.assertEqual(result, x.cos()) | ||
|
|
||
| def test_invalid_argnums(self, device): | ||
| x = torch.randn([]) | ||
| y = torch.randn([]) | ||
| with self.assertRaisesRegex(RuntimeError, 'but only'): | ||
| grad(torch.mul, argnums=-3)(x, y) | ||
| with self.assertRaisesRegex(RuntimeError, 'but only'): | ||
| grad(torch.mul, argnums=2)(x, y) | ||
| with self.assertRaisesRegex(RuntimeError, 'int or Tuple'): | ||
| grad(torch.mul, argnums=[0])(x, y) | ||
| with self.assertRaisesRegex(RuntimeError, 'must be int'): | ||
| grad(torch.mul, argnums=('0',))(x, y) | ||
| with self.assertRaisesRegex(RuntimeError, 'must be unique'): | ||
| grad(torch.mul, argnums=(0, 0))(x, y) | ||
| with self.assertRaisesRegex(RuntimeError, 'must be unique'): | ||
| grad(torch.mul, argnums=(0, -2))(x, y) | ||
|
|
||
| def test_argnums(self, device): | ||
| x = torch.randn([]) | ||
| y = torch.randn([]) | ||
| gx = grad(torch.mul, argnums=0)(x, y) | ||
| self.assertEqual(gx, y) | ||
|
|
||
| gy = grad(torch.mul, argnums=1)(x, y) | ||
| self.assertEqual(gy, x) | ||
|
|
||
| gx, = grad(torch.mul, argnums=(0,))(x, y) | ||
| self.assertEqual(gx, y) | ||
|
|
||
| gx, gy = grad(torch.mul, argnums=(0, 1))(x, y) | ||
| self.assertEqual(gx, y) | ||
| self.assertEqual(gy, x) | ||
|
|
||
| def test_out_of_order_argnums(self, device): | ||
| x = torch.randn([]) | ||
| y = torch.randn([]) | ||
| gy, gx = grad(torch.mul, argnums=(1, 0))(x, y) | ||
| self.assertEqual(gx, y) | ||
| self.assertEqual(gy, x) | ||
|
|
||
| def test_negative_argnums(self, device): | ||
| x = torch.randn([]) | ||
| y = torch.randn([]) | ||
| gx = grad(torch.mul, argnums=-2)(x, y) | ||
| self.assertEqual(gx, y) | ||
|
|
||
| gy = grad(torch.mul, argnums=-1)(x, y) | ||
| self.assertEqual(gy, x) | ||
|
|
||
| gx, = grad(torch.mul, argnums=(-2,))(x, y) | ||
| self.assertEqual(gx, y) | ||
|
|
||
| gx, gy = grad(torch.mul, argnums=(-2, -1))(x, y) | ||
| self.assertEqual(gx, y) | ||
| self.assertEqual(gy, x) | ||
|
|
||
| def test_grad_pytree_inputs(self, device): | ||
| x = torch.randn([], device=device) | ||
|
|
||
| def f(a, b): | ||
| x, y = a | ||
| return 1 * x + 2 * y + 3 * b['foo'] | ||
|
|
||
| args = ((x, x), {'foo': x}) | ||
|
|
||
| gx, gy = grad(f)(*args) | ||
| self.assertEqual(gx, torch.tensor(1., device=device)) | ||
| self.assertEqual(gy, torch.tensor(2., device=device)) | ||
|
|
||
| (gx, gy), = grad(f, argnums=(0,))(*args) | ||
| self.assertEqual(gx, torch.tensor(1., device=device)) | ||
| self.assertEqual(gy, torch.tensor(2., device=device)) | ||
|
|
||
| (gx, gy), gz = grad(f, argnums=(0, 1))(*args) | ||
| self.assertEqual(gx, torch.tensor(1., device=device)) | ||
| self.assertEqual(gy, torch.tensor(2., device=device)) | ||
| self.assertEqual(gz['foo'], torch.tensor(3., device=device)) | ||
|
|
||
| def test_grad_aux_tensor(self, device): | ||
|
|
||
| x = torch.randn(3, device=device) | ||
|
|
||
| with self.assertRaisesRegex( | ||
| RuntimeError, | ||
| r'grad_and_value\(f\)\(\*args\): output of function f should be a tuple' | ||
| ): | ||
| grad(lambda t: [t, t], has_aux=True)(x) | ||
|
|
||
| with self.assertRaisesRegex( | ||
| RuntimeError, | ||
| r'grad_and_value\(f\)\(\*args\): output of function f should be a tuple' | ||
| ): | ||
| grad(lambda t: (t, t + 2, t + 3), has_aux=True)(x) | ||
|
|
||
| def f(t): | ||
| y = t.sin() | ||
| return y.sum(), t.cos() | ||
|
|
||
| out, aux = grad(f, has_aux=True)(x) | ||
| self.assertEqual(aux, x.cos()) | ||
| self.assertEqual(out, x.cos()) | ||
|
|
||
| def test_grad_aux_pytree(self, device): | ||
| def f(x): | ||
| y = x.sin() | ||
| return y.sum(), {'a': x.cos(), 'b': [x.tan()]} | ||
|
|
||
| x = torch.randn(3, device=device) | ||
|
|
||
| out, aux = grad(f, has_aux=True)(x) | ||
| _, expected_aux = f(x) | ||
| self.assertEqual(aux, expected_aux) | ||
| self.assertEqual(out, x.cos()) | ||
|
|
||
| for aux in [1, 1.0, "abc"]: | ||
| with self.assertRaisesRegex(RuntimeError, r"Expected tensors, got unsupported type"): | ||
| _ = grad(lambda x: (x.sum(), aux), has_aux=True)(x) | ||
| with self.assertRaisesRegex(RuntimeError, r"Expected tensors, got unsupported type"): | ||
| _ = grad(lambda x: (x.sum(), [x, aux]), has_aux=True)(x) | ||
|
|
||
| def test_zero_grad(self, device): | ||
| def f(x): | ||
| return (x['a']**2.0).sum() | ||
| inps = ({'a': torch.randn(10, device=device) + 3, 'b': torch.randn(10, device=device)}) | ||
| grads = grad(f)(inps) | ||
| self.assertNotEqual(grads['a'].sum(), 0.0) | ||
| self.assertEqual(grads['b'].sum(), 0.0) | ||
|
|
||
| def test_unrelated_grad(self, device): | ||
| x = torch.tensor(1., device=device) | ||
| y = torch.tensor(2., device=device) | ||
|
|
||
| def unrelated(x): | ||
| return y | ||
|
|
||
| result = grad(unrelated)(x) | ||
| self.assertEqual(result, torch.zeros_like(x)) | ||
|
|
||
| def test_unrelated_vjp(self, device): | ||
| x = torch.tensor(1., device=device) | ||
| y = torch.tensor(2., device=device) | ||
| v = torch.tensor(1., device=device) | ||
|
|
||
| def unrelated(x): | ||
| return y | ||
|
|
||
| out, vjp_fn = vjp(unrelated, x) | ||
| result = vjp_fn(v) | ||
| expected = (torch.zeros_like(x),) | ||
| self.assertEqual(result, expected) | ||
|
|
||
| def test_unrelated_vjp_multiple_inputs_outputs(self, device): | ||
| w = torch.tensor(3., device=device) | ||
| x = torch.tensor(4., device=device) | ||
| y = torch.tensor(2., device=device) | ||
| v = torch.tensor(1., device=device) | ||
|
|
||
| def unrelated(w, x): | ||
| return y, y, x | ||
|
|
||
| out, vjp_fn = vjp(unrelated, w, x) | ||
| result = vjp_fn((v, v, v)) | ||
| expected = (torch.zeros_like(x), torch.ones_like(x)) | ||
| self.assertEqual(result, expected) | ||
|
|
||
| # TODO: https://github.com/zou3519/functorch/issues/12 | ||
| @onlyCPU | ||
| def test_unrelated_hessian(self, device): | ||
| N = 5 | ||
| M = 3 | ||
| W = torch.randn(N, M, device=device) | ||
|
|
||
| def f(x): | ||
| return W @ x | ||
|
|
||
| x = torch.randn(M) | ||
| result = jacrev(jacrev(f))(x) | ||
| expected = torch.zeros(N, M, M, device=device) | ||
| self.assertEqual(result, expected) | ||
|
|
||
| def test_vjp_pytree_input(self, device): | ||
| def f(x): | ||
| return x[0] * x[1][0] | ||
|
|
||
| x = torch.randn([], device=device) | ||
| v = torch.randn([], device=device) | ||
| out, vjp_fn = vjp(f, (x, (x, x))) | ||
| self.assertEqual(out, x * x) | ||
| result = vjp_fn(v) | ||
| self.assertEqual(result, ((x * v, (x * v, 0.)),)) | ||
|
|
||
| def test_vjp_pytree_output(self, device): | ||
| def f(x): | ||
| return x, (x, x) | ||
|
|
||
| x = torch.randn([], device=device) | ||
| v1 = torch.randn([], device=device) | ||
| v2 = torch.randn([], device=device) | ||
| v3 = torch.randn([], device=device) | ||
| _, vjp_fn = vjp(f, x) | ||
| result, = vjp_fn((v1, (v2, v3))) | ||
| self.assertEqual(result, v1 + v2 + v3) | ||
|
|
||
| def test_vjp_outputs_can_any_pytree(self, device): | ||
| x = torch.randn(2, 3, device=device) | ||
| t = torch.randn(2, 3, device=device) | ||
|
|
||
| for output in [None, ()]: | ||
| with self.assertRaisesRegex( | ||
| RuntimeError, r"vjp\(f, \*primals\): Expected f to be a function that has non-empty output" | ||
| ): | ||
| _, vjp_fn = vjp(lambda _: output, x) | ||
| vjp_fn(t) | ||
|
|
||
| for output in [1, True, 12.2, "abc"]: | ||
| with self.assertRaisesRegex( | ||
| RuntimeError, r"vjp\(f, \*primals\): expected f\(\*primals\) to return only tensors" | ||
| ): | ||
| _, vjp_fn = vjp(lambda _: output, x) | ||
| vjp_fn(t) | ||
|
|
||
| # Check list output | ||
| output, vjp_fn = vjp(lambda x: [x, x.sum()], x) | ||
| vjp_out, = vjp_fn([t, t.sum()]) | ||
| assert isinstance(output, list) and len(output) == 2 | ||
| assert isinstance(vjp_out, torch.Tensor) | ||
|
|
||
| # Check dict output | ||
| output, vjp_fn = vjp(lambda x: {"x": x, "xsum": x.sum()}, x) | ||
| vjp_out, = vjp_fn({"x": t, "xsum": t.sum()}) | ||
| assert isinstance(output, dict) and len(output) == 2 and "xsum" in output | ||
| assert isinstance(vjp_out, torch.Tensor) | ||
|
|
||
| def composite_output(x): | ||
| out = x.sum() | ||
| return [ | ||
| (out, {"a": x, "out": [x, out]}), | ||
| ] | ||
|
|
||
| output, vjp_fn = vjp(composite_output, x) | ||
| vjp_out, = vjp_fn([(t.sum(), {"a": t, "out": [t, t.sum()]}), ]) | ||
| assert isinstance(output, list) | ||
| assert isinstance(output[0], tuple) and isinstance(output[0][1], dict) | ||
| assert isinstance(vjp_out, torch.Tensor) | ||
|
|
||
| def test_vjp_pytree_error(self, device): | ||
| def f(x): | ||
| return x, (x, x) | ||
|
|
||
| x = torch.randn([], device=device) | ||
| v1 = torch.randn([], device=device) | ||
| v2 = torch.randn([], device=device) | ||
| v3 = torch.randn([], device=device) | ||
| _, vjp_fn = vjp(f, x) | ||
| with self.assertRaisesRegex(RuntimeError, 'Expected pytree structure'): | ||
| result, = vjp_fn(((v1, (v2, v3)),)) | ||
|
|
||
| def test_vjp_aux_tensor(self, device): | ||
|
|
||
| x = torch.randn(3, device=device) | ||
|
|
||
| with self.assertRaisesRegex(RuntimeError, r'vjp\(f, \*primals\): output of function f should be a tuple'): | ||
| vjp(lambda t: [t, t], x, has_aux=True) | ||
|
|
||
| with self.assertRaisesRegex(RuntimeError, r'vjp\(f, \*primals\): output of function f should be a tuple'): | ||
| vjp(lambda t: (t, t + 2, t + 3), x, has_aux=True) | ||
|
|
||
| def f(t): | ||
| y = t.sin() | ||
| return y, t.cos() | ||
|
|
||
| out, vjp_fn, aux = vjp(f, x, has_aux=True) | ||
| self.assertEqual(aux, x.cos()) | ||
| self.assertEqual(out, x.sin()) | ||
|
|
||
| v = torch.randn(3, device=device) | ||
| grad_x, = vjp_fn(v) | ||
| self.assertEqual(grad_x, v * x.cos()) | ||
|
|
||
| def test_vjp_aux_pytree(self, device): | ||
| def f(x): | ||
| y = x.sin() | ||
| return y, {'a': x.cos(), 'b': [x.tan()]} | ||
|
|
||
| x = torch.randn(3, device=device) | ||
|
|
||
| out, vjp_fn, aux = vjp(f, x, has_aux=True) | ||
| expected_out, expected_aux = f(x) | ||
| self.assertEqual(out, expected_out) | ||
| self.assertEqual(aux, expected_aux) | ||
|
|
||
| v = torch.randn(3, device=device) | ||
| grad_x, = vjp_fn(v) | ||
| self.assertEqual(grad_x, v * x.cos()) | ||
|
|
||
| for aux in [1, 1.0, "abc"]: | ||
| with self.assertRaisesRegex(RuntimeError, r"Expected tensors, got unsupported type"): | ||
| _ = vjp(lambda x: (x, aux), x, has_aux=True) | ||
| with self.assertRaisesRegex(RuntimeError, r"Expected tensors, got unsupported type"): | ||
| _ = vjp(lambda x: (x, [x, aux]), x, has_aux=True) | ||
|
|
||
| def test_functional_init(self, device): | ||
| class MLPClassifier(nn.Module): | ||
| def __init__(self, hidden_dim=32, n_classes=2): | ||
| super().__init__() | ||
| self.hidden_dim = hidden_dim | ||
| self.n_classes = n_classes | ||
|
|
||
| self.fc1 = nn.Linear(2, self.hidden_dim) | ||
| self.fc2 = nn.Linear(self.hidden_dim, self.n_classes) | ||
|
|
||
| def forward(self, x): | ||
| x = self.fc1(x) | ||
| x = F.relu(x) | ||
| x = self.fc2(x) | ||
| x = F.log_softmax(x, -1) | ||
| return x | ||
|
|
||
| B = 10 | ||
| weights, fn, _ = functional_init(MLPClassifier, (B,), device=device)(32, 2) | ||
| inputs = torch.randn(B, 7, 2, device=device) | ||
| vmap(fn)(weights, (inputs,)) | ||
|
|
||
| def test_functional_init_with_buffers(self, device): | ||
| class MLPClassifier(nn.Module): | ||
| def __init__(self, hidden_dim=32, n_classes=2): | ||
| super().__init__() | ||
| self.hidden_dim = hidden_dim | ||
| self.n_classes = n_classes | ||
|
|
||
| self.fc1 = nn.Linear(2, self.hidden_dim) | ||
| self.bn = nn.BatchNorm1d(self.hidden_dim, affine=True) | ||
| self.fc2 = nn.Linear(self.hidden_dim, self.n_classes) | ||
|
|
||
| def forward(self, x): | ||
| x = self.fc1(x) | ||
| x = F.relu(x) | ||
| x = self.bn(x) | ||
| x = self.fc2(x) | ||
| x = F.log_softmax(x, -1) | ||
| return x | ||
|
|
||
| B = 10 | ||
| weights, buffers, fn, _, _ = \ | ||
| functional_init_with_buffers(MLPClassifier, [B], device=device)(32, 2) | ||
| inputs = torch.randn(B, 7, 2, device=device) | ||
| vmap(fn)(weights, buffers, (inputs,)) | ||
|
|
||
| def test_advanced_indexing(self, device): | ||
| def f(value): | ||
| log_prob = torch.ones((), device=device) | ||
| val = (torch.zeros(()) > 0) | ||
| log_prob[val] = 0 | ||
| return value | ||
|
|
||
| result = grad(f)(torch.randn((), device=device)) | ||
| self.assertEqual(result, torch.ones_like(result)) | ||
|
|
||
| def f2(value): | ||
| value = value.clone() | ||
| value[value > 0] = 0 | ||
| return value.sum() | ||
|
|
||
| x = torch.randn(100, device=device) | ||
| result = grad(f2)(x) | ||
| self.assertEqual(result, (x <= 0).type_as(x)) | ||
|
|
||
| def test_tensor_ctor_inside_grad(self, device): | ||
| def foo(x): | ||
| return x * torch.tensor(2., device=device) | ||
|
|
||
| x = torch.tensor(3.14, device=device) | ||
| functorch.grad(foo)(x) | ||
|
|
||
| @parametrize("op_list_data", [ | ||
| subtest(([vmap, ], [(4, 2), (64, 3, 32, 32)]), name='vmap'), | ||
| subtest(([vmap, vmap], [(4, 3, 2), (64, 3, 32, 32)]), name='vmap_vmap'), | ||
| subtest(([grad, ], [(0, ), [], (4, 2), (64, 3, 32, 32)]), name='grad'), | ||
| subtest(([grad, grad], [[], ]), name='grad_grad'), | ||
| subtest(([vmap, grad], [(4, 2)]), name='vmap_grad'), | ||
| ]) | ||
| def test_tensor_print(self, device, op_list_data): | ||
|
|
||
| op_list, shapes = op_list_data | ||
|
|
||
| for dt in get_all_fp_dtypes(): | ||
| data = [torch.randn(s, dtype=dt, device=device) for s in shapes] | ||
|
|
||
| for x in data: | ||
| buf = None | ||
|
|
||
| def foo(t): | ||
| nonlocal buf | ||
| buf = repr(t) | ||
| return t.mean() | ||
|
|
||
| fn = foo | ||
| bdim = 0 | ||
| for op in reversed(op_list): | ||
| if op == vmap: | ||
| fn = op(fn, in_dims=bdim) | ||
| bdim += 1 | ||
| else: | ||
| fn = op(fn) | ||
|
|
||
| expected = f"{repr(x)}" | ||
| level = 0 | ||
| for op in op_list: | ||
| level += 1 | ||
| if op == grad: | ||
| expected = f"GradTrackingTensor(lvl={level}, value={expected})" | ||
| elif op == vmap: | ||
| bdim -= 1 | ||
| expected = f"BatchedTensor(lvl={level}, bdim={bdim}, value={expected})" | ||
|
|
||
| fn(x) | ||
| buf = buf.replace("\n", "").replace(" ", "") | ||
| expected = expected.replace("\n", "").replace(" ", "") | ||
| self.assertEqual(expected, buf) | ||
|
|
||
| def test_print_captured_tensor_inside_transform(self, device): | ||
| x = torch.tensor([1., 2., 3.], device=device) | ||
| out = None | ||
|
|
||
| def f(y): | ||
| nonlocal out | ||
| out = repr(x) | ||
| return y | ||
|
|
||
| vjp(f, torch.randn(4, device=device)) | ||
| self.assertEqual(out, repr(x)) | ||
|
|
||
| def test_no_grad_outside(self, device): | ||
| x = torch.randn([], device=device, requires_grad=True) | ||
| with torch.no_grad(): | ||
| y = grad(torch.sin)(x) | ||
| self.assertEqual(y, x.cos()) | ||
| self.assertFalse(y.requires_grad) | ||
|
|
||
| def test_no_grad_inside(self, device): | ||
| def f(x): | ||
| with torch.no_grad(): | ||
| shift = x ** 2 | ||
| return x ** 2 - shift | ||
|
|
||
| x = torch.randn([], device=device) | ||
| y = grad(f)(x) | ||
| self.assertEqual(y, 2 * x) | ||
| y = grad(grad(f))(x) | ||
| self.assertEqual(y, 2) | ||
|
|
||
| x = torch.randn([], device=device, requires_grad=True) | ||
| y = grad(f)(x) | ||
| z, = torch.autograd.grad(y, x) | ||
| self.assertEqual(z, 2) | ||
|
|
||
| def test_no_grad_mixed(self, device): | ||
| def f(x): | ||
| with torch.no_grad(): | ||
| shift = x ** 2 | ||
| return x ** 2 - shift | ||
|
|
||
| x = torch.randn([], device=device, requires_grad=True) | ||
| with torch.no_grad(): | ||
| y = grad(f)(x) | ||
|
|
||
| self.assertEqual(y, 2 * x) | ||
| self.assertFalse(y.requires_grad) | ||
|
|
||
| def test_no_grad_nested_simple(self, device): | ||
| def h(x): | ||
| with torch.no_grad(): | ||
| shift = grad(lambda x: 0.25 * x ** 4)(x) | ||
| return x ** 3 - shift | ||
|
|
||
| x = torch.tensor(1.5, device=device, requires_grad=True) | ||
| y = grad(h)(x) | ||
| self.assertEqual(y, 3 * x ** 2) | ||
|
|
||
| z, = torch.autograd.grad(y, x) | ||
| self.assertEqual(z, 6 * x) | ||
|
|
||
| def test_no_grad_nested_complicated(self, device): | ||
| def f(x): | ||
| with torch.no_grad(): | ||
| shift = x ** 3 | ||
| return x ** 3 - shift | ||
|
|
||
| def g(x): | ||
| r1 = grad(f)(x) | ||
| with torch.no_grad(): | ||
| shift = grad(f)(x) | ||
| return r1 - shift | ||
|
|
||
| x = torch.randn([], requires_grad=True, device=device) | ||
| y = grad(g)(x) | ||
| # The only differential part of g is x ** 3 | ||
| self.assertEqual(y, 6 * x) | ||
|
|
||
| z, = torch.autograd.grad(y, x) | ||
| self.assertEqual(z, 6) | ||
|
|
||
| def test_no_grad_value(self, device): | ||
| def h(x): | ||
| with torch.no_grad(): | ||
| gvalue, value = grad_and_value(lambda x: x ** 3)(x) | ||
| return x ** 3 - value | ||
|
|
||
| x = torch.tensor(1.6, device=device, requires_grad=True) | ||
| y = grad(h)(x) | ||
| self.assertEqual(y, 3 * x ** 2) | ||
|
|
||
| z, = torch.autograd.grad(y, x) | ||
| self.assertEqual(z, 6 * x) | ||
|
|
||
| def test_no_grad_outside_vjp(self, device): | ||
| def h(x): | ||
| return x ** 2 | ||
|
|
||
| x = torch.tensor(2., requires_grad=True, device=device) | ||
| with torch.no_grad(): | ||
| out, vjp_fn = vjp(h, x) | ||
| y, = vjp_fn(torch.tensor(1., device=device)) | ||
|
|
||
| self.assertEqual(y, 2 * x) | ||
| self.assertFalse(y.requires_grad) | ||
| self.assertFalse(out.requires_grad) | ||
|
|
||
| def test_no_grad_outside_vjp_fn(self, device): | ||
| def h(x): | ||
| return x ** 2 | ||
|
|
||
| x = torch.tensor(3.14, requires_grad=True, device=device) | ||
| out, vjp_fn = vjp(h, x) | ||
| with torch.no_grad(): | ||
| y, = vjp_fn(torch.tensor(1., device=device)) | ||
|
|
||
| self.assertEqual(y, 2 * x) | ||
| self.assertFalse(y.requires_grad) | ||
| self.assertTrue(out.requires_grad) | ||
|
|
||
| z, = torch.autograd.grad(out, x) | ||
| self.assertEqual(z, 2 * x) | ||
|
|
||
| def test_no_grad_outside_vjp_only(self, device): | ||
| def h(x): | ||
| return x ** 2 | ||
|
|
||
| x = torch.tensor(3.14, requires_grad=True, device=device) | ||
| with torch.no_grad(): | ||
| out, vjp_fn = vjp(h, x) | ||
| y, = vjp_fn(torch.tensor(1., device=device)) | ||
|
|
||
| self.assertEqual(y, 2 * x) | ||
| self.assertFalse(out.requires_grad) | ||
|
|
||
| # This one is a little weird... | ||
| self.assertTrue(y.requires_grad) | ||
|
|
||
| z, = torch.autograd.grad(y, x) | ||
| self.assertEqual(z, 2) | ||
|
|
||
|
|
||
| class TestAutogradFunction(TestCase): | ||
| @_set_autograd_function_extension_enabled() | ||
| def test_set_materialize_grads(self, device): | ||
| class A(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward(x, y): | ||
| return x, y | ||
|
|
||
| @staticmethod | ||
| def setup_context(ctx, inputs, outputs): | ||
| ctx.set_materialize_grads(False) | ||
|
|
||
| @staticmethod | ||
| def backward(ctx, gx, gy): | ||
| self.assertIsNotNone(gx) | ||
| self.assertIsNone(gy) | ||
| return gx, gy | ||
|
|
||
| def f(y, x): | ||
| x, y = A.apply(x, y) | ||
| return x ** 2 | ||
|
|
||
| x = torch.tensor(2., device=device) | ||
| y = torch.tensor(3., device=device) | ||
| # grad differentiates w.r.t. arg 0 by default | ||
| grad(f)(y, x) | ||
| grad(grad(f))(y, x) | ||
|
|
||
| @_set_autograd_function_extension_enabled() | ||
| def test_needs_input_grads(self, device): | ||
| class A(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward(x, y): | ||
| return x * y | ||
|
|
||
| @staticmethod | ||
| def setup_context(ctx, inputs, outputs): | ||
| return | ||
|
|
||
| @staticmethod | ||
| def backward(ctx, grad_output): | ||
| self.assertTrue(ctx.needs_input_grad[0]) | ||
| self.assertFalse(ctx.needs_input_grad[1]) | ||
| return None, None | ||
|
|
||
| x = torch.tensor(2., device=device) | ||
| y = torch.tensor(3., device=device) | ||
| # grad differentiates w.r.t. arg 0 by default | ||
| grad(A.apply)(x, y) | ||
| grad(grad(A.apply))(x, y) | ||
|
|
||
| def _get_NumpyCubeNotComposable(self): | ||
| class NumpyCubeNotComposable(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward(input): | ||
| input_np = input.cpu().numpy() | ||
| return torch.tensor(input_np ** 3, device=input.device), input_np | ||
|
|
||
| @staticmethod | ||
| def setup_context(ctx, inputs, outputs): | ||
| ctx.input_np = outputs[1] | ||
| ctx.device = inputs[0].device | ||
|
|
||
| @staticmethod | ||
| @torch.autograd.function.once_differentiable | ||
| def backward(ctx, grad_output, grad_saved): | ||
| result_np = 3 * (ctx.input_np ** 2) | ||
| return torch.tensor(result_np, device=ctx.device) | ||
|
|
||
| return NumpyCubeNotComposable | ||
|
|
||
| @_set_autograd_function_extension_enabled() | ||
| def test_once_differentiable_autograd_vjp(self, device): | ||
| NumpyCubeNotComposable = self._get_NumpyCubeNotComposable() | ||
|
|
||
| def f(x): | ||
| y, _ = NumpyCubeNotComposable.apply(x) | ||
| return y | ||
|
|
||
| # regular autograd x vjp | ||
| x = torch.randn([], requires_grad=True, device=device) | ||
| grad_y = torch.randn_like(x, requires_grad=True) | ||
| _, vjp_fn = vjp(f, x) | ||
| gx, = vjp_fn(grad_y) | ||
|
|
||
| with self.assertRaisesRegex(RuntimeError, "marked with @once_differentiable"): | ||
| gx.backward() | ||
|
|
||
| # TODO: support torch.autograd.function.once_differentiable | ||
| # (or, if impossible, figure out how to raise a nice error) | ||
| # https://github.com/pytorch/pytorch/issues/90224 | ||
| @unittest.expectedFailure | ||
| @_set_autograd_function_extension_enabled() | ||
| def test_once_differentiable_grad_vjp(self, device): | ||
| NumpyCubeNotComposable = self._get_NumpyCubeNotComposable() | ||
|
|
||
| # grad x vjp | ||
| x = torch.randn([], device=device) | ||
| grad_y = torch.randn_like(x) | ||
|
|
||
| def h(x, grad_y): | ||
| _, vjp_fn = vjp(f, x) | ||
| gx, = vjp_fn(grad_y) | ||
| return gx | ||
|
|
||
| grad(h, argnums=(0, 1))(x, grad_y) | ||
|
|
||
| @_set_autograd_function_extension_enabled() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the repitition intentional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Meant to have negative value for second. Thanks!
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Ref: pytorch/functorch#680
We introduce a kwarg
chunk_sizeinjacrevto control whether the Jacobian computation should be chunked and if so thenchunk_sizewill dictate the maximum size of the chunks used.We try two approaches,
For Memory Benchmark, see #89376 (comment)
Benchmark CPU : Performs better with more chunks/ smaller chunk_size.
NOTE: There seems to be a lot of noise for shape
(64, 64).Details
Benchmark CUDA: Performs better with less chunks/bigger chunk_size.
Details
Benchmark Script
Details