Skip to content

Commit 2fd142a

Browse files
Michael Carillifacebook-github-bot
authored andcommitted
Small clarification to amp gradient penalty example (#44667)
Summary: requested by https://discuss.pytorch.org/t/what-is-the-correct-way-of-computing-a-grad-penalty-using-amp/95827/3 Pull Request resolved: #44667 Reviewed By: mruberry Differential Revision: D23692768 Pulled By: ngimel fbshipit-source-id: 83c61b94e79ef9f86abed2cc066f188dce0c8456
1 parent aedce77 commit 2fd142a

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

docs/source/notes/amp_examples.rst

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,9 @@ Here's an ordinary example of an L2 penalty without gradient scaling or autocast
169169
loss = loss_fn(output, target)
170170

171171
# Creates gradients
172-
grad_params = torch.autograd.grad(loss, model.parameters(), create_graph=True)
172+
grad_params = torch.autograd.grad(outputs=loss,
173+
inputs=model.parameters(),
174+
create_graph=True)
173175

174176
# Computes the penalty term and adds it to the loss
175177
grad_norm = 0
@@ -184,8 +186,8 @@ Here's an ordinary example of an L2 penalty without gradient scaling or autocast
184186

185187
optimizer.step()
186188

187-
To implement a gradient penalty *with* gradient scaling, the loss passed to
188-
:func:`torch.autograd.grad` should be scaled. The resulting gradients
189+
To implement a gradient penalty *with* gradient scaling, the ``outputs`` Tensor(s)
190+
passed to :func:`torch.autograd.grad` should be scaled. The resulting gradients
189191
will therefore be scaled, and should be unscaled before being combined to create the
190192
penalty value.
191193

@@ -203,8 +205,10 @@ Here's how that looks for the same L2 penalty::
203205
output = model(input)
204206
loss = loss_fn(output, target)
205207

206-
# Scales the loss for autograd.grad's backward pass, resulting in scaled grad_params
207-
scaled_grad_params = torch.autograd.grad(scaler.scale(loss), model.parameters(), create_graph=True)
208+
# Scales the loss for autograd.grad's backward pass, producing scaled_grad_params
209+
scaled_grad_params = torch.autograd.grad(outputs=scaler.scale(loss),
210+
inputs=model.parameters(),
211+
create_graph=True)
208212

209213
# Creates unscaled grad_params before computing the penalty. scaled_grad_params are
210214
# not owned by any optimizer, so ordinary division is used instead of scaler.unscale_:

0 commit comments

Comments
 (0)