Skip to content

Conversation

@kshitij12345
Copy link
Collaborator

@kshitij12345 kshitij12345 commented Nov 20, 2022

Ref: pytorch/functorch#680

We introduce a kwarg chunk_size in jacrev to control whether the Jacobian computation should be chunked and if so then chunk_size will dictate the maximum size of the chunks used.

We try two approaches,

  • Stacked Approach: Append the intermediate computation to a list and then stack those results.
  • Pre-allocation Approach: Pre-allocate a zeros tensor and copy chunked computation into it.

For Memory Benchmark, see #89376 (comment)

Benchmark CPU : Performs better with more chunks/ smaller chunk_size.

NOTE: There seems to be a lot of noise for shape (64, 64).

Details
[----------------------------------------------- jacrev : device cpu : chunks 2 -----------------------------------------------]
                                     |  with chunk_size and stacked  |  without chunk_size  |  with chunk_size and pre-allocated
1 threads: ---------------------------------------------------------------------------------------------------------------------
      (64, 64) : chunk_size 2080     |               76.2            |          50.9        |                  80.1             
      (128, 128) : chunk_size 8256   |             1172.8            |         783.3        |                1225.5             
      (128, 144) : chunk_size 9288   |             1475.1            |         990.4        |                1548.3             
      (144, 144) : chunk_size 10440  |             1871.3            |        1254.4        |                1971.2             

Times are in milliseconds (ms).

[----------------------------------------------- jacrev : device cpu : chunks 3 ----------------------------------------------]
                                    |  with chunk_size and stacked  |  without chunk_size  |  with chunk_size and pre-allocated
1 threads: --------------------------------------------------------------------------------------------------------------------
      (64, 64) : chunk_size 1386    |               39.9            |          25.8        |                  58.8             
      (128, 128) : chunk_size 5504  |             1182.6            |         782.2        |                1229.7             
      (128, 144) : chunk_size 6192  |             1483.6            |         995.4        |                1550.6             
      (144, 144) : chunk_size 6960  |             1879.1            |        1257.7        |                1960.5             

Times are in milliseconds (ms).

[----------------------------------------------- jacrev : device cpu : chunks 4 ----------------------------------------------]
                                    |  with chunk_size and stacked  |  without chunk_size  |  with chunk_size and pre-allocated
1 threads: --------------------------------------------------------------------------------------------------------------------
      (64, 64) : chunk_size 1040    |               41.7            |          50.6        |                  29.1             
      (128, 128) : chunk_size 4128  |             1171.6            |         782.3        |                1226.7             
      (128, 144) : chunk_size 4644  |             1482.2            |         994.6        |                1550.9             
      (144, 144) : chunk_size 5220  |             1870.2            |        1254.5        |                1961.4             

Times are in milliseconds (ms).

[--------------------------------------------- jacrev : device cpu : chunks 100 ---------------------------------------------]
                                   |  with chunk_size and stacked  |  without chunk_size  |  with chunk_size and pre-allocated
1 threads: -------------------------------------------------------------------------------------------------------------------
      (64, 64) : chunk_size 41     |               46.8            |          50.5        |                  46.4             
      (128, 128) : chunk_size 165  |              622.2            |         775.2        |                 656.0             
      (128, 144) : chunk_size 185  |              803.9            |         987.3        |                 866.9             
      (144, 144) : chunk_size 208  |             1021.1            |        1251.2        |                1088.2             

Times are in milliseconds (ms).

[--------------------------------------------- jacrev : device cpu : chunks 200 ---------------------------------------------]
                                   |  with chunk_size and stacked  |  without chunk_size  |  with chunk_size and pre-allocated
1 threads: -------------------------------------------------------------------------------------------------------------------
      (64, 64) : chunk_size 20     |               60.9            |          50.2        |                  62.3             
      (128, 128) : chunk_size 82   |              583.1            |         779.4        |                 634.3             
      (128, 144) : chunk_size 92   |              834.1            |        1005.8        |                 472.3             
      (144, 144) : chunk_size 104  |             1053.6            |        1277.0        |                1033.9             

Times are in milliseconds (ms).

[--------------------------------------------- jacrev : device cpu : chunks 300 --------------------------------------------]
                                  |  with chunk_size and stacked  |  without chunk_size  |  with chunk_size and pre-allocated
1 threads: ------------------------------------------------------------------------------------------------------------------
      (64, 64) : chunk_size 13    |              77.7             |          50.4        |                  79.6             
      (128, 128) : chunk_size 55  |             578.9             |         782.3        |                 626.9             
      (128, 144) : chunk_size 61  |             718.2             |        1024.9        |                 800.4             
      (144, 144) : chunk_size 69  |             919.7             |        1313.7        |                1023.0             

Times are in milliseconds (ms).

Benchmark CUDA: Performs better with less chunks/bigger chunk_size.

Details
[--------------------------------------------- jacrev : device cuda:1 : chunks 2 ----------------------------------------------]
                                     |  with chunk_size and stacked  |  without chunk_size  |  with chunk_size and pre-allocated
1 threads: ---------------------------------------------------------------------------------------------------------------------
      (64, 64) : chunk_size 2080     |             1485.7            |         923.8        |                1632.3             
      (128, 128) : chunk_size 8256   |            25390.2            |       14103.2        |               33557.4             
      (128, 144) : chunk_size 9288   |              801.7            |       16854.1        |               42894.6             
      (144, 144) : chunk_size 10440  |             1003.5            |       21386.5        |               59648.5             

Times are in microseconds (us).

3 / 3 : Shape (144, 144) : Device cuda:1 : chunks: 3
[--------------------------------------------- jacrev : device cuda:1 : chunks 3 ---------------------------------------------]
                                    |  with chunk_size and stacked  |  without chunk_size  |  with chunk_size and pre-allocated
1 threads: --------------------------------------------------------------------------------------------------------------------
      (64, 64) : chunk_size 1386    |             1474.5            |         924.5        |                1655.5             
      (128, 128) : chunk_size 5504  |            25368.9            |       10156.0        |               34022.1             
      (128, 144) : chunk_size 6192  |            25223.0            |       12933.7        |               56418.5             
      (144, 144) : chunk_size 6960  |            24729.3            |       16367.4        |               68744.7             

Times are in microseconds (us).

