Skip to content

Commit fa4ca4e

Browse files
mrshenlifacebook-github-bot
authored andcommitted
Emphasize all DDP forward() outputs must participate in computing loss (#20586)
Summary: CC borguz chenyangyu1988 Pull Request resolved: #20586 Reviewed By: ezyang Differential Revision: D15373674 Pulled By: mrshenli fbshipit-source-id: b986918b3592616a9bcc88fba1b8fd53016f68d7
1 parent c941abb commit fa4ca4e

File tree

2 files changed

+19
-9
lines changed

2 files changed

+19
-9
lines changed

torch/csrc/distributed/c10d/reducer.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -395,17 +395,19 @@ void Reducer::prepare_for_backward(
395395
"starting a new one. ",
396396
"",
397397
"This error indicates that your module has parameters that were ",
398-
"not used in producing its output (the return value of `forward`). ",
398+
"not used in producing loss. ",
399399
"",
400-
"You can enable unused parameter detection by passing the keyword "
400+
"You can enable unused parameter detection by (1) passing the keyword "
401401
"argument `find_unused_parameters=True` to ",
402-
"`torch.nn.parallel.DistributedDataParallel`. ",
402+
"`torch.nn.parallel.DistributedDataParallel`; (2) making sure all ",
403+
"`forward` function outputs participate in calculating loss. "
403404
"",
404-
"If you already have this argument set, then the distributed data ",
405-
"parallel module wasn't able to locate the output tensors in the ",
405+
"If you already have done the above two steps, then the distributed ",
406+
"data parallel module wasn't able to locate the output tensors in the ",
406407
"return value of your module's `forward` function. ",
407-
"Please include the structure of the return value of `forward` of ",
408-
"your module when reporting this issue (e.g. list, dict, iterable).");
408+
"Please include the loss function and the structure of the return ",
409+
"value of `forward` of your module when reporting this issue (e.g. ",
410+
"list, dict, iterable).");
409411
}
410412

411413
// Reset accounting.

torch/nn/parallel/distributed.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,16 @@ class DistributedDataParallel(Module):
197197
module's ``forward`` function.
198198
Parameters that don't receive gradients as
199199
part of this graph are preemptively marked
200-
as being ready to be reduced.
201-
(default: ``False``)
200+
as being ready to be reduced. Note that all
201+
``forward`` outputs that are derived from
202+
module parameters must participate in
203+
calculating loss and later the gradient
204+
computation. If they don't, this wrapper will
205+
hang waiting for autograd to produce gradients
206+
for those parameters. Any outputs derived from
207+
module parameters that are otherwise unused can
208+
be detached from the autograd graph using
209+
``torch.Tensor.detach``. (default: ``False``)
202210
check_reduction: when setting to ``True``, it enables DistributedDataParallel
203211
to automatically check if the previous iteration's
204212
backward reductions were successfully issued at the

0 commit comments

Comments
 (0)