11import copy
22import functools
3+ from dataclasses import dataclass
34from typing import (
45 Any ,
56 cast ,
2122import torch .distributed .fsdp .fully_sharded_data_parallel as fsdp_file
2223import torch .nn as nn
2324from torch .distributed ._shard .sharded_tensor import ShardedTensor
24- from torch .distributed .fsdp ._common_utils import _get_param_to_fqns
25+ from torch .distributed .fsdp ._common_utils import (
26+ _apply_to_modules ,
27+ _get_param_to_fqns ,
28+ _module_handles ,
29+ clean_tensor_name ,
30+ )
2531from torch .distributed .fsdp ._fsdp_extensions import _ext_chunk_tensor
2632from torch .distributed .fsdp ._runtime_utils import _clear_grads_if_needed , _lazy_init
2733from torch .distributed .fsdp ._shard_utils import _gather_state_dict
2834from torch .distributed .fsdp .api import ShardingStrategy
2935from torch .distributed .fsdp .flat_param import FlatParameter , FlatParamHandle
3036
3137
38+ @dataclass
39+ class FSDPParamInfo :
40+ # The typing will be changed to FSDPState in the future.
41+ state : nn .Module
42+ flat_param : FlatParameter
43+ param_indices : Dict [str , int ]
44+
45+
3246def sorted_items (dictionary : Dict [str , Any ]) -> Iterator [Tuple [str , Any ]]:
3347 keys = sorted (dictionary .keys ())
3448 for k in keys :
@@ -84,7 +98,7 @@ class _OptimStateKey(NamedTuple):
8498 """
8599
86100 unflat_param_names : Tuple [str , ...]
87- is_flat_param : bool
101+ is_fsdp_managed : bool
88102
89103
90104def _unflatten_optim_state (
@@ -293,23 +307,21 @@ def _flatten_optim_state_dict(
293307 '`optim_state_dict` must have the keys "state" and '
294308 '"param_groups" to be a valid optimizer state dict'
295309 )
296- flat_param_to_fsdp_module = _get_flat_param_to_fsdp_module (model )
297310 param_to_fqns = _get_param_to_fqns (model )
311+ fqn_to_fsdp_param_info = _get_fqn_to_fsdp_param_info (model )
298312
299313 # Construct the "state" part
300314 flat_osd_state : Dict [_OptimStateKey , Any ] = {}
301315 unflat_osd_state = unflat_osd ["state" ]
302316 for param , unflat_param_names in param_to_fqns .items ():
303- if isinstance (param , FlatParameter ): # flatten FSDP parameters' states
304- assert (
305- param in flat_param_to_fsdp_module
306- ), f"Check the `flat_param_to_fsdp_module` construction\n param: { param } "
307- fsdp_module = flat_param_to_fsdp_module [param ]
317+ fqn = unflat_param_names [0 ]
318+ if fqn in fqn_to_fsdp_param_info :
319+ fsdp_param_info = fqn_to_fsdp_param_info [fqn ]
308320 flat_state = _flatten_optim_state (
309321 unflat_osd_state ,
310322 unflat_param_names ,
311- fsdp_module ,
312- param ,
323+ fsdp_param_info . state ,
324+ fsdp_param_info . flat_param ,
313325 shard_state ,
314326 )
315327 key = _OptimStateKey (tuple (unflat_param_names ), True )
@@ -670,7 +682,7 @@ def _process_pos_dim_tensor_state(
670682 if not is_pos_dim_tensor_state :
671683 no_tensor_osd ["state" ][key ][state_name ] = value
672684 continue
673- if key .is_flat_param : # FSDP parameter
685+ if key .is_fsdp_managed : # FSDP parameter
674686 sharded_size = FlatParamHandle ._get_sharded_size (
675687 value , rank = 0 , world_size = world_size
676688 )
@@ -753,7 +765,7 @@ def _broadcast_pos_dim_tensor_states(
753765 else :
754766 unsharded_tensor = None
755767 shape , dtype = value .shape , value .dtype
756- if key .is_flat_param : # FSDP parameter
768+ if key .is_fsdp_managed : # FSDP parameter
757769 _broadcast_sharded_pos_dim_tensor_state (
758770 unsharded_tensor ,
759771 param_state ,
@@ -1079,6 +1091,7 @@ def _map_param_id_to_optim_keys(
10791091 group : Optional [dist .ProcessGroup ],
10801092 param_id_to_param : List [nn .Parameter ],
10811093 param_to_fqns : Dict [nn .Parameter , List [str ]],
1094+ fqn_to_fsdp_param_info : Dict [str , FSDPParamInfo ],
10821095) -> Tuple [Dict [int , _OptimStateKey ], Dict [_OptimStateKey , int ]]:
10831096 """
10841097 Construct the local mapping between the `_OptimStateKey` and parameter IDs
@@ -1087,18 +1100,21 @@ def _map_param_id_to_optim_keys(
10871100 """
10881101 rank = dist .get_rank (group )
10891102 optim_state_key_to_param_id : Dict [_OptimStateKey , int ] = {} # local
1090- r0_param_id_to_optim_state_key : Dict [
1091- int , _OptimStateKey
1092- ] = {} # rank 0
1103+ r0_param_id_to_optim_state_key : Dict [int , _OptimStateKey ] = {} # rank 0
10931104
10941105 for param_id , param in enumerate (param_id_to_param ):
10951106 # Do not include parameters without state to avoid empty mappings
10961107 # just like in normal `torch.optim.Optimizer.state_dict()`
10971108 if param_id not in optim_state_dict ["state" ]:
10981109 continue
1110+ fqns = param_to_fqns [param ]
1111+ is_fsdp_managed = isinstance (param , FlatParameter )
1112+ if is_fsdp_managed :
1113+ assert fqns [0 ] in fqn_to_fsdp_param_info
1114+ is_fsdp_managed = fqns [0 ] in fqn_to_fsdp_param_info
10991115 optim_state_key = _OptimStateKey (
1100- unflat_param_names = tuple (param_to_fqns [ param ] ),
1101- is_flat_param = isinstance ( param , FlatParameter ) ,
1116+ unflat_param_names = tuple (fqns ),
1117+ is_fsdp_managed = is_fsdp_managed ,
11021118 )
11031119 if rank == 0 :
11041120 r0_param_id_to_optim_state_key [param_id ] = optim_state_key
@@ -1220,6 +1236,7 @@ def _optim_state_dict(
12201236 if using_optim_input
12211237 else _get_param_id_to_param (optim )
12221238 )
1239+ fqn_to_fsdp_param_info = _get_fqn_to_fsdp_param_info (model )
12231240
12241241 (
12251242 param_id_to_optim_state_key ,
@@ -1229,20 +1246,23 @@ def _optim_state_dict(
12291246 group ,
12301247 param_id_to_param ,
12311248 param_to_fqns ,
1249+ fqn_to_fsdp_param_info ,
12321250 )
1233- flat_param_to_fsdp_state = _get_flat_param_to_fsdp_module (model )
12341251
12351252 # Iterate in rank 0's flattened parameter ID order to ensure aligned
12361253 # all-gathers across ranks
12371254 for optim_state_key in param_id_to_optim_state_key .values ():
12381255 param_id = optim_state_key_to_param_id [optim_state_key ]
1239- if optim_state_key .is_flat_param :
1240- param = param_id_to_param [param_id ]
1241- fsdp_state = flat_param_to_fsdp_state [param ]
1256+ if optim_state_key .is_fsdp_managed :
1257+ # If there are multiple unflat_param_names (not use_orig_params),
1258+ # they share the same FSDPParamInfo. So the first unflat_param_name
1259+ # is sufficient to fetch the FSDPParamInfo.
1260+ fqn = optim_state_key .unflat_param_names [0 ]
1261+ fsdp_param_info = fqn_to_fsdp_param_info [fqn ]
12421262 unflat_state = _unflatten_optim_state (
1243- cast ( FlatParameter , param ) ,
1263+ fsdp_param_info . flat_param ,
12441264 optim_state_dict ["state" ][param_id ],
1245- fsdp_state ,
1265+ fsdp_param_info . state ,
12461266 to_save ,
12471267 shard_state ,
12481268 )
@@ -1269,3 +1289,43 @@ def _optim_state_dict(
12691289 )
12701290
12711291 return fsdp_osd
1292+
1293+
1294+ def _get_fqn_to_fsdp_param_info (model : nn .Module ) -> Dict [str , FSDPParamInfo ]:
1295+ """
1296+ Construct the mapping from a param's fqn to its corresponding ``FSDPParamInfo``
1297+ if the param is managed by FSDP. ``FlatParameter._fqns`` only stores the first
1298+ FQN of a shared parameter. So the keys in the mapping are guaranteed to map
1299+ to unique parameters.
1300+ """
1301+
1302+ def module_fn (module , prefix , fqn_to_param_info ):
1303+ # TODO: make it work with composable API.
1304+ if not isinstance (module , fsdp_file .FullyShardedDataParallel ):
1305+ return
1306+ _lazy_init (module , module )
1307+ handles = _module_handles (module , module )
1308+ if not handles :
1309+ return
1310+ flat_param = handles [0 ].flat_param
1311+ fsdp_param_info = FSDPParamInfo (module , flat_param , {})
1312+ for idx , local_fqn in enumerate (flat_param ._fqns ):
1313+ fqn = clean_tensor_name (prefix + local_fqn )
1314+ if fqn in fqn_to_param_info :
1315+ assert fqn_to_param_info [fqn ].flat_param == flat_param
1316+ fqn_to_param_info [fqn ] = fsdp_param_info
1317+ fsdp_param_info .param_indices [fqn ] = idx
1318+
1319+ def return_fn (fqn_to_param_info ):
1320+ return fqn_to_param_info
1321+
1322+ fqn_to_param_info : Dict [str , FSDPParamInfo ] = {}
1323+ # FlatParameter._fqns stores the local fqn, starting from the root of the
1324+ # FSDP. Using _apply_to_modules() with model (may not be the FSDP root
1325+ # module) allows us to construct the global fqn.
1326+ return _apply_to_modules (
1327+ model ,
1328+ module_fn ,
1329+ return_fn ,
1330+ fqn_to_param_info ,
1331+ )
0 commit comments