Skip to content

Commit c676100

Browse files
author
Andrew Gu
committed
[Easy][FSDP] Update full osd warning
ghstack-source-id: a4989e3 Pull Request resolved: #75109
1 parent 2bfa018 commit c676100

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

torch/distributed/fsdp/fully_sharded_data_parallel.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2056,7 +2056,8 @@ def full_optim_state_dict(
20562056
contained in ``model`` are mapped back to their unflattened parameters.
20572057
20582058
.. warning:: This needs to be called on all ranks since synchronization
2059-
primitives are used.
2059+
primitives are used. However, the state dict is only populated on
2060+
rank 0. All other ranks return an empty :class:`dict`.
20602061
20612062
.. warning:: Unlike ``torch.optim.Optimizer.state_dict()``, this method
20622063
uses full parameter names as keys instead of parameter IDs.
@@ -2087,7 +2088,8 @@ def full_optim_state_dict(
20872088
full_osd (Dict[str, Any]): A :class:`dict` containing the optimizer
20882089
state for ``model`` 's original unflattened parameters and
20892090
including keys "state" and "param_groups" following the
2090-
convention of :meth:`torch.optim.Optimizer.state_dict`.
2091+
convention of :meth:`torch.optim.Optimizer.state_dict` if on
2092+
rank 0, and an empty :class:`dict` otherwise.
20912093
"""
20922094
osd = optim.state_dict()
20932095
osd_state, osd_param_groups = osd["state"], osd["param_groups"] # alias

0 commit comments

Comments
 (0)