Skip to content

Commit 3e6bb52

Browse files
Michael Carillifacebook-github-bot
authored andcommitted
Reference amp tutorial (recipe) from core amp docs (#44725)
Summary: https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html is live. Core amp docs should reference it. Also i fixed some typos in the `zero_grad` docs we ignored when git was behaving weirdly during ngimel 's merge of #44423. Pull Request resolved: #44725 Reviewed By: mruberry Differential Revision: D23723807 Pulled By: ngimel fbshipit-source-id: ca0b76365f8ca908bd978e3b38bf81857fa6c2a3
1 parent a011b86 commit 3e6bb52

File tree

4 files changed

+16
-11
lines changed

4 files changed

+16
-11
lines changed

docs/source/amp.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ are much faster in ``float16``. Other ops, like reductions, often require the dy
1414
range of ``float32``. Mixed precision tries to match each op to its appropriate datatype.
1515

1616
Ordinarily, "automatic mixed precision training" uses :class:`torch.cuda.amp.autocast` and
17-
:class:`torch.cuda.amp.GradScaler` together, as shown in the :ref:`Automatic Mixed Precision examples<amp-examples>`.
17+
:class:`torch.cuda.amp.GradScaler` together, as shown in the :ref:`Automatic Mixed Precision examples<amp-examples>`
18+
and `Automatic Mixed Precision recipe <https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html>`_.
1819
However, :class:`autocast` and :class:`GradScaler` are modular, and may be used separately if desired.
1920

2021
.. contents:: :local:

docs/source/notes/amp_examples.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ gradients by minimizing gradient underflow, as explained :ref:`here<gradient-sca
1919
:class:`torch.cuda.amp.autocast` and :class:`torch.cuda.amp.GradScaler` are modular.
2020
In the samples below, each is used as its individual documentation suggests.
2121

22+
(Samples here are illustrative. See the
23+
`Automatic Mixed Precision recipe <https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html>`_
24+
for a runnable walkthrough.)
25+
2226
.. contents:: :local:
2327

2428
Typical Mixed Precision Training

torch/nn/modules/module.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,11 +1315,11 @@ def requires_grad_(self: T, requires_grad: bool = True) -> T:
13151315

13161316
def zero_grad(self, set_to_none: bool = False) -> None:
13171317
r"""Sets gradients of all model parameters to zero. See similar function
1318-
under `torch.optimizer` for more contexts.
1318+
under :class:`torch.optim.Optimizer` for more context.
13191319
13201320
Arguments:
1321-
set_to_none (bool): instead of setting to zero, set the grad to None.
1322-
See :meth:`torch.optim.optimizer.zero_grad` for details.
1321+
set_to_none (bool): instead of setting to zero, set the grads to None.
1322+
See :meth:`torch.optim.Optimizer.zero_grad` for details.
13231323
"""
13241324
if getattr(self, '_is_replica', False):
13251325
warnings.warn(

torch/optim/optimizer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -165,18 +165,18 @@ def update_group(group, new_group):
165165
self.__setstate__({'state': state, 'param_groups': param_groups})
166166

167167
def zero_grad(self, set_to_none: bool = False):
168-
r"""Set the gradients of all optimized :class:`torch.Tensor` s to zero.
168+
r"""Sets the gradients of all optimized :class:`torch.Tensor` s to zero.
169169
170170
Arguments:
171-
set_to_none (bool): instead of setting to zero, set the grad to None.
171+
set_to_none (bool): instead of setting to zero, set the grads to None.
172172
This is will in general have lower memory footprint, and can modestly improve performance.
173173
However, it changes certain behaviors. For example:
174-
1. When user tries to access the gradient value and perform manual ops on it.
175-
A None attribute or a Tensor full of 0s will be different.
176-
2. If the user requests `zero_grad(set_to_none=True)` followed by a backward pass, `.grad` s
174+
1. When the user tries to access a gradient and perform manual ops on it,
175+
a None attribute or a Tensor full of 0s will behave differently.
176+
2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s
177177
are guaranteed to be None for params that did not receive a gradient.
178-
3. `torch.optim` optimizers have a different behavior if the gradient is 0 or None
179-
(in one case it does the step with a gradient of 0 and in the other it skip
178+
3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None
179+
(in one case it does the step with a gradient of 0 and in the other it skips
180180
the step altogether).
181181
"""
182182
for group in self.param_groups:

0 commit comments

Comments
 (0)