Skip to content

Conversation

@fegin
Copy link
Contributor

@fegin fegin commented Nov 30, 2022

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 30, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/89899

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Failures

As of commit 97ff650:

The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

…ginal FQN to flat_param"


**Motivation:**
Add a helper to map from the FQN to the corresponding flat_param. The helper will directly get flat_param from fsdp_state and flat_handler as flat_param is not registered to the module if `use_orig_params` is True.


[ghstack-poisoned]
…ginal FQN to flat_param"


**Motivation:**
Add a helper to map from the FQN to the corresponding flat_param. The helper will directly get flat_param from fsdp_state and flat_handler as flat_param is not registered to the module if `use_orig_params` is True.


[ghstack-poisoned]
…ginal FQN to flat_param"


**Motivation:**
Add a helper to map from the FQN to the corresponding flat_param. The helper will directly get flat_param from fsdp_state and flat_handler as flat_param is not registered to the module if `use_orig_params` is True.


[ghstack-poisoned]
Copy link
Collaborator

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made an initial pass with some conceptual questions.



def _get_fqn_to_fsdp_param_info(
model: nn.Module, dedup_shared_fqns: Set[str]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name dedup_shared_fqns is a bit unclear to me. Is the deduplication just happening because this is a set, but this data structure mainly just represents the FQNs to add as keys to the returned fqn_to_param_info?

Copy link
Contributor Author

@fegin fegin Nov 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment. But in general, we only need to track the first fqn of a shared parameter. Same case for _get_param_to_fqns() and param_to_fqns.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After checking the code, I realized this protection is redundant. Please check the comment I added into the code.

class FSDPParamInfo:
state: nn.Module
flat_param: FlatParameter
fqn_indices: Dict[str, int]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add a comment explaining what these indices mean? (My understanding is that this is a mapping from FQN from FlatParameter._fqns to the corresponding index in FlatParameter._fqns.)

It also looks like this is not used for this PR (but probably for the next?).

Copy link
Contributor Author

@fegin fegin Nov 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change to param_indices and yes it is used in the next PR.

@fegin fegin changed the title [FSDP][optim_state_dict][2/N] Add a helper to map from original FQN to flat_param [FSDP][optim_state_dict][2/N] Add _get_fqn_to_fsdp_param_info to map from original FQN to flat_param Nov 30, 2022
…nfo to map from original FQN to flat_param"


**Motivation:**
Add a helper to map from the FQN to the corresponding flat_param. The helper will directly get flat_param from fsdp_state and flat_handler as flat_param is not registered to the module if `use_orig_params` is True.


[ghstack-poisoned]
@fegin fegin requested a review from wanchaol as a code owner November 30, 2022 22:56
…nfo to map from original FQN to flat_param"


**Motivation:**
Add a helper to map from the FQN to the corresponding flat_param. The helper will directly get flat_param from fsdp_state and flat_handler as flat_param is not registered to the module if `use_orig_params` is True.


[ghstack-poisoned]
…nfo to map from original FQN to flat_param"


**Motivation:**
Add a helper to map from the FQN to the corresponding flat_param. The helper will directly get flat_param from fsdp_state and flat_handler as flat_param is not registered to the module if `use_orig_params` is True.


[ghstack-poisoned]
Copy link
Collaborator

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I left some nits.


def _get_fqn_to_fsdp_param_info(model: nn.Module) -> Dict[str, FSDPParamInfo]:
"""
Construct the maaping from a param's fqn to its corresponding FSDPParamInfo
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
Construct the maaping from a param's fqn to its corresponding FSDPParamInfo
Construct the mapping from a param's fqn to its corresponding FSDPParamInfo

"""
Construct the maaping from a param's fqn to its corresponding FSDPParamInfo
if the param is managed by FSDP. FlatParam only stores the first FQN of a
shared parameter. So the keys in the mapping are guranteed to map to unique
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
shared parameter. So the keys in the mapping are guranteed to map to unique
shared parameter. So the keys in the mapping are guaranteed to map to unique

shared parameter. So the keys in the mapping are guranteed to map to unique
parameters.
"""
def module_fn(module, prefix, fqn_to_param_info):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a comment saying we need to use _apply_to_modules to get the global FQN (since the saved FQNs are like local FQNs, not necessarily prefixed from the global root module)?

def _get_fqn_to_fsdp_param_info(model: nn.Module) -> Dict[str, FSDPParamInfo]:
"""
Construct the maaping from a param's fqn to its corresponding FSDPParamInfo
if the param is managed by FSDP. FlatParam only stores the first FQN of a
Copy link
Collaborator

@awgu awgu Dec 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I got confused at first, but I think I understand. Let me know if this is actually the wrong understanding.

Suggested change
if the param is managed by FSDP. FlatParam only stores the first FQN of a
if the param is managed by FSDP. FlatParameter._fqns only stores the first FQN of a

(add backticks if you want, but maybe not to be consistent with rest of the comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is correct, thanks!

…nfo to map from original FQN to flat_param"


**Motivation:**
Add a helper to map from the FQN to the corresponding flat_param. The helper will directly get flat_param from fsdp_state and flat_handler as flat_param is not registered to the module if `use_orig_params` is True.


[ghstack-poisoned]
@fegin
Copy link
Contributor Author

fegin commented Dec 7, 2022

@pytorchbot merge -f "The failing test is not related."

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
…from original FQN to flat_param (pytorch#89899)

**Motivation:**
Add a helper to map from the FQN to the corresponding flat_param. The helper will directly get flat_param from fsdp_state and flat_handler as flat_param is not registered to the module if `use_orig_params` is True.

Pull Request resolved: pytorch#89899
Approved by: https://github.com/awgu
@facebook-github-bot facebook-github-bot deleted the gh/fegin/48/head branch June 8, 2023 17:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants