-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[FSDP][optim_state_dict][2/N] Add _get_fqn_to_fsdp_param_info to map from original FQN to flat_param #89899
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…o flat_param [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/89899
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 FailuresAs of commit 97ff650: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…ginal FQN to flat_param" [ghstack-poisoned]
…ginal FQN to flat_param" **Motivation:** Add a helper to map from the FQN to the corresponding flat_param. The helper will directly get flat_param from fsdp_state and flat_handler as flat_param is not registered to the module if `use_orig_params` is True. [ghstack-poisoned]
…ginal FQN to flat_param" **Motivation:** Add a helper to map from the FQN to the corresponding flat_param. The helper will directly get flat_param from fsdp_state and flat_handler as flat_param is not registered to the module if `use_orig_params` is True. [ghstack-poisoned]
…ginal FQN to flat_param" **Motivation:** Add a helper to map from the FQN to the corresponding flat_param. The helper will directly get flat_param from fsdp_state and flat_handler as flat_param is not registered to the module if `use_orig_params` is True. [ghstack-poisoned]
awgu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made an initial pass with some conceptual questions.
|
|
||
|
|
||
| def _get_fqn_to_fsdp_param_info( | ||
| model: nn.Module, dedup_shared_fqns: Set[str] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name dedup_shared_fqns is a bit unclear to me. Is the deduplication just happening because this is a set, but this data structure mainly just represents the FQNs to add as keys to the returned fqn_to_param_info?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a comment. But in general, we only need to track the first fqn of a shared parameter. Same case for _get_param_to_fqns() and param_to_fqns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After checking the code, I realized this protection is redundant. Please check the comment I added into the code.
| class FSDPParamInfo: | ||
| state: nn.Module | ||
| flat_param: FlatParameter | ||
| fqn_indices: Dict[str, int] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add a comment explaining what these indices mean? (My understanding is that this is a mapping from FQN from FlatParameter._fqns to the corresponding index in FlatParameter._fqns.)
It also looks like this is not used for this PR (but probably for the next?).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change to param_indices and yes it is used in the next PR.
…nfo to map from original FQN to flat_param" **Motivation:** Add a helper to map from the FQN to the corresponding flat_param. The helper will directly get flat_param from fsdp_state and flat_handler as flat_param is not registered to the module if `use_orig_params` is True. [ghstack-poisoned]
…nfo to map from original FQN to flat_param" **Motivation:** Add a helper to map from the FQN to the corresponding flat_param. The helper will directly get flat_param from fsdp_state and flat_handler as flat_param is not registered to the module if `use_orig_params` is True. [ghstack-poisoned]
…nfo to map from original FQN to flat_param" **Motivation:** Add a helper to map from the FQN to the corresponding flat_param. The helper will directly get flat_param from fsdp_state and flat_handler as flat_param is not registered to the module if `use_orig_params` is True. [ghstack-poisoned]
awgu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! I left some nits.
|
|
||
| def _get_fqn_to_fsdp_param_info(model: nn.Module) -> Dict[str, FSDPParamInfo]: | ||
| """ | ||
| Construct the maaping from a param's fqn to its corresponding FSDPParamInfo |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
| Construct the maaping from a param's fqn to its corresponding FSDPParamInfo | |
| Construct the mapping from a param's fqn to its corresponding FSDPParamInfo |
| """ | ||
| Construct the maaping from a param's fqn to its corresponding FSDPParamInfo | ||
| if the param is managed by FSDP. FlatParam only stores the first FQN of a | ||
| shared parameter. So the keys in the mapping are guranteed to map to unique |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
| shared parameter. So the keys in the mapping are guranteed to map to unique | |
| shared parameter. So the keys in the mapping are guaranteed to map to unique |
| shared parameter. So the keys in the mapping are guranteed to map to unique | ||
| parameters. | ||
| """ | ||
| def module_fn(module, prefix, fqn_to_param_info): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add a comment saying we need to use _apply_to_modules to get the global FQN (since the saved FQNs are like local FQNs, not necessarily prefixed from the global root module)?
| def _get_fqn_to_fsdp_param_info(model: nn.Module) -> Dict[str, FSDPParamInfo]: | ||
| """ | ||
| Construct the maaping from a param's fqn to its corresponding FSDPParamInfo | ||
| if the param is managed by FSDP. FlatParam only stores the first FQN of a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I got confused at first, but I think I understand. Let me know if this is actually the wrong understanding.
| if the param is managed by FSDP. FlatParam only stores the first FQN of a | |
| if the param is managed by FSDP. FlatParameter._fqns only stores the first FQN of a |
(add backticks if you want, but maybe not to be consistent with rest of the comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is correct, thanks!
…nfo to map from original FQN to flat_param" **Motivation:** Add a helper to map from the FQN to the corresponding flat_param. The helper will directly get flat_param from fsdp_state and flat_handler as flat_param is not registered to the module if `use_orig_params` is True. [ghstack-poisoned]
|
@pytorchbot merge -f "The failing test is not related." |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…from original FQN to flat_param (pytorch#89899) **Motivation:** Add a helper to map from the FQN to the corresponding flat_param. The helper will directly get flat_param from fsdp_state and flat_handler as flat_param is not registered to the module if `use_orig_params` is True. Pull Request resolved: pytorch#89899 Approved by: https://github.com/awgu
Stack from ghstack (oldest at bottom):
Motivation:
Add a helper to map from the FQN to the corresponding flat_param. The helper will directly get flat_param from fsdp_state and flat_handler as flat_param is not registered to the module if
use_orig_paramsis True.