Skip to content

Commit 44779d9

Browse files
feginpytorchmergebot
authored andcommitted
[FSDP][optim_state_dict][2/N] Add _get_fqn_to_fsdp_param_info to map from original FQN to flat_param (#89899)
**Motivation:** Add a helper to map from the FQN to the corresponding flat_param. The helper will directly get flat_param from fsdp_state and flat_handler as flat_param is not registered to the module if `use_orig_params` is True. Pull Request resolved: #89899 Approved by: https://github.com/awgu
1 parent f7cdd3a commit 44779d9

File tree

1 file changed

+83
-23
lines changed

1 file changed

+83
-23
lines changed

torch/distributed/fsdp/_optim_utils.py

Lines changed: 83 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import copy
22
import functools
3+
from dataclasses import dataclass
34
from typing import (
45
Any,
56
cast,
@@ -21,14 +22,27 @@
2122
import torch.distributed.fsdp.fully_sharded_data_parallel as fsdp_file
2223
import torch.nn as nn
2324
from torch.distributed._shard.sharded_tensor import ShardedTensor
24-
from torch.distributed.fsdp._common_utils import _get_param_to_fqns
25+
from torch.distributed.fsdp._common_utils import (
26+
_apply_to_modules,
27+
_get_param_to_fqns,
28+
_module_handles,
29+
clean_tensor_name,
30+
)
2531
from torch.distributed.fsdp._fsdp_extensions import _ext_chunk_tensor
2632
from torch.distributed.fsdp._runtime_utils import _clear_grads_if_needed, _lazy_init
2733
from torch.distributed.fsdp._shard_utils import _gather_state_dict
2834
from torch.distributed.fsdp.api import ShardingStrategy
2935
from torch.distributed.fsdp.flat_param import FlatParameter, FlatParamHandle
3036

3137

38+
@dataclass
39+
class FSDPParamInfo:
40+
# The typing will be changed to FSDPState in the future.
41+
state: nn.Module
42+
flat_param: FlatParameter
43+
param_indices: Dict[str, int]
44+
45+
3246
def sorted_items(dictionary: Dict[str, Any]) -> Iterator[Tuple[str, Any]]:
3347
keys = sorted(dictionary.keys())
3448
for k in keys:
@@ -84,7 +98,7 @@ class _OptimStateKey(NamedTuple):
8498
"""
8599

86100
unflat_param_names: Tuple[str, ...]
87-
is_flat_param: bool
101+
is_fsdp_managed: bool
88102

89103

90104
def _unflatten_optim_state(
@@ -293,23 +307,21 @@ def _flatten_optim_state_dict(
293307
'`optim_state_dict` must have the keys "state" and '
294308
'"param_groups" to be a valid optimizer state dict'
295309
)
296-
flat_param_to_fsdp_module = _get_flat_param_to_fsdp_module(model)
297310
param_to_fqns = _get_param_to_fqns(model)
311+
fqn_to_fsdp_param_info = _get_fqn_to_fsdp_param_info(model)
298312

299313
# Construct the "state" part
300314
flat_osd_state: Dict[_OptimStateKey, Any] = {}
301315
unflat_osd_state = unflat_osd["state"]
302316
for param, unflat_param_names in param_to_fqns.items():
303-
if isinstance(param, FlatParameter): # flatten FSDP parameters' states
304-
assert (
305-
param in flat_param_to_fsdp_module
306-
), f"Check the `flat_param_to_fsdp_module` construction\nparam: {param}"
307-
fsdp_module = flat_param_to_fsdp_module[param]
317+
fqn = unflat_param_names[0]
318+
if fqn in fqn_to_fsdp_param_info:
319+
fsdp_param_info = fqn_to_fsdp_param_info[fqn]
308320
flat_state = _flatten_optim_state(
309321
unflat_osd_state,
310322
unflat_param_names,
311-
fsdp_module,
312-
param,
323+
fsdp_param_info.state,
324+
fsdp_param_info.flat_param,
313325
shard_state,
314326
)
315327
key = _OptimStateKey(tuple(unflat_param_names), True)
@@ -670,7 +682,7 @@ def _process_pos_dim_tensor_state(
670682
if not is_pos_dim_tensor_state:
671683
no_tensor_osd["state"][key][state_name] = value
672684
continue
673-
if key.is_flat_param: # FSDP parameter
685+
if key.is_fsdp_managed: # FSDP parameter
674686
sharded_size = FlatParamHandle._get_sharded_size(
675687
value, rank=0, world_size=world_size
676688
)
@@ -753,7 +765,7 @@ def _broadcast_pos_dim_tensor_states(
753765
else:
754766
unsharded_tensor = None
755767
shape, dtype = value.shape, value.dtype
756-
if key.is_flat_param: # FSDP parameter
768+
if key.is_fsdp_managed: # FSDP parameter
757769
_broadcast_sharded_pos_dim_tensor_state(
758770
unsharded_tensor,
759771
param_state,
@@ -1079,6 +1091,7 @@ def _map_param_id_to_optim_keys(
10791091
group: Optional[dist.ProcessGroup],
10801092
param_id_to_param: List[nn.Parameter],
10811093
param_to_fqns: Dict[nn.Parameter, List[str]],
1094+
fqn_to_fsdp_param_info: Dict[str, FSDPParamInfo],
10821095
) -> Tuple[Dict[int, _OptimStateKey], Dict[_OptimStateKey, int]]:
10831096
"""
10841097
Construct the local mapping between the `_OptimStateKey` and parameter IDs
@@ -1087,18 +1100,21 @@ def _map_param_id_to_optim_keys(
10871100
"""
10881101
rank = dist.get_rank(group)
10891102
optim_state_key_to_param_id: Dict[_OptimStateKey, int] = {} # local
1090-
r0_param_id_to_optim_state_key: Dict[
1091-
int, _OptimStateKey
1092-
] = {} # rank 0
1103+
r0_param_id_to_optim_state_key: Dict[int, _OptimStateKey] = {} # rank 0
10931104

10941105
for param_id, param in enumerate(param_id_to_param):
10951106
# Do not include parameters without state to avoid empty mappings
10961107
# just like in normal `torch.optim.Optimizer.state_dict()`
10971108
if param_id not in optim_state_dict["state"]:
10981109
continue
1110+
fqns = param_to_fqns[param]
1111+
is_fsdp_managed = isinstance(param, FlatParameter)
1112+
if is_fsdp_managed:
1113+
assert fqns[0] in fqn_to_fsdp_param_info
1114+
is_fsdp_managed = fqns[0] in fqn_to_fsdp_param_info
10991115
optim_state_key = _OptimStateKey(
1100-
unflat_param_names=tuple(param_to_fqns[param]),
1101-
is_flat_param=isinstance(param, FlatParameter),
1116+
unflat_param_names=tuple(fqns),
1117+
is_fsdp_managed=is_fsdp_managed,
11021118
)
11031119
if rank == 0:
11041120
r0_param_id_to_optim_state_key[param_id] = optim_state_key
@@ -1220,6 +1236,7 @@ def _optim_state_dict(
12201236
if using_optim_input
12211237
else _get_param_id_to_param(optim)
12221238
)
1239+
fqn_to_fsdp_param_info = _get_fqn_to_fsdp_param_info(model)
12231240

12241241
(
12251242
param_id_to_optim_state_key,
@@ -1229,20 +1246,23 @@ def _optim_state_dict(
12291246
group,
12301247
param_id_to_param,
12311248
param_to_fqns,
1249+
fqn_to_fsdp_param_info,
12321250
)
1233-
flat_param_to_fsdp_state = _get_flat_param_to_fsdp_module(model)
12341251

12351252
# Iterate in rank 0's flattened parameter ID order to ensure aligned
12361253
# all-gathers across ranks
12371254
for optim_state_key in param_id_to_optim_state_key.values():
12381255
param_id = optim_state_key_to_param_id[optim_state_key]
1239-
if optim_state_key.is_flat_param:
1240-
param = param_id_to_param[param_id]
1241-
fsdp_state = flat_param_to_fsdp_state[param]
1256+
if optim_state_key.is_fsdp_managed:
1257+
# If there are multiple unflat_param_names (not use_orig_params),
1258+
# they share the same FSDPParamInfo. So the first unflat_param_name
1259+
# is sufficient to fetch the FSDPParamInfo.
1260+
fqn = optim_state_key.unflat_param_names[0]
1261+
fsdp_param_info = fqn_to_fsdp_param_info[fqn]
12421262
unflat_state = _unflatten_optim_state(
1243-
cast(FlatParameter, param),
1263+
fsdp_param_info.flat_param,
12441264
optim_state_dict["state"][param_id],
1245-
fsdp_state,
1265+
fsdp_param_info.state,
12461266
to_save,
12471267
shard_state,
12481268
)
@@ -1269,3 +1289,43 @@ def _optim_state_dict(
12691289
)
12701290

12711291
return fsdp_osd
1292+
1293+
1294+
def _get_fqn_to_fsdp_param_info(model: nn.Module) -> Dict[str, FSDPParamInfo]:
1295+
"""
1296+
Construct the mapping from a param's fqn to its corresponding ``FSDPParamInfo``
1297+
if the param is managed by FSDP. ``FlatParameter._fqns`` only stores the first
1298+
FQN of a shared parameter. So the keys in the mapping are guaranteed to map
1299+
to unique parameters.
1300+
"""
1301+
1302+
def module_fn(module, prefix, fqn_to_param_info):
1303+
# TODO: make it work with composable API.
1304+
if not isinstance(module, fsdp_file.FullyShardedDataParallel):
1305+
return
1306+
_lazy_init(module, module)
1307+
handles = _module_handles(module, module)
1308+
if not handles:
1309+
return
1310+
flat_param = handles[0].flat_param
1311+
fsdp_param_info = FSDPParamInfo(module, flat_param, {})
1312+
for idx, local_fqn in enumerate(flat_param._fqns):
1313+
fqn = clean_tensor_name(prefix + local_fqn)
1314+
if fqn in fqn_to_param_info:
1315+
assert fqn_to_param_info[fqn].flat_param == flat_param
1316+
fqn_to_param_info[fqn] = fsdp_param_info
1317+
fsdp_param_info.param_indices[fqn] = idx
1318+
1319+
def return_fn(fqn_to_param_info):
1320+
return fqn_to_param_info
1321+
1322+
fqn_to_param_info: Dict[str, FSDPParamInfo] = {}
1323+
# FlatParameter._fqns stores the local fqn, starting from the root of the
1324+
# FSDP. Using _apply_to_modules() with model (may not be the FSDP root
1325+
# module) allows us to construct the global fqn.
1326+
return _apply_to_modules(
1327+
model,
1328+
module_fn,
1329+
return_fn,
1330+
fqn_to_param_info,
1331+
)

0 commit comments

Comments
 (0)