@@ -251,6 +251,7 @@ def vjp(func: Callable, *primals, has_aux: bool = False):
251251
252252 Case 2: Using ``vjp`` inside ``torch.no_grad`` context manager:
253253
254+ >>> # xdoctest: +SKIP(failing)
254255 >>> with torch.no_grad():
255256 >>> vjp(f)(x)
256257
@@ -1286,6 +1287,7 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla
12861287
12871288 Example of using ``grad``:
12881289
1290+ >>> # xdoctest: +SKIP
12891291 >>> from torch.func import grad
12901292 >>> x = torch.randn([])
12911293 >>> cos_x = grad(lambda x: torch.sin(x))(x)
@@ -1297,6 +1299,7 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla
12971299
12981300 When composed with ``vmap``, ``grad`` can be used to compute per-sample-gradients:
12991301
1302+ >>> # xdoctest: +SKIP
13001303 >>> from torch.func import grad, vmap
13011304 >>> batch_size, feature_size = 3, 5
13021305 >>>
@@ -1317,6 +1320,7 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla
13171320
13181321 Example of using ``grad`` with ``has_aux`` and ``argnums``:
13191322
1323+ >>> # xdoctest: +SKIP
13201324 >>> from torch.func import grad
13211325 >>> def my_loss_func(y, y_pred):
13221326 >>> loss_per_sample = (0.5 * y_pred - y) ** 2
@@ -1327,13 +1331,14 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla
13271331 >>> y_true = torch.rand(4)
13281332 >>> y_preds = torch.rand(4, requires_grad=True)
13291333 >>> out = fn(y_true, y_preds)
1330- >>> > output is ((grads w.r.t y_true, grads w.r.t y_preds), (y_pred, loss_per_sample))
1334+ >>> # > output is ((grads w.r.t y_true, grads w.r.t y_preds), (y_pred, loss_per_sample))
13311335
13321336 .. note::
13331337 Using PyTorch ``torch.no_grad`` together with ``grad``.
13341338
13351339 Case 1: Using ``torch.no_grad`` inside a function:
13361340
1341+ >>> # xdoctest: +SKIP
13371342 >>> def f(x):
13381343 >>> with torch.no_grad():
13391344 >>> c = x ** 2
@@ -1343,6 +1348,7 @@ def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Calla
13431348
13441349 Case 2: Using ``grad`` inside ``torch.no_grad`` context manager:
13451350
1351+ >>> # xdoctest: +SKIP
13461352 >>> with torch.no_grad():
13471353 >>> grad(f)(x)
13481354
@@ -1433,11 +1439,12 @@ def functionalize(func: Callable, *, remove: str = 'mutations') -> Callable:
14331439
14341440 Example::
14351441
1442+ >>> # xdoctest: +SKIP
14361443 >>> import torch
14371444 >>> from torch.fx.experimental.proxy_tensor import make_fx
14381445 >>> from torch.func import functionalize
14391446 >>>
1440- >>> A function that uses mutations and views, but only on intermediate tensors.
1447+ >>> # A function that uses mutations and views, but only on intermediate tensors.
14411448 >>> def f(a):
14421449 ... b = a + 1
14431450 ... c = b.view(-1)
@@ -1490,17 +1497,17 @@ def forward(self, a_1):
14901497 return view_copy_1
14911498
14921499
1493- >>> A function that mutates its input tensor
1500+ >>> # A function that mutates its input tensor
14941501 >>> def f(a):
14951502 ... b = a.view(-1)
14961503 ... b.add_(1)
14971504 ... return a
14981505 ...
14991506 >>> f_no_mutations_and_views_traced = make_fx(functionalize(f, remove='mutations_and_views'))(inpt)
1500- >>>
1501- >>> All mutations and views have been removed,
1502- >>> but there is an extra copy_ in the graph to correctly apply the mutation to the input
1503- >>> after the function has completed.
1507+ >>> #
1508+ >>> # All mutations and views have been removed,
1509+ >>> # but there is an extra copy_ in the graph to correctly apply the mutation to the input
1510+ >>> # after the function has completed.
15041511 >>> print(f_no_mutations_and_views_traced.code)
15051512
15061513
0 commit comments