@@ -169,7 +169,9 @@ Here's an ordinary example of an L2 penalty without gradient scaling or autocast
169169 loss = loss_fn(output, target)
170170
171171 # Creates gradients
172- grad_params = torch.autograd.grad(loss, model.parameters(), create_graph=True)
172+ grad_params = torch.autograd.grad(outputs=loss,
173+ inputs=model.parameters(),
174+ create_graph=True)
173175
174176 # Computes the penalty term and adds it to the loss
175177 grad_norm = 0
@@ -184,8 +186,8 @@ Here's an ordinary example of an L2 penalty without gradient scaling or autocast
184186
185187 optimizer.step()
186188
187- To implement a gradient penalty *with * gradient scaling, the loss passed to
188- :func: `torch.autograd.grad ` should be scaled. The resulting gradients
189+ To implement a gradient penalty *with * gradient scaling, the `` outputs `` Tensor(s)
190+ passed to :func: `torch.autograd.grad ` should be scaled. The resulting gradients
189191will therefore be scaled, and should be unscaled before being combined to create the
190192penalty value.
191193
@@ -203,8 +205,10 @@ Here's how that looks for the same L2 penalty::
203205 output = model(input)
204206 loss = loss_fn(output, target)
205207
206- # Scales the loss for autograd.grad's backward pass, resulting in scaled grad_params
207- scaled_grad_params = torch.autograd.grad(scaler.scale(loss), model.parameters(), create_graph=True)
208+ # Scales the loss for autograd.grad's backward pass, producing scaled_grad_params
209+ scaled_grad_params = torch.autograd.grad(outputs=scaler.scale(loss),
210+ inputs=model.parameters(),
211+ create_graph=True)
208212
209213 # Creates unscaled grad_params before computing the penalty. scaled_grad_params are
210214 # not owned by any optimizer, so ordinary division is used instead of scaler.unscale_:
0 commit comments