3 / 3 : Shape (144, 144) : Device cuda:1 : chunks: 4
[--------------------------------------------- jacrev : device cuda:1 : chunks 4 ---------------------------------------------]
                                    |  with chunk_size and stacked  |  without chunk_size  |  with chunk_size and pre-allocated
1 threads: --------------------------------------------------------------------------------------------------------------------
      (64, 64) : chunk_size 1040    |             1489.2            |         924.4        |                 1679.6            
      (128, 128) : chunk_size 4128  |            25370.4            |        8987.4        |                57201.3            
      (128, 144) : chunk_size 4644  |            32239.1            |       10136.2        |                72406.5            
      (144, 144) : chunk_size 5220  |            40994.3            |       12867.8        |               108653.4            

Times are in microseconds (us).

3 / 3 : Shape (144, 144) : Device cuda:1 : chunks: 100
[------------------------------------------- jacrev : device cuda:1 : chunks 100 --------------------------------------------]
                                   |  with chunk_size and stacked  |  without chunk_size  |  with chunk_size and pre-allocated
1 threads: -------------------------------------------------------------------------------------------------------------------
      (64, 64) : chunk_size 41     |            21121.8            |         924.2        |               22753.5             
      (128, 128) : chunk_size 165  |            23679.7            |       14284.4        |               26758.2             
      (128, 144) : chunk_size 185  |            30082.3            |       18063.3        |               33553.5             
      (144, 144) : chunk_size 208  |            38175.6            |       22839.5        |               42030.0             

Times are in microseconds (us).

Benchmark Script

Details
import functorch
import torch
import itertools
import time
from torch.utils.benchmark import Timer
from torch.utils.benchmark import Compare
import sys
import pickle
from torch import profiler

import math

def prod(l):
    prod = 1
    for el in l:
        prod *= el
    
    return prod

def fn(x, y):
    return x + y, x.sum(0)

shapes = ((64, 64), (128, 128), (128, 144), (144, 144))

for device in ('cpu', 'cuda:1'):
    if device == 'cuda:1':
        chunks = (2, 3, 4, 100,)
    else:
        chunks = (2, 3, 4, 100, 200, 300)
    for chunk in chunks:
        results = []
        for shape in shapes:
            x = torch.zeros(*shape, dtype=torch.float, device=device)
            y = x.sum()
            chunk_size = (prod(shape) + prod(shape[1:])) // chunk
            jacrev_fn_chunked = functorch.jacrev(fn, (0, 1), chunk_size=chunk_size)
            jacrev_fn_chunked_pre = functorch.jacrev(fn, (0, 1), chunk_size=chunk_size, _preallocate_and_copy=True)
            jacrev_fn = functorch.jacrev(fn, (0, 1), chunk_size=None)

            tasks = [("jacrev_fn_chunked(x, y)", "with chunk_size and stacked"),
                     ("jacrev_fn(x, y)", "without chunk_size"),
                     ("jacrev_fn_chunked_pre(x, y)", "with chunk_size and pre-allocated"),]
            timers = [Timer(stmt=stmt, label=f"jacrev : device {device} : chunks {chunk}", sub_label=f"{(shape)} : chunk_size {chunk_size}", description=desc, globals=globals()) for stmt, desc in tasks]

            
            for i, timer in enumerate(timers):
                results.append(
                    timer.blocked_autorange(min_run_time=2.)
                )
                print(f"\r{i + 1} / {len(timers)} : Shape {shape} : Device {device} : chunks: {chunk}", end="")
                sys.stdout.flush()

        print()
        comparison = Compare(results)
        comparison.print()

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 20, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/89376

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ff0e9e1:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@kshitij12345 kshitij12345 changed the title [WIP] jacrev : for loop approach [WIP] jacrev : Support chunked computation Dec 12, 2022
@kshitij12345 kshitij12345 changed the title [WIP] jacrev : Support chunked computation jacrev : Support chunked computation Dec 13, 2022
@kshitij12345
Copy link
Collaborator Author

kshitij12345 commented Dec 13, 2022

CUDA Memory Summary from the different approaches (using torch.cuda.memory_summary) : Chunks=10

NOTE: Have copied only the interesting part of the summary.

Stacked Approach

|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |   83456 B  |    3312 MB |    4983 MB |    4983 MB |
|       from large pool |       0 B  |    3312 MB |    4983 MB |    4983 MB |
|       from small pool |   83456 B  |       0 MB |       0 MB |       0 MB |
|---------------------------------------------------------------------------|
| Active memory         |   83456 B  |    3312 MB |    4983 MB |    4983 MB |
|       from large pool |       0 B  |    3312 MB |    4983 MB |    4983 MB |
|       from small pool |   83456 B  |       0 MB |       0 MB |       0 MB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |    3500 MB |    3500 MB |    3500 MB |       0 B  |
|       from large pool |    3498 MB |    3498 MB |    3498 MB |       0 B  |
|       from small pool |       2 MB |       2 MB |       2 MB |       0 B  |
|---------------------------------------------------------------------------|
| Non-releasable memory |    1966 KB |   21190 KB |   32092 KB |   30125 KB |
|       from large pool |       0 KB |   19305 KB |   29876 KB |   29876 KB |
|       from small pool |    1966 KB |    1967 KB |    2216 KB |     249 KB |
|---------------------------------------------------------------------------|

Pre-allocation Approach

|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |   83456 B  |    2152 MB |    4983 MB |    4983 MB |
|       from large pool |       0 B  |    2152 MB |    4983 MB |    4983 MB |
|       from small pool |   83456 B  |       0 MB |       0 MB |       0 MB |
|---------------------------------------------------------------------------|
| Active memory         |   83456 B  |    2152 MB |    4983 MB |    4983 MB |
|       from large pool |       0 B  |    2152 MB |    4983 MB |    4983 MB |
|       from small pool |   83456 B  |       0 MB |       0 MB |       0 MB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |    2172 MB |    2172 MB |    2172 MB |       0 B  |
|       from large pool |    2170 MB |    2170 MB |    2170 MB |       0 B  |
|       from small pool |       2 MB |       2 MB |       2 MB |       0 B  |
|---------------------------------------------------------------------------|
| Non-releasable memory |    1966 KB |   21108 KB |   32092 KB |   30125 KB |
|       from large pool |       0 KB |   19305 KB |   29876 KB |   29876 KB |
|       from small pool |    1966 KB |    1967 KB |    2216 KB |     249 KB |
|---------------------------------------------------------------------------|

