-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[FSDP][optim_state_dict][2/N] Add _get_fqn_to_fsdp_param_info to map from original FQN to flat_param #89899
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
[FSDP][optim_state_dict][2/N] Add _get_fqn_to_fsdp_param_info to map from original FQN to flat_param #89899
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
7d3dc02
[FSDP][optim_state_dict][2/N] Add a helper to map from original FQN t…
fegin 793e02e
Update on "[FSDP][optim_state_dict][2/N] Add a helper to map from ori…
fegin 9926172
Update on "[FSDP][optim_state_dict][2/N] Add a helper to map from ori…
fegin 87e89c5
Update on "[FSDP][optim_state_dict][2/N] Add a helper to map from ori…
fegin 11ff9d0
Update on "[FSDP][optim_state_dict][2/N] Add a helper to map from ori…
fegin 3f67a55
Update on "[FSDP][optim_state_dict][2/N] Add _get_fqn_to_fsdp_param_i…
fegin 23688d1
Update on "[FSDP][optim_state_dict][2/N] Add _get_fqn_to_fsdp_param_i…
fegin ec6526f
Update on "[FSDP][optim_state_dict][2/N] Add _get_fqn_to_fsdp_param_i…
fegin 97ff650
Update on "[FSDP][optim_state_dict][2/N] Add _get_fqn_to_fsdp_param_i…
fegin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| import copy | ||
| import functools | ||
| from dataclasses import dataclass | ||
| from typing import ( | ||
| Any, | ||
| cast, | ||
|
|
@@ -21,14 +22,27 @@ | |
| import torch.distributed.fsdp.fully_sharded_data_parallel as fsdp_file | ||
| import torch.nn as nn | ||
| from torch.distributed._shard.sharded_tensor import ShardedTensor | ||
| from torch.distributed.fsdp._common_utils import _get_param_to_fqns | ||
| from torch.distributed.fsdp._common_utils import ( | ||
| _apply_to_modules, | ||
| _get_param_to_fqns, | ||
| _module_handles, | ||
| clean_tensor_name, | ||
| ) | ||
| from torch.distributed.fsdp._fsdp_extensions import _ext_chunk_tensor | ||
| from torch.distributed.fsdp._runtime_utils import _clear_grads_if_needed, _lazy_init | ||
| from torch.distributed.fsdp._shard_utils import _gather_state_dict | ||
| from torch.distributed.fsdp.api import ShardingStrategy | ||
| from torch.distributed.fsdp.flat_param import FlatParameter, FlatParamHandle | ||
|
|
||
|
|
||
| @dataclass | ||
| class FSDPParamInfo: | ||
| # The typing will be changed to FSDPState in the future. | ||
| state: nn.Module | ||
| flat_param: FlatParameter | ||
| param_indices: Dict[str, int] | ||
|
|
||
|
|
||
| def sorted_items(dictionary: Dict[str, Any]) -> Iterator[Tuple[str, Any]]: | ||
| keys = sorted(dictionary.keys()) | ||
| for k in keys: | ||
|
|
@@ -84,7 +98,7 @@ class _OptimStateKey(NamedTuple): | |
| """ | ||
|
|
||
| unflat_param_names: Tuple[str, ...] | ||
| is_flat_param: bool | ||
| is_fsdp_managed: bool | ||
|
|
||
|
|
||
| def _unflatten_optim_state( | ||
|
|
@@ -293,23 +307,21 @@ def _flatten_optim_state_dict( | |
| '`optim_state_dict` must have the keys "state" and ' | ||
| '"param_groups" to be a valid optimizer state dict' | ||
| ) | ||
| flat_param_to_fsdp_module = _get_flat_param_to_fsdp_module(model) | ||
| param_to_fqns = _get_param_to_fqns(model) | ||
| fqn_to_fsdp_param_info = _get_fqn_to_fsdp_param_info(model) | ||
|
|
||
| # Construct the "state" part | ||
| flat_osd_state: Dict[_OptimStateKey, Any] = {} | ||
| unflat_osd_state = unflat_osd["state"] | ||
| for param, unflat_param_names in param_to_fqns.items(): | ||
| if isinstance(param, FlatParameter): # flatten FSDP parameters' states | ||
| assert ( | ||
| param in flat_param_to_fsdp_module | ||
| ), f"Check the `flat_param_to_fsdp_module` construction\nparam: {param}" | ||
| fsdp_module = flat_param_to_fsdp_module[param] | ||
| fqn = unflat_param_names[0] | ||
| if fqn in fqn_to_fsdp_param_info: | ||
| fsdp_param_info = fqn_to_fsdp_param_info[fqn] | ||
| flat_state = _flatten_optim_state( | ||
| unflat_osd_state, | ||
| unflat_param_names, | ||
| fsdp_module, | ||
| param, | ||
| fsdp_param_info.state, | ||
| fsdp_param_info.flat_param, | ||
| shard_state, | ||
| ) | ||
| key = _OptimStateKey(tuple(unflat_param_names), True) | ||
|
|
@@ -670,7 +682,7 @@ def _process_pos_dim_tensor_state( | |
| if not is_pos_dim_tensor_state: | ||
| no_tensor_osd["state"][key][state_name] = value | ||
| continue | ||
| if key.is_flat_param: # FSDP parameter | ||
| if key.is_fsdp_managed: # FSDP parameter | ||
| sharded_size = FlatParamHandle._get_sharded_size( | ||
| value, rank=0, world_size=world_size | ||
| ) | ||
|
|
@@ -753,7 +765,7 @@ def _broadcast_pos_dim_tensor_states( | |
| else: | ||
| unsharded_tensor = None | ||
| shape, dtype = value.shape, value.dtype | ||
| if key.is_flat_param: # FSDP parameter | ||
| if key.is_fsdp_managed: # FSDP parameter | ||
| _broadcast_sharded_pos_dim_tensor_state( | ||
| unsharded_tensor, | ||
| param_state, | ||
|
|
@@ -1079,6 +1091,7 @@ def _map_param_id_to_optim_keys( | |
| group: Optional[dist.ProcessGroup], | ||
| param_id_to_param: List[nn.Parameter], | ||
| param_to_fqns: Dict[nn.Parameter, List[str]], | ||
| fqn_to_fsdp_param_info: Dict[str, FSDPParamInfo], | ||
| ) -> Tuple[Dict[int, _OptimStateKey], Dict[_OptimStateKey, int]]: | ||
| """ | ||
| Construct the local mapping between the `_OptimStateKey` and parameter IDs | ||
|
|
@@ -1087,18 +1100,21 @@ def _map_param_id_to_optim_keys( | |
| """ | ||
| rank = dist.get_rank(group) | ||
| optim_state_key_to_param_id: Dict[_OptimStateKey, int] = {} # local | ||
| r0_param_id_to_optim_state_key: Dict[ | ||
| int, _OptimStateKey | ||
| ] = {} # rank 0 | ||
| r0_param_id_to_optim_state_key: Dict[int, _OptimStateKey] = {} # rank 0 | ||
|
|
||
| for param_id, param in enumerate(param_id_to_param): | ||
| # Do not include parameters without state to avoid empty mappings | ||
| # just like in normal `torch.optim.Optimizer.state_dict()` | ||
| if param_id not in optim_state_dict["state"]: | ||
| continue | ||
| fqns = param_to_fqns[param] | ||
| is_fsdp_managed = isinstance(param, FlatParameter) | ||
| if is_fsdp_managed: | ||
| assert fqns[0] in fqn_to_fsdp_param_info | ||
| is_fsdp_managed = fqns[0] in fqn_to_fsdp_param_info | ||
| optim_state_key = _OptimStateKey( | ||
| unflat_param_names=tuple(param_to_fqns[param]), | ||
| is_flat_param=isinstance(param, FlatParameter), | ||
| unflat_param_names=tuple(fqns), | ||
| is_fsdp_managed=is_fsdp_managed, | ||
| ) | ||
| if rank == 0: | ||
| r0_param_id_to_optim_state_key[param_id] = optim_state_key | ||
|
|
@@ -1220,6 +1236,7 @@ def _optim_state_dict( | |
| if using_optim_input | ||
| else _get_param_id_to_param(optim) | ||
| ) | ||
| fqn_to_fsdp_param_info = _get_fqn_to_fsdp_param_info(model) | ||
|
|
||
| ( | ||
| param_id_to_optim_state_key, | ||
|
|
@@ -1229,20 +1246,23 @@ def _optim_state_dict( | |
| group, | ||
| param_id_to_param, | ||
| param_to_fqns, | ||
| fqn_to_fsdp_param_info, | ||
| ) | ||
| flat_param_to_fsdp_state = _get_flat_param_to_fsdp_module(model) | ||
|
|
||
| # Iterate in rank 0's flattened parameter ID order to ensure aligned | ||
| # all-gathers across ranks | ||
| for optim_state_key in param_id_to_optim_state_key.values(): | ||
| param_id = optim_state_key_to_param_id[optim_state_key] | ||
| if optim_state_key.is_flat_param: | ||
| param = param_id_to_param[param_id] | ||
| fsdp_state = flat_param_to_fsdp_state[param] | ||
| if optim_state_key.is_fsdp_managed: | ||
| # If there are multiple unflat_param_names (not use_orig_params), | ||
| # they share the same FSDPParamInfo. So the first unflat_param_name | ||
| # is sufficient to fetch the FSDPParamInfo. | ||
| fqn = optim_state_key.unflat_param_names[0] | ||
awgu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| fsdp_param_info = fqn_to_fsdp_param_info[fqn] | ||
| unflat_state = _unflatten_optim_state( | ||
| cast(FlatParameter, param), | ||
| fsdp_param_info.flat_param, | ||
| optim_state_dict["state"][param_id], | ||
| fsdp_state, | ||
| fsdp_param_info.state, | ||
| to_save, | ||
| shard_state, | ||
| ) | ||
|
|
@@ -1269,3 +1289,43 @@ def _optim_state_dict( | |
| ) | ||
|
|
||
| return fsdp_osd | ||
|
|
||
|
|
||
| def _get_fqn_to_fsdp_param_info(model: nn.Module) -> Dict[str, FSDPParamInfo]: | ||
| """ | ||
| Construct the mapping from a param's fqn to its corresponding ``FSDPParamInfo`` | ||
| if the param is managed by FSDP. ``FlatParameter._fqns`` only stores the first | ||
| FQN of a shared parameter. So the keys in the mapping are guaranteed to map | ||
| to unique parameters. | ||
| """ | ||
|
|
||
| def module_fn(module, prefix, fqn_to_param_info): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we add a comment saying we need to use |
||
| # TODO: make it work with composable API. | ||
| if not isinstance(module, fsdp_file.FullyShardedDataParallel): | ||
| return | ||
| _lazy_init(module, module) | ||
| handles = _module_handles(module, module) | ||
| if not handles: | ||
| return | ||
| flat_param = handles[0].flat_param | ||
| fsdp_param_info = FSDPParamInfo(module, flat_param, {}) | ||
| for idx, local_fqn in enumerate(flat_param._fqns): | ||
| fqn = clean_tensor_name(prefix + local_fqn) | ||
| if fqn in fqn_to_param_info: | ||
| assert fqn_to_param_info[fqn].flat_param == flat_param | ||
| fqn_to_param_info[fqn] = fsdp_param_info | ||
| fsdp_param_info.param_indices[fqn] = idx | ||
|
|
||
| def return_fn(fqn_to_param_info): | ||
| return fqn_to_param_info | ||
|
|
||
| fqn_to_param_info: Dict[str, FSDPParamInfo] = {} | ||
| # FlatParameter._fqns stores the local fqn, starting from the root of the | ||
| # FSDP. Using _apply_to_modules() with model (may not be the FSDP root | ||
| # module) allows us to construct the global fqn. | ||
| return _apply_to_modules( | ||
| model, | ||
| module_fn, | ||
| return_fn, | ||
| fqn_to_param_info, | ||
| ) | ||
awgu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.