Skip to content
Closed
106 changes: 83 additions & 23 deletions torch/distributed/fsdp/_optim_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import functools
from dataclasses import dataclass
from typing import (
Any,
cast,
Expand All @@ -21,14 +22,27 @@
import torch.distributed.fsdp.fully_sharded_data_parallel as fsdp_file
import torch.nn as nn
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed.fsdp._common_utils import _get_param_to_fqns
from torch.distributed.fsdp._common_utils import (
_apply_to_modules,
_get_param_to_fqns,
_module_handles,
clean_tensor_name,
)
from torch.distributed.fsdp._fsdp_extensions import _ext_chunk_tensor
from torch.distributed.fsdp._runtime_utils import _clear_grads_if_needed, _lazy_init
from torch.distributed.fsdp._shard_utils import _gather_state_dict
from torch.distributed.fsdp.api import ShardingStrategy
from torch.distributed.fsdp.flat_param import FlatParameter, FlatParamHandle


@dataclass
class FSDPParamInfo:
# The typing will be changed to FSDPState in the future.
state: nn.Module
flat_param: FlatParameter
param_indices: Dict[str, int]


def sorted_items(dictionary: Dict[str, Any]) -> Iterator[Tuple[str, Any]]:
keys = sorted(dictionary.keys())
for k in keys:
Expand Down Expand Up @@ -84,7 +98,7 @@ class _OptimStateKey(NamedTuple):
"""

unflat_param_names: Tuple[str, ...]
is_flat_param: bool
is_fsdp_managed: bool


def _unflatten_optim_state(
Expand Down Expand Up @@ -293,23 +307,21 @@ def _flatten_optim_state_dict(
'`optim_state_dict` must have the keys "state" and '
'"param_groups" to be a valid optimizer state dict'
)
flat_param_to_fsdp_module = _get_flat_param_to_fsdp_module(model)
param_to_fqns = _get_param_to_fqns(model)
fqn_to_fsdp_param_info = _get_fqn_to_fsdp_param_info(model)

# Construct the "state" part
flat_osd_state: Dict[_OptimStateKey, Any] = {}
unflat_osd_state = unflat_osd["state"]
for param, unflat_param_names in param_to_fqns.items():
if isinstance(param, FlatParameter): # flatten FSDP parameters' states
assert (
param in flat_param_to_fsdp_module
), f"Check the `flat_param_to_fsdp_module` construction\nparam: {param}"
fsdp_module = flat_param_to_fsdp_module[param]
fqn = unflat_param_names[0]
if fqn in fqn_to_fsdp_param_info:
fsdp_param_info = fqn_to_fsdp_param_info[fqn]
flat_state = _flatten_optim_state(
unflat_osd_state,
unflat_param_names,
fsdp_module,
param,
fsdp_param_info.state,
fsdp_param_info.flat_param,
shard_state,
)
key = _OptimStateKey(tuple(unflat_param_names), True)
Expand Down Expand Up @@ -670,7 +682,7 @@ def _process_pos_dim_tensor_state(
if not is_pos_dim_tensor_state:
no_tensor_osd["state"][key][state_name] = value
continue
if key.is_flat_param: # FSDP parameter
if key.is_fsdp_managed: # FSDP parameter
sharded_size = FlatParamHandle._get_sharded_size(
value, rank=0, world_size=world_size
)
Expand Down Expand Up @@ -753,7 +765,7 @@ def _broadcast_pos_dim_tensor_states(
else:
unsharded_tensor = None
shape, dtype = value.shape, value.dtype
if key.is_flat_param: # FSDP parameter
if key.is_fsdp_managed: # FSDP parameter
_broadcast_sharded_pos_dim_tensor_state(
unsharded_tensor,
param_state,
Expand Down Expand Up @@ -1079,6 +1091,7 @@ def _map_param_id_to_optim_keys(
group: Optional[dist.ProcessGroup],
param_id_to_param: List[nn.Parameter],
param_to_fqns: Dict[nn.Parameter, List[str]],
fqn_to_fsdp_param_info: Dict[str, FSDPParamInfo],
) -> Tuple[Dict[int, _OptimStateKey], Dict[_OptimStateKey, int]]:
"""
Construct the local mapping between the `_OptimStateKey` and parameter IDs
Expand All @@ -1087,18 +1100,21 @@ def _map_param_id_to_optim_keys(
"""
rank = dist.get_rank(group)
optim_state_key_to_param_id: Dict[_OptimStateKey, int] = {} # local
r0_param_id_to_optim_state_key: Dict[
int, _OptimStateKey
] = {} # rank 0
r0_param_id_to_optim_state_key: Dict[int, _OptimStateKey] = {} # rank 0

for param_id, param in enumerate(param_id_to_param):
# Do not include parameters without state to avoid empty mappings
# just like in normal `torch.optim.Optimizer.state_dict()`
if param_id not in optim_state_dict["state"]:
continue
fqns = param_to_fqns[param]
is_fsdp_managed = isinstance(param, FlatParameter)
if is_fsdp_managed:
assert fqns[0] in fqn_to_fsdp_param_info
is_fsdp_managed = fqns[0] in fqn_to_fsdp_param_info
optim_state_key = _OptimStateKey(
unflat_param_names=tuple(param_to_fqns[param]),
is_flat_param=isinstance(param, FlatParameter),
unflat_param_names=tuple(fqns),
is_fsdp_managed=is_fsdp_managed,
)
if rank == 0:
r0_param_id_to_optim_state_key[param_id] = optim_state_key
Expand Down Expand Up @@ -1220,6 +1236,7 @@ def _optim_state_dict(
if using_optim_input
else _get_param_id_to_param(optim)
)
fqn_to_fsdp_param_info = _get_fqn_to_fsdp_param_info(model)

(
param_id_to_optim_state_key,
Expand All @@ -1229,20 +1246,23 @@ def _optim_state_dict(
group,
param_id_to_param,
param_to_fqns,
fqn_to_fsdp_param_info,
)
flat_param_to_fsdp_state = _get_flat_param_to_fsdp_module(model)

# Iterate in rank 0's flattened parameter ID order to ensure aligned
# all-gathers across ranks
for optim_state_key in param_id_to_optim_state_key.values():
param_id = optim_state_key_to_param_id[optim_state_key]
if optim_state_key.is_flat_param:
param = param_id_to_param[param_id]
fsdp_state = flat_param_to_fsdp_state[param]
if optim_state_key.is_fsdp_managed:
# If there are multiple unflat_param_names (not use_orig_params),
# they share the same FSDPParamInfo. So the first unflat_param_name
# is sufficient to fetch the FSDPParamInfo.
fqn = optim_state_key.unflat_param_names[0]
fsdp_param_info = fqn_to_fsdp_param_info[fqn]
unflat_state = _unflatten_optim_state(
cast(FlatParameter, param),
fsdp_param_info.flat_param,
optim_state_dict["state"][param_id],
fsdp_state,
fsdp_param_info.state,
to_save,
shard_state,
)
Expand All @@ -1269,3 +1289,43 @@ def _optim_state_dict(
)

return fsdp_osd


def _get_fqn_to_fsdp_param_info(model: nn.Module) -> Dict[str, FSDPParamInfo]:
"""
Construct the mapping from a param's fqn to its corresponding ``FSDPParamInfo``
if the param is managed by FSDP. ``FlatParameter._fqns`` only stores the first
FQN of a shared parameter. So the keys in the mapping are guaranteed to map
to unique parameters.
"""

def module_fn(module, prefix, fqn_to_param_info):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a comment saying we need to use _apply_to_modules to get the global FQN (since the saved FQNs are like local FQNs, not necessarily prefixed from the global root module)?

# TODO: make it work with composable API.
if not isinstance(module, fsdp_file.FullyShardedDataParallel):
return
_lazy_init(module, module)
handles = _module_handles(module, module)
if not handles:
return
flat_param = handles[0].flat_param
fsdp_param_info = FSDPParamInfo(module, flat_param, {})
for idx, local_fqn in enumerate(flat_param._fqns):
fqn = clean_tensor_name(prefix + local_fqn)
if fqn in fqn_to_param_info:
assert fqn_to_param_info[fqn].flat_param == flat_param
fqn_to_param_info[fqn] = fsdp_param_info
fsdp_param_info.param_indices[fqn] = idx

def return_fn(fqn_to_param_info):
return fqn_to_param_info

fqn_to_param_info: Dict[str, FSDPParamInfo] = {}
# FlatParameter._fqns stores the local fqn, starting from the root of the
# FSDP. Using _apply_to_modules() with model (may not be the FSDP root
# module) allows us to construct the global fqn.
return _apply_to_modules(
model,
module_fn,
return_fn,
fqn_to_param_info,
)