Single Chunk

|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |   83456 B  |    3316 MB |    3316 MB |    3316 MB |
|       from large pool |       0 B  |    3316 MB |    3316 MB |    3316 MB |
|       from small pool |   83456 B  |       0 MB |       0 MB |       0 MB |
|---------------------------------------------------------------------------|
| Active memory         |   83456 B  |    3316 MB |    3316 MB |    3316 MB |
|       from large pool |       0 B  |    3316 MB |    3316 MB |    3316 MB |
|       from small pool |   83456 B  |       0 MB |       0 MB |       0 MB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |    3318 MB |    3318 MB |    3318 MB |       0 B  |
|       from large pool |    3316 MB |    3316 MB |    3316 MB |       0 B  |
|       from small pool |       2 MB |       2 MB |       2 MB |       0 B  |
|---------------------------------------------------------------------------|
| Non-releasable memory |    1966 KB |    1967 KB |    2131 KB |  168448 B  |
|       from large pool |       0 KB |       0 KB |       0 KB |       0 B  |
|       from small pool |    1966 KB |    1967 KB |    2131 KB |  168448 B  |
|---------------------------------------------------------------------------|

Script:

Details
import functorch
import torch

def prod(l):
    prod = 1
    for el in l:
        prod *= el
    
    return prod

def fn(x, y):
    return x + y, x.sum(0)

shape = (144, 144)
chunk = 10
x = torch.zeros(*shape, dtype=torch.float, device='cuda')
y = x.sum()
chunk_size = (prod(shape) + prod(shape[1:])) // chunk

# Stack approach
# jacrev_fn_chunked = functorch.jacrev(fn, (0, 1), chunk_size=chunk_size)
# jacrev_fn_chunked(x, y)

# Pre-allocate and copy approach
# jacrev_fn_chunked = functorch.jacrev(fn, (0, 1), chunk_size=chunk_size, _preallocate_and_copy=True)
# jacrev_fn_chunked(x, y)


# Single chunk
jacrev_fn = functorch.jacrev(fn, (0, 1), chunk_size=None)
jacrev_fn(x, y)

print(torch.cuda.memory_summary())

auxiliary objects that will not be differentiated.
Default: False.
chunk_size (None or int): If specified, controls the maximum size of chunk for computing
Jacobian. Default: None.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to tweak it.

Specify what happens for None.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If None (default), we will use the maximum chunk size (this is equivalent to doing a single vmap over vjp to compute the jacobian). If not None, then we will compute the jacobian chunk_size rows at a time using vmap to vectorize the computation. Note that chunk_size=1 is equivalent to computing the jacobian row-by-row with a for-loop. If you run into memory issues computing the jacobian, please try to specify a non-None chunk_size.

Something like that

@kshitij12345 kshitij12345 marked this pull request as ready for review December 13, 2022 13:24
@kshitij12345 kshitij12345 requested a review from zou3519 December 13, 2022 13:24
@kshitij12345
Copy link
Collaborator Author

kshitij12345 commented Dec 13, 2022

@zou3519

Just putting this out there for review on the API (chunk vs chunk_size), perf and memory benchmarks between different approaches.

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code looks correct and clean. Let's discuss the API options with more folks (and I'll think about it as well).

def f(x, y):
return (x.sin(), x + y), (x + 2, x.sum())

for chunk_size in [1, 2, 3, 4, 7, 10]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: When we have figured out the API, we should test some extreme cases:

  • check that chunk_size <= 0 raises an error
  • try chunk_size = 100000 (some big number)

