-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[FSDP] Override named_parameters() for clean names in summon_full_params()
#74333
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
CI Flow Status⚛️ CI FlowRuleset - Version:
|
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit acd6186 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
[ghstack-poisoned]
|
@awgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
| """ | ||
| # Monkey patch `named_parameters()` | ||
| torch_named_parameters = torch.nn.Module.named_parameters | ||
| self.named_parameters = self._fsdp_named_parameters # type: ignore[assignment] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding the type: ignore[assignment] seems to be the best solution. See python/mypy#2427.
Fixes #73890 by monkey patching `torch.nn.Module.named_paramteters()`. Differential Revision: [D34937201](https://our.internmc.facebook.com/intern/diff/D34937201) [ghstack-poisoned]
|
@awgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
| List, | ||
| Optional, | ||
| Generator, | ||
| Iterator, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just sorted imports.
|
To-Do: look into overriding |
Fixes #73890 by monkey patching `torch.nn.Module.named_paramteters()`. Differential Revision: [D34937201](https://our.internmc.facebook.com/intern/diff/D34937201) [ghstack-poisoned]
| context manager. | ||
| """ | ||
| # Determine which logic to use based on the context at call time | ||
| if not hasattr(self, "training_state") or \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nn.Modules contained in an FSDP instance may not have training_state as an attribute but still return True for isinstance(module, FullyShardedDataParallel). If you guys think this hasattr() check is too hacky, I will see if I can find a more direct solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this can be simplified to if getattr(self, "training_state", None) != TrainingState_.SUMMON_FULL_PARAMS:
Fixes #73890 by monkey patching `torch.nn.Module.named_paramteters()`. Differential Revision: [D34937201](https://our.internmc.facebook.com/intern/diff/D34937201) [ghstack-poisoned]
|
@awgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
|
|
||
| def named_parameters( | ||
| self, | ||
| prefix: str = "", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's take in *args, **kwargs so we don't have to change if the API changes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be correct to say that PyTorch Core cannot add positional arguments before prefix without breaking backward compatibility? In that case, could we only add **kwargs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using (*args, **kwargs) is kind of the convention in Python to pass arguments to the parent method without any modification.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I misunderstood. I thought we were trying to additionally pass in *args, **kwargs instead of replacing prefix and recurse.
Fixed this now.
…ummon_full_params()`" Fixes #73890. Differential Revision: [D34937201](https://our.internmc.facebook.com/intern/diff/D34937201) [ghstack-poisoned]
|
@awgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
…ummon_full_params()`" Fixes #73890. Differential Revision: [D34937201](https://our.internmc.facebook.com/intern/diff/D34937201) [ghstack-poisoned]
|
@awgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
…ummon_full_params()`" Fixes #73890. Differential Revision: [D34937201](https://our.internmc.facebook.com/intern/diff/D34937201) [ghstack-poisoned]
|
@awgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
…ummon_full_params()`" Fixes #73890. Differential Revision: [D34937201](https://our.internmc.facebook.com/intern/diff/D34937201) [ghstack-poisoned]
|
@awgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
rohan-varma
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, just 2 minor q's for your consideration, will stamp after that
| # Remove any instances of the FSDP-specific prefix; there can | ||
| # be multiple in the case of nested FSDP modules | ||
| param_name = param_name.replace(FSDP_PREFIX, "") | ||
| yield (param_name, param) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we do the following to avoid the duplicated for loop?
in_summon = (training_state == summon_full_params)
for n, p in named_parameters():
name = name.replace(...) if in_summon else name
yield (name, p)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great point.
| """ | ||
| # Determine which logic to use based on the context at call time | ||
| if getattr(self, "training_state", None) != TrainingState_.SUMMON_FULL_PARAMS: | ||
| for param_name, param in torch.nn.Module.named_parameters( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wonder if we can use super().named_parameters rather than torch.nn.Module.named_parameters call? If user writes MyModule that inherits from nn.Module with a custom named_parameters will this work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great point.
…ummon_full_params()`" Fixes #73890. Differential Revision: [D34937201](https://our.internmc.facebook.com/intern/diff/D34937201) [ghstack-poisoned]
|
@awgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
|
Lgtm |
fegin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
Hey @awgu. |
Stack from ghstack:
named_parameters()for clean names insummon_full_params()#74333 [FSDP] Overridenamed_parameters()for clean names insummon_full_params()Fixes #73890.
Differential Revision: D34937201