Comment on lines +341 to +343
def jacrev(func: Callable, argnums: Union[int, Tuple[int]] = 0, *, has_aux=False,
chunk_size: Optional[int] = None,
_preallocate_and_copy=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @samdow @soulitzer @Chillee for API help.

We've got a couple of options here.

  1. Either we have a chunk_size argument, or we have a chunks argument (for the number of total chunks).
  2. _preallocate_and_copy is private, or we expose it publicly.

My opinion is:

  • we should make preallocate_and_copy public. In the long run, memory-planning in PT2 should save us, but idk how soon that is coming and the preallocate_and_copy code is simple enough to maintain.
  • I have a slight preference for chunks:
    • If jacrev(f)(x) OOMs, the user needs to try out a chunks/chunks_size argument. If the API is chunks, then it is clear what the next number the user should try is: 2, and they can keep incrementing this until they're satisfied. If the API is chunk_size, the user just tosses random numbers or needs to compute the size of their jacobian to figure out what the max chunk_size is so they know what the range of numbers to try is.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some nuggets of wisdom from Horace:

  1. chunk_size seems nicer, because (1) it is batch-agnostic and (2) it's like loop unrolling - you loop unroll some number of things instead of indicate how many times you want to loop unroll.
  2. _preallocate_and_copy should NOT be public. PT2 is around the corner and already includes this optimization (assuming the entire jacrev call can be captured); one of the goals of compilation is to remove code smells like this.

Also, we can always update chunk_size to have different behavior, if we really want.

  • if the number is >= 1 then it is a chunk size
  • if the number is between 0 and 1 then it is 1/chunks.

@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Dec 15, 2022
Comment on lines 1887 to 1893

# NB: numpy is a testing dependency!
import numpy as np

USE_TORCHVISION = False
try:
import torchvision # noqa: F401
USE_TORCHVISION = True
except ImportError:
warnings.warn("Couldn't import torchvision. Some of our tests use it, try "
"to install it with commands from pytorch.org, post-fixed with "
"`--no-deps` to avoid overwriting the pytorch installation",
UserWarning)

# TestCase for _slice_argnums, an important helper funciton


class TestSliceArgnums(TestCase):
def test_invalid_argnum_type(self):
x = torch.randn(3)
args = (x,)
with self.assertRaisesRegex(RuntimeError, "int or Tuple"):
_slice_argnums(args, 0.0)
with self.assertRaisesRegex(RuntimeError, "int or Tuple"):
_slice_argnums(args, [0])
with self.assertRaisesRegex(RuntimeError, "must be int"):
_slice_argnums(args, (0.0,))

args = (0.1, 1.1, 2.1, 3.1, 4.1)

with self.assertRaisesRegex(RuntimeError, "must be int"):
_slice_argnums(args, ((0, 1), 2))

def test_out_of_bounds_argnum_values(self):
x = torch.randn(3)
args = (x,)
with self.assertRaisesRegex(RuntimeError, "positional inputs"):
_slice_argnums(args, 1)
with self.assertRaisesRegex(RuntimeError, "positional inputs"):
_slice_argnums(args, -2)
with self.assertRaisesRegex(RuntimeError, "positional inputs"):
_slice_argnums(args, (-2,))

def test_not_enough_argnums(self):
x = torch.randn(3)
args = (x,)
with self.assertRaisesRegex(RuntimeError, "must be non-empty"):
_slice_argnums(args, ())

def test_duplicate_argnums(self):
x = torch.randn(3)
args = (x, x)
with self.assertRaisesRegex(RuntimeError, "must be unique"):
_slice_argnums(args, (0, 0))
with self.assertRaisesRegex(RuntimeError, "must be unique"):
_slice_argnums(args, (0, -2))

def test_flat_args_with_positive_int_argnum(self):
args = (0.1, 1.1, 2.1, 3.1, 4.1)

res = _slice_argnums(args, 0)
self.assertEqual(res, (0.1,))

res = _slice_argnums(args, 4)
self.assertEqual(res, (4.1,))

def test_flat_args_with_negative_int_argnum(self):
args = (0.1, 1.1, 2.1, 3.1, 4.1)

res = _slice_argnums(args, -1)
self.assertEqual(res, (4.1,))

res = _slice_argnums(args, -5)
self.assertEqual(res, (0.1,))

def test_flat_args_with_tuple_argnum(self):
args = (0.1, 1.1, 2.1, 3.1, 4.1)

res = _slice_argnums(args, (0, 1, 2, 3, 4))
self.assertEqual(res, args)

res = _slice_argnums(args, (0, -3))
self.assertEqual(res, (0.1, 2.1))

def test_pytree_args(self):
args = ((0.1, 1.1), 2.0, [3.1])

res = _slice_argnums(args, 0)
self.assertEqual(res, args[0:1])

res = _slice_argnums(args, (0,))
self.assertEqual(res, args[0:1])

res = _slice_argnums(args, -1)
self.assertEqual(res, args[-1:])

res = _slice_argnums(args, (0, -2))
self.assertEqual(res, args[0:2])

def test_argnums_reorders(self):
args = ((0.1, 1.1, 2.1), 3.1, 4.1)

res = _slice_argnums(args, (1, 0))
self.assertEqual(res, (args[1], args[0]))


class TestGradTransform(TestCase):
def test_primitive(self, device):
x = torch.randn([], device=device)
result = grad(torch.sin)(x)
self.assertEqual(result, torch.cos(x))

def test_composite_simple(self, device):
x = torch.randn(2, 3, 4, device=device)
result = grad(lambda x: torch.flatten(x).sum())(x)
self.assertEqual(result, torch.ones_like(x))

def test_fn_with_kwargs(self, device):
def foo(x, y):
return (x * y).sum()

x = torch.randn(3, device=device)
y = torch.randn(3, device=device)
expected = grad(foo)(x, y)
result = grad(foo)(x, y=y)
self.assertEqual(result, expected)

def test_composite_complicated(self, device):
x = torch.randn(3, device=device)
y = torch.randn(3, 5, device=device)

def foo(x, y):
result = x @ y
return result.sum()

result = grad(foo)(x, y)

x.requires_grad_()
out = foo(x, y)
expected, = torch.autograd.grad(out, x)

self.assertEqual(result, expected)

def test_composite_two_ops(self, device):
N, C = 2, 5
y = torch.randn(N, C, device=device)
targets = torch.randint(0, C, (N,), device=device)

def foo(y, targets):
return F.cross_entropy(y, targets)

result = grad(foo)(y, targets)

y.requires_grad_()
expected, = torch.autograd.grad(foo(y, targets), y)

self.assertEqual(result, expected)

def _test_attributes(self, get_attr_lambda, device):
x = torch.randn(2, 3, 5, dtype=torch.double, device=device)
expected = get_attr_lambda(x)

def foo(x):
self.assertEqual(get_attr_lambda(x), expected)
return x.sum()

grad(foo)(x)

def test_shape(self, device):
self._test_attributes(lambda x: x.shape, device)

def test_dtype(self, device):
self._test_attributes(lambda x: x.dtype, device)

def test_is_cuda(self, device):
self._test_attributes(lambda x: x.is_cuda, device)

def test_numel(self, device):
self._test_attributes(lambda x: x.numel(), device)

def test_inplace(self, device):
x = torch.randn([], device=device)

def foo(x):
return x.clone().sin_()

result = grad(foo)(x)
self.assertEqual(result, x.cos())

def test_inplace_on_view(self, device):
x = torch.randn(3, device=device)

def foo(x):
y = x.clone()
y0 = y[0]
y0.sin_()
return y.sum()

result = grad(foo)(x)

x.requires_grad_()
out = foo(x)
expected, = torch.autograd.grad(out, x)

self.assertEqual(result, expected)

def test_inplace_on_view_base(self, device):
x = torch.randn(3, device=device)

def foo(x):
y = x.clone()
y0 = y[0]
y.sin_()
return y0

result = grad(foo)(x)

x.requires_grad_()
out = foo(x)
expected, = torch.autograd.grad(out, x)

self.assertEqual(result, expected)

def test_inplace_on_captures(self, device):
x = torch.tensor([1., 2., 3.], device=device)
captured = torch.randn(3, device=device)

def foo(x):
captured.copy_(x)
return (x * captured).sum()

with self.assertRaisesRegex(RuntimeError, 'mutate a captured Tensor'):
grad(foo)(x)

def test_nesting_simple(self, device):
x = torch.randn([], device=device)
result = grad(grad(torch.sin))(x)
self.assertEqual(result, -torch.sin(x))

def test_escaped_wrappers_are_marked_as_dead(self, device):
x = torch.randn([], device=device)
escaped = []

def foo(x):
y = x.sin()
escaped.append(y)
return y

grad(foo)(x)
self.assertEqual(torch._C._functorch.dlevel(escaped[0]), -1)

def test_escaped_wrappers_are_ignored(self, device):
x = torch.randn([], device=device)
escaped = []

def foo(x):
y = x.sin()
escaped.append(y)
return y

grad(foo)(x)

something = escaped[0].sum()
self.assertEqual(torch._C._functorch.dlevel(something), 0)
self.assertEqual(something, x.sin().sum())

def test_manual_seed_inside_grad(self, device):
x = torch.randn([], device=device)

def f(x):
torch.manual_seed(0)
return x * torch.randn_like(x)

with freeze_rng_state():
result = grad(f)(x)
x.requires_grad_()
expected, = torch.autograd.grad(f(x), x)
self.assertEqual(result, expected)

def test_vjp(self, device):
x = torch.randn([], device=device)
out, vjp_fn = vjp(torch.sin, x)
self.assertEqual(out, x.sin())

v = torch.randn([], device=device)
result, = vjp_fn(v)
self.assertEqual(result, v * x.cos())

def test_vjp_two_outputs(self, device):
def f(x):
return x, x
result, vjp_fn = vjp(f, torch.tensor(1.))
vjp_fn(result)

def test_conj_bit(self):
x = torch.tensor(1 + 1j)

def foo(x):
assert not x.is_conj()
y = x.conj()
assert y.is_conj()
return y
res = grad(foo)(x)
self.assertEqual(res, torch.ones_like(res))

def test_composed_with_autograd(self, device):
x = torch.randn([], requires_grad=True, device=device)

y = grad(torch.sin)(x)
result, = torch.autograd.grad(y, x)
self.assertEqual(result, -x.sin())

def test_grad_of_vjp_composition(self, device):
x = torch.randn([], device=device)
y = torch.randn([], device=device)

def foo(x, y):
out, vjp_fn = vjp(torch.sin, x)
return grad(lambda y: vjp_fn(y)[0])(y)

result = foo(x, y)
expected = x.cos()
self.assertEqual(result, expected)

def test_vjp_of_grad_composition(self, device):
x = torch.randn([], device=device)
y = torch.randn([], device=device)

def foo(x, y):
out, vjp_fn = vjp(grad(torch.sin), x)
return vjp_fn(y)[0]

result = foo(x, y)
expected = -y * x.sin()
self.assertEqual(result, expected)

def test_grad_of_vjp_of_grad_composition(self, device):
x = torch.randn([], device=device)
y = torch.randn([], device=device)

def foo(x, y):
df, vjp_fn = vjp(grad(lambda x: -torch.cos(x)), x)
return grad(lambda y: vjp_fn(y)[0])(y)

result = foo(x, y)
expected = x.cos()
self.assertEqual(result, expected)

def test_views(self, device):
x = torch.randn([], requires_grad=True, device=device)
y = torch.randn([], requires_grad=True, device=device)

def silly_sin(x):
x = x.view([])
x = x.sin()
return x

def foo(x, y):
z1 = grad(silly_sin)(x)
z2 = torch.cos(y)
return z1 + z2

result = foo(x, y)
grads = torch.autograd.grad(result, [x, y])
self.assertEqual(grads[0], -x.sin())
self.assertEqual(grads[1], -y.sin())

def test_view_inplace_simple(self, device):
def foo(x):
x = x.clone()
x.view([]).sin_()
return x

x = torch.randn([], requires_grad=True, device=device)
result = grad(foo)(x)
self.assertEqual(result, x.cos())

def test_invalid_argnums(self, device):
x = torch.randn([])
y = torch.randn([])
with self.assertRaisesRegex(RuntimeError, 'but only'):
grad(torch.mul, argnums=-3)(x, y)
with self.assertRaisesRegex(RuntimeError, 'but only'):
grad(torch.mul, argnums=2)(x, y)
with self.assertRaisesRegex(RuntimeError, 'int or Tuple'):
grad(torch.mul, argnums=[0])(x, y)
with self.assertRaisesRegex(RuntimeError, 'must be int'):
grad(torch.mul, argnums=('0',))(x, y)
with self.assertRaisesRegex(RuntimeError, 'must be unique'):
grad(torch.mul, argnums=(0, 0))(x, y)
with self.assertRaisesRegex(RuntimeError, 'must be unique'):
grad(torch.mul, argnums=(0, -2))(x, y)

def test_argnums(self, device):
x = torch.randn([])
y = torch.randn([])
gx = grad(torch.mul, argnums=0)(x, y)
self.assertEqual(gx, y)

gy = grad(torch.mul, argnums=1)(x, y)
self.assertEqual(gy, x)

gx, = grad(torch.mul, argnums=(0,))(x, y)
self.assertEqual(gx, y)

gx, gy = grad(torch.mul, argnums=(0, 1))(x, y)
self.assertEqual(gx, y)
self.assertEqual(gy, x)

def test_out_of_order_argnums(self, device):
x = torch.randn([])
y = torch.randn([])
gy, gx = grad(torch.mul, argnums=(1, 0))(x, y)
self.assertEqual(gx, y)
self.assertEqual(gy, x)

def test_negative_argnums(self, device):
x = torch.randn([])
y = torch.randn([])
gx = grad(torch.mul, argnums=-2)(x, y)
self.assertEqual(gx, y)

gy = grad(torch.mul, argnums=-1)(x, y)
self.assertEqual(gy, x)

gx, = grad(torch.mul, argnums=(-2,))(x, y)
self.assertEqual(gx, y)

gx, gy = grad(torch.mul, argnums=(-2, -1))(x, y)
self.assertEqual(gx, y)
self.assertEqual(gy, x)

def test_grad_pytree_inputs(self, device):
x = torch.randn([], device=device)

def f(a, b):
x, y = a
return 1 * x + 2 * y + 3 * b['foo']

args = ((x, x), {'foo': x})

gx, gy = grad(f)(*args)
self.assertEqual(gx, torch.tensor(1., device=device))
self.assertEqual(gy, torch.tensor(2., device=device))

(gx, gy), = grad(f, argnums=(0,))(*args)
self.assertEqual(gx, torch.tensor(1., device=device))
self.assertEqual(gy, torch.tensor(2., device=device))

(gx, gy), gz = grad(f, argnums=(0, 1))(*args)
self.assertEqual(gx, torch.tensor(1., device=device))
self.assertEqual(gy, torch.tensor(2., device=device))
self.assertEqual(gz['foo'], torch.tensor(3., device=device))

def test_grad_aux_tensor(self, device):

x = torch.randn(3, device=device)

with self.assertRaisesRegex(
RuntimeError,
r'grad_and_value\(f\)\(\*args\): output of function f should be a tuple'
):
grad(lambda t: [t, t], has_aux=True)(x)

with self.assertRaisesRegex(
RuntimeError,
r'grad_and_value\(f\)\(\*args\): output of function f should be a tuple'
):
grad(lambda t: (t, t + 2, t + 3), has_aux=True)(x)

def f(t):
y = t.sin()
return y.sum(), t.cos()

out, aux = grad(f, has_aux=True)(x)
self.assertEqual(aux, x.cos())
self.assertEqual(out, x.cos())

def test_grad_aux_pytree(self, device):
def f(x):
y = x.sin()
return y.sum(), {'a': x.cos(), 'b': [x.tan()]}

x = torch.randn(3, device=device)

out, aux = grad(f, has_aux=True)(x)
_, expected_aux = f(x)
self.assertEqual(aux, expected_aux)
self.assertEqual(out, x.cos())

for aux in [1, 1.0, "abc"]:
with self.assertRaisesRegex(RuntimeError, r"Expected tensors, got unsupported type"):
_ = grad(lambda x: (x.sum(), aux), has_aux=True)(x)
with self.assertRaisesRegex(RuntimeError, r"Expected tensors, got unsupported type"):
_ = grad(lambda x: (x.sum(), [x, aux]), has_aux=True)(x)

def test_zero_grad(self, device):
def f(x):
return (x['a']**2.0).sum()
inps = ({'a': torch.randn(10, device=device) + 3, 'b': torch.randn(10, device=device)})
grads = grad(f)(inps)
self.assertNotEqual(grads['a'].sum(), 0.0)
self.assertEqual(grads['b'].sum(), 0.0)

def test_unrelated_grad(self, device):
x = torch.tensor(1., device=device)
y = torch.tensor(2., device=device)

def unrelated(x):
return y

result = grad(unrelated)(x)
self.assertEqual(result, torch.zeros_like(x))

def test_unrelated_vjp(self, device):
x = torch.tensor(1., device=device)
y = torch.tensor(2., device=device)
v = torch.tensor(1., device=device)

def unrelated(x):
return y

out, vjp_fn = vjp(unrelated, x)
result = vjp_fn(v)
expected = (torch.zeros_like(x),)
self.assertEqual(result, expected)

def test_unrelated_vjp_multiple_inputs_outputs(self, device):
w = torch.tensor(3., device=device)
x = torch.tensor(4., device=device)
y = torch.tensor(2., device=device)
v = torch.tensor(1., device=device)

def unrelated(w, x):
return y, y, x

out, vjp_fn = vjp(unrelated, w, x)
result = vjp_fn((v, v, v))
expected = (torch.zeros_like(x), torch.ones_like(x))
self.assertEqual(result, expected)

# TODO: https://github.com/zou3519/functorch/issues/12
@onlyCPU
def test_unrelated_hessian(self, device):
N = 5
M = 3
W = torch.randn(N, M, device=device)

def f(x):
return W @ x

x = torch.randn(M)
result = jacrev(jacrev(f))(x)
expected = torch.zeros(N, M, M, device=device)
self.assertEqual(result, expected)

def test_vjp_pytree_input(self, device):
def f(x):
return x[0] * x[1][0]

x = torch.randn([], device=device)
v = torch.randn([], device=device)
out, vjp_fn = vjp(f, (x, (x, x)))
self.assertEqual(out, x * x)
result = vjp_fn(v)
self.assertEqual(result, ((x * v, (x * v, 0.)),))

def test_vjp_pytree_output(self, device):
def f(x):
return x, (x, x)

x = torch.randn([], device=device)
v1 = torch.randn([], device=device)
v2 = torch.randn([], device=device)
v3 = torch.randn([], device=device)
_, vjp_fn = vjp(f, x)
result, = vjp_fn((v1, (v2, v3)))
self.assertEqual(result, v1 + v2 + v3)

def test_vjp_outputs_can_any_pytree(self, device):
x = torch.randn(2, 3, device=device)
t = torch.randn(2, 3, device=device)

for output in [None, ()]:
with self.assertRaisesRegex(
RuntimeError, r"vjp\(f, \*primals\): Expected f to be a function that has non-empty output"
):
_, vjp_fn = vjp(lambda _: output, x)
vjp_fn(t)

for output in [1, True, 12.2, "abc"]:
with self.assertRaisesRegex(
RuntimeError, r"vjp\(f, \*primals\): expected f\(\*primals\) to return only tensors"
):
_, vjp_fn = vjp(lambda _: output, x)
vjp_fn(t)

# Check list output
output, vjp_fn = vjp(lambda x: [x, x.sum()], x)
vjp_out, = vjp_fn([t, t.sum()])
assert isinstance(output, list) and len(output) == 2
assert isinstance(vjp_out, torch.Tensor)

# Check dict output
output, vjp_fn = vjp(lambda x: {"x": x, "xsum": x.sum()}, x)
vjp_out, = vjp_fn({"x": t, "xsum": t.sum()})
assert isinstance(output, dict) and len(output) == 2 and "xsum" in output
assert isinstance(vjp_out, torch.Tensor)

def composite_output(x):
out = x.sum()
return [
(out, {"a": x, "out": [x, out]}),
]

output, vjp_fn = vjp(composite_output, x)
vjp_out, = vjp_fn([(t.sum(), {"a": t, "out": [t, t.sum()]}), ])
assert isinstance(output, list)
assert isinstance(output[0], tuple) and isinstance(output[0][1], dict)
assert isinstance(vjp_out, torch.Tensor)

def test_vjp_pytree_error(self, device):
def f(x):
return x, (x, x)

x = torch.randn([], device=device)
v1 = torch.randn([], device=device)
v2 = torch.randn([], device=device)
v3 = torch.randn([], device=device)
_, vjp_fn = vjp(f, x)
with self.assertRaisesRegex(RuntimeError, 'Expected pytree structure'):
result, = vjp_fn(((v1, (v2, v3)),))

def test_vjp_aux_tensor(self, device):

x = torch.randn(3, device=device)

with self.assertRaisesRegex(RuntimeError, r'vjp\(f, \*primals\): output of function f should be a tuple'):
vjp(lambda t: [t, t], x, has_aux=True)

with self.assertRaisesRegex(RuntimeError, r'vjp\(f, \*primals\): output of function f should be a tuple'):
vjp(lambda t: (t, t + 2, t + 3), x, has_aux=True)

def f(t):
y = t.sin()
return y, t.cos()

out, vjp_fn, aux = vjp(f, x, has_aux=True)
self.assertEqual(aux, x.cos())
self.assertEqual(out, x.sin())

v = torch.randn(3, device=device)
grad_x, = vjp_fn(v)
self.assertEqual(grad_x, v * x.cos())

def test_vjp_aux_pytree(self, device):
def f(x):
y = x.sin()
return y, {'a': x.cos(), 'b': [x.tan()]}

x = torch.randn(3, device=device)

out, vjp_fn, aux = vjp(f, x, has_aux=True)
expected_out, expected_aux = f(x)
self.assertEqual(out, expected_out)
self.assertEqual(aux, expected_aux)

v = torch.randn(3, device=device)
grad_x, = vjp_fn(v)
self.assertEqual(grad_x, v * x.cos())

for aux in [1, 1.0, "abc"]:
with self.assertRaisesRegex(RuntimeError, r"Expected tensors, got unsupported type"):
_ = vjp(lambda x: (x, aux), x, has_aux=True)
with self.assertRaisesRegex(RuntimeError, r"Expected tensors, got unsupported type"):
_ = vjp(lambda x: (x, [x, aux]), x, has_aux=True)

def test_functional_init(self, device):
class MLPClassifier(nn.Module):
def __init__(self, hidden_dim=32, n_classes=2):
super().__init__()
self.hidden_dim = hidden_dim
self.n_classes = n_classes

self.fc1 = nn.Linear(2, self.hidden_dim)
self.fc2 = nn.Linear(self.hidden_dim, self.n_classes)

def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.log_softmax(x, -1)
return x

B = 10
weights, fn, _ = functional_init(MLPClassifier, (B,), device=device)(32, 2)
inputs = torch.randn(B, 7, 2, device=device)
vmap(fn)(weights, (inputs,))

def test_functional_init_with_buffers(self, device):
class MLPClassifier(nn.Module):
def __init__(self, hidden_dim=32, n_classes=2):
super().__init__()
self.hidden_dim = hidden_dim
self.n_classes = n_classes

self.fc1 = nn.Linear(2, self.hidden_dim)
self.bn = nn.BatchNorm1d(self.hidden_dim, affine=True)
self.fc2 = nn.Linear(self.hidden_dim, self.n_classes)

def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
x = self.bn(x)
x = self.fc2(x)
x = F.log_softmax(x, -1)
return x

B = 10
weights, buffers, fn, _, _ = \
functional_init_with_buffers(MLPClassifier, [B], device=device)(32, 2)
inputs = torch.randn(B, 7, 2, device=device)
vmap(fn)(weights, buffers, (inputs,))

def test_advanced_indexing(self, device):
def f(value):
log_prob = torch.ones((), device=device)
val = (torch.zeros(()) > 0)
log_prob[val] = 0
return value

result = grad(f)(torch.randn((), device=device))
self.assertEqual(result, torch.ones_like(result))

def f2(value):
value = value.clone()
value[value > 0] = 0
return value.sum()

x = torch.randn(100, device=device)
result = grad(f2)(x)
self.assertEqual(result, (x <= 0).type_as(x))

def test_tensor_ctor_inside_grad(self, device):
def foo(x):
return x * torch.tensor(2., device=device)

x = torch.tensor(3.14, device=device)
functorch.grad(foo)(x)

@parametrize("op_list_data", [
subtest(([vmap, ], [(4, 2), (64, 3, 32, 32)]), name='vmap'),
subtest(([vmap, vmap], [(4, 3, 2), (64, 3, 32, 32)]), name='vmap_vmap'),
subtest(([grad, ], [(0, ), [], (4, 2), (64, 3, 32, 32)]), name='grad'),
subtest(([grad, grad], [[], ]), name='grad_grad'),
subtest(([vmap, grad], [(4, 2)]), name='vmap_grad'),
])
def test_tensor_print(self, device, op_list_data):

op_list, shapes = op_list_data

for dt in get_all_fp_dtypes():
data = [torch.randn(s, dtype=dt, device=device) for s in shapes]

for x in data:
buf = None

def foo(t):
nonlocal buf
buf = repr(t)
return t.mean()

fn = foo
bdim = 0
for op in reversed(op_list):
if op == vmap:
fn = op(fn, in_dims=bdim)
bdim += 1
else:
fn = op(fn)

expected = f"{repr(x)}"
level = 0
for op in op_list:
level += 1
if op == grad:
expected = f"GradTrackingTensor(lvl={level}, value={expected})"
elif op == vmap:
bdim -= 1
expected = f"BatchedTensor(lvl={level}, bdim={bdim}, value={expected})"

fn(x)
buf = buf.replace("\n", "").replace(" ", "")
expected = expected.replace("\n", "").replace(" ", "")
self.assertEqual(expected, buf)

def test_print_captured_tensor_inside_transform(self, device):
x = torch.tensor([1., 2., 3.], device=device)
out = None

def f(y):
nonlocal out
out = repr(x)
return y

vjp(f, torch.randn(4, device=device))
self.assertEqual(out, repr(x))

def test_no_grad_outside(self, device):
x = torch.randn([], device=device, requires_grad=True)
with torch.no_grad():
y = grad(torch.sin)(x)
self.assertEqual(y, x.cos())
self.assertFalse(y.requires_grad)

def test_no_grad_inside(self, device):
def f(x):
with torch.no_grad():
shift = x ** 2
return x ** 2 - shift

x = torch.randn([], device=device)
y = grad(f)(x)
self.assertEqual(y, 2 * x)
y = grad(grad(f))(x)
self.assertEqual(y, 2)

x = torch.randn([], device=device, requires_grad=True)
y = grad(f)(x)
z, = torch.autograd.grad(y, x)
self.assertEqual(z, 2)

def test_no_grad_mixed(self, device):
def f(x):
with torch.no_grad():
shift = x ** 2
return x ** 2 - shift

x = torch.randn([], device=device, requires_grad=True)
with torch.no_grad():
y = grad(f)(x)

self.assertEqual(y, 2 * x)
self.assertFalse(y.requires_grad)

def test_no_grad_nested_simple(self, device):
def h(x):
with torch.no_grad():
shift = grad(lambda x: 0.25 * x ** 4)(x)
return x ** 3 - shift

x = torch.tensor(1.5, device=device, requires_grad=True)
y = grad(h)(x)
self.assertEqual(y, 3 * x ** 2)

z, = torch.autograd.grad(y, x)
self.assertEqual(z, 6 * x)

def test_no_grad_nested_complicated(self, device):
def f(x):
with torch.no_grad():
shift = x ** 3
return x ** 3 - shift

def g(x):
r1 = grad(f)(x)
with torch.no_grad():
shift = grad(f)(x)
return r1 - shift

x = torch.randn([], requires_grad=True, device=device)
y = grad(g)(x)
# The only differential part of g is x ** 3
self.assertEqual(y, 6 * x)

z, = torch.autograd.grad(y, x)
self.assertEqual(z, 6)

def test_no_grad_value(self, device):
def h(x):
with torch.no_grad():
gvalue, value = grad_and_value(lambda x: x ** 3)(x)
return x ** 3 - value

x = torch.tensor(1.6, device=device, requires_grad=True)
y = grad(h)(x)
self.assertEqual(y, 3 * x ** 2)

z, = torch.autograd.grad(y, x)
self.assertEqual(z, 6 * x)

def test_no_grad_outside_vjp(self, device):
def h(x):
return x ** 2

x = torch.tensor(2., requires_grad=True, device=device)
with torch.no_grad():
out, vjp_fn = vjp(h, x)
y, = vjp_fn(torch.tensor(1., device=device))

self.assertEqual(y, 2 * x)
self.assertFalse(y.requires_grad)
self.assertFalse(out.requires_grad)

def test_no_grad_outside_vjp_fn(self, device):
def h(x):
return x ** 2

x = torch.tensor(3.14, requires_grad=True, device=device)
out, vjp_fn = vjp(h, x)
with torch.no_grad():
y, = vjp_fn(torch.tensor(1., device=device))

self.assertEqual(y, 2 * x)
self.assertFalse(y.requires_grad)
self.assertTrue(out.requires_grad)

z, = torch.autograd.grad(out, x)
self.assertEqual(z, 2 * x)

def test_no_grad_outside_vjp_only(self, device):
def h(x):
return x ** 2

x = torch.tensor(3.14, requires_grad=True, device=device)
with torch.no_grad():
out, vjp_fn = vjp(h, x)
y, = vjp_fn(torch.tensor(1., device=device))

self.assertEqual(y, 2 * x)
self.assertFalse(out.requires_grad)

# This one is a little weird...
self.assertTrue(y.requires_grad)

z, = torch.autograd.grad(y, x)
self.assertEqual(z, 2)


class TestAutogradFunction(TestCase):
@_set_autograd_function_extension_enabled()
def test_set_materialize_grads(self, device):
class A(torch.autograd.Function):
@staticmethod
def forward(x, y):
return x, y

@staticmethod
def setup_context(ctx, inputs, outputs):
ctx.set_materialize_grads(False)

@staticmethod
def backward(ctx, gx, gy):
self.assertIsNotNone(gx)
self.assertIsNone(gy)
return gx, gy

def f(y, x):
x, y = A.apply(x, y)
return x ** 2

x = torch.tensor(2., device=device)
y = torch.tensor(3., device=device)
# grad differentiates w.r.t. arg 0 by default
grad(f)(y, x)
grad(grad(f))(y, x)

@_set_autograd_function_extension_enabled()
def test_needs_input_grads(self, device):
class A(torch.autograd.Function):
@staticmethod
def forward(x, y):
return x * y

@staticmethod
def setup_context(ctx, inputs, outputs):
return

@staticmethod
def backward(ctx, grad_output):
self.assertTrue(ctx.needs_input_grad[0])
self.assertFalse(ctx.needs_input_grad[1])
return None, None

x = torch.tensor(2., device=device)
y = torch.tensor(3., device=device)
# grad differentiates w.r.t. arg 0 by default
grad(A.apply)(x, y)
grad(grad(A.apply))(x, y)

def _get_NumpyCubeNotComposable(self):
class NumpyCubeNotComposable(torch.autograd.Function):
@staticmethod
def forward(input):
input_np = input.cpu().numpy()
return torch.tensor(input_np ** 3, device=input.device), input_np

@staticmethod
def setup_context(ctx, inputs, outputs):
ctx.input_np = outputs[1]
ctx.device = inputs[0].device

@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad_output, grad_saved):
result_np = 3 * (ctx.input_np ** 2)
return torch.tensor(result_np, device=ctx.device)

return NumpyCubeNotComposable

@_set_autograd_function_extension_enabled()
def test_once_differentiable_autograd_vjp(self, device):
NumpyCubeNotComposable = self._get_NumpyCubeNotComposable()

def f(x):
y, _ = NumpyCubeNotComposable.apply(x)
return y

# regular autograd x vjp
x = torch.randn([], requires_grad=True, device=device)
grad_y = torch.randn_like(x, requires_grad=True)
_, vjp_fn = vjp(f, x)
gx, = vjp_fn(grad_y)

with self.assertRaisesRegex(RuntimeError, "marked with @once_differentiable"):
gx.backward()

# TODO: support torch.autograd.function.once_differentiable
# (or, if impossible, figure out how to raise a nice error)
# https://github.com/pytorch/pytorch/issues/90224
@unittest.expectedFailure
@_set_autograd_function_extension_enabled()
def test_once_differentiable_grad_vjp(self, device):
NumpyCubeNotComposable = self._get_NumpyCubeNotComposable()

# grad x vjp
x = torch.randn([], device=device)
grad_y = torch.randn_like(x)

def h(x, grad_y):
_, vjp_fn = vjp(f, x)
gx, = vjp_fn(grad_y)
return gx

grad(h, argnums=(0, 1))(x, grad_y)

@_set_autograd_function_extension_enabled()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the repitition intentional?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meant to have negative value for second. Thanks!

@kshitij12345
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 19, 2022
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: torch.func release notes category for torch.vmap or torch.func.* APIs triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants