-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[FSDP] Add use_orig_params
#84911
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP] Add use_orig_params
#84911
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/84911
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit fa871ec: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
rohan-varma
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we have a brief PR description for reviewers to have context?
[ghstack-poisoned]
[ghstack-poisoned]
**Overview** This PR adds the option to use the original parameters via `use_orig_params=True` in the FSDP constructor. - This exposes the original parameters rather than the `FlatParameter`s from `named_parameters()`, which means that the optimizer runs on the original parameters. Hence, users may assign original parameters from the same `FlatParameter` to different parameter groups. - This enables decoupling the original parameter variables from their storage without changing the variables themselves, which is critical for our upcoming execution-order-based non-recursive wrapping policy. For more detailed design explanation, refer to the Quip shared internally. **Follow-Ups** #85831 `test_fsdp_use_orig_params.py` adds ~4 min 46 seconds to the TTS on the AWS cluster. [ghstack-poisoned]
**Overview** This PR adds the option to use the original parameters via `use_orig_params=True` in the FSDP constructor. - This exposes the original parameters rather than the `FlatParameter`s from `named_parameters()`, which means that the optimizer runs on the original parameters. Hence, users may assign original parameters from the same `FlatParameter` to different parameter groups. - This enables decoupling the original parameter variables from their storage without changing the variables themselves, which is critical for our upcoming execution-order-based non-recursive wrapping policy. For more detailed design explanation, refer to the Quip shared internally. **Follow-Ups** #85831 `test_fsdp_use_orig_params.py` adds ~4 min 46 seconds to the TTS on the AWS cluster. [ghstack-poisoned]
**Overview** This PR adds the option to use the original parameters via `use_orig_params=True` in the FSDP constructor. - This exposes the original parameters rather than the `FlatParameter`s from `named_parameters()`, which means that the optimizer runs on the original parameters. Hence, users may assign original parameters from the same `FlatParameter` to different parameter groups. - This enables decoupling the original parameter variables from their storage without changing the variables themselves, which is critical for our upcoming execution-order-based non-recursive wrapping policy. For more detailed design explanation, refer to the Quip shared internally. **Follow-Ups** See 85831 (removing link to avoid spamming the issue whenever I update this PR). `test_fsdp_use_orig_params.py` adds ~4 min 46 seconds to the TTS on the AWS cluster. [ghstack-poisoned]
|
/easycla As part of the transition to the PyTorch Foundation, this project now requires contributions be covered under the new CLA. See #85559 for additional details. This comment will trigger a new check of this PR. If you are already covered, you will simply see a new "EasyCLA" check that passes. If you are not covered, a bot will leave a new comment with a link to sign. |
| def get_error_context(): | ||
| error_regex = "Optimizer state checkpointing is not supported yet for `use_orig_params=True`" | ||
| return self.assertRaisesRegex( | ||
| expected_exception=NotImplementedError, expected_regex=error_regex |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we file issues for this and all other unsupported features?
| with self.assertRaisesRegex(RuntimeError, "Cannot writeback"): | ||
| # Change the gradient to a new one with 1 added to each dimension | ||
| # to force a shape mismatch when writing back | ||
| if self.rank == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if/else can be condensed to:
param = getattr(fsdp, f"lin{rank}")
lin_weight_shape = param.weight.shape
param.weight = nn.Parameter(...)
param.weight.grad = ....
rohan-varma
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall, thanks for working through this and the great attention to detail! Have 2 high level questions:
- Shall we file follow-up issues for all unsupported features such as optimizer state checkpointing
- Did we update all necessary documentation mentioning how to use this feature and the caveats/assumptions (such as the gradient writeback)?
| flat_param.grad = flat_param._saved_grad_shard # type: ignore[attr-defined] | ||
| if self._config.keep_low_precision_grads: | ||
| assert flat_param.grad is not None # mypy | ||
| flat_param.grad.data = flat_param.grad.to(self._config.param_dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So mixed precision doesn't work with CPU offload? Can we file an issue for this, seems pretty major?
| self._use_orig_params | ||
| and self._handles | ||
| and self._handles[0].uses_sharded_strategy | ||
| and self._handles[0].is_sharded(self._handles[0].flat_param) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is is_sharded the canonical, recommended way to check if a param is in the sharded state? how about gradients?
In general, is it worth exposing docs on such methods to aid FSDP developers in the future who are looking to do these common sort of things?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, is_sharded() can be the canonical way to check if a parameter or its gradient is currently sharded.
Do you have any suggestions for how to expose docs / what would help FSDP developers onboard more efficiently? The method is currently documented, but perhaps this is not salient enough.
| def _sharded_post_load_state_dict_hook(self, *args, **kwargs) -> None: | ||
| pass | ||
| if self._use_orig_params: | ||
| self._register_orig_params() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there guidance for FSDP developers on when they will need to call these methods?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am still working to thoroughly understand how model state dict is implemented, namely the pre/postconditions of the pre/post save and load hooks, e.g. should FlatParameters be registered or should original parameters be registered, what do the prefixes and state dict keys look like at some point in the recursive call stack, etc.
I started trying to retire FlattenParamsWrapper but got quickly stymied by trying to understand those pre/postconditions. Maybe after I figure this out, I can help provide more internal documentation around these invariants.
| if torch.cuda.is_available(): | ||
| torch.cuda.synchronize() | ||
| self._lazy_init() | ||
| self._clear_grads_if_needed() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need to do this for state_dict? Grads being none shouldn't matter there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am just using the major FSDP calls as an entry point to release gradient memory as early as possible. In the code crawl I did manually, I found that sometimes people will checkpoint after zero_grad(set_to_none=True) after the optimizer step.
| return args, kwargs | ||
| self._wait_for_previous_optim_step() | ||
| self._needs_pre_forward_unshard.clear() | ||
| self._clear_grads_if_needed() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a comment to mention we are calling this to enable correctness when user has set_grad_to_none and it is a sort of delayed set_to_none.
Also, is it worth documenting these writeback semantics clearly to the end user?
| in_summon_full_params = self.training_state == TrainingState_.SUMMON_FULL_PARAMS | ||
| should_clean_name = ( | ||
| self.training_state == TrainingState_.SUMMON_FULL_PARAMS | ||
| or self._use_orig_params |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need to clean the name when using use_orig_params? Shouldn't the param FQNs be exactly the local param names?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is still nested wrapping (FSDP -> FPW -> module), so I think the names will be unclean.
| for fsdp_module in FullyShardedDataParallel.fsdp_modules(model) | ||
| ): | ||
| raise NotImplementedError( | ||
| "Optimizer state checkpointing is not supported yet for `use_orig_params=True`" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have we filed issues for this?
**Overview** This PR adds the option to use the original parameters via `use_orig_params=True` in the FSDP constructor. - This exposes the original parameters rather than the `FlatParameter`s from `named_parameters()`, which means that the optimizer runs on the original parameters. Hence, users may assign original parameters from the same `FlatParameter` to different parameter groups. - This enables decoupling the original parameter variables from their storage without changing the variables themselves, which is critical for our upcoming execution-order-based non-recursive wrapping policy. For more detailed design explanation, refer to the Quip shared internally. **Follow-Ups** See 85831 (removing link to avoid spamming the issue whenever I update this PR). `test_fsdp_use_orig_params.py` adds ~4 min 46 seconds to the TTS on the AWS cluster. [ghstack-poisoned]
| param = self.flat_param._params[i] # type: ignore[index] | ||
| setattr(module, param_name, param) | ||
| param.data = view |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
feel we need a comment here, read for a while and seems that it intentionally exposes 'param' variable as the module's attr, so that the .data can be changed and points to changed data later on? also, the param_name is not registered as parameter here, why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct. More precisely, we never delete the original parameter variable, and instead, FlatParamHandle always keep a reference to the original parameter variable.
Just for knowledge sharing, de-registration can happen in two ways:
delattr(module, param_name)where the parameter is stored asmodule.param_name.module._parameters.pop(param_name).
The second way preserves that the parameter is present, i.e. the user may still access module.param_name; however, the parameter will not be returned by named_parameters().
Similarly, registration can happen in two ways:
setattr(module, param_name, param)module._parameters[param_name] = param
Since we already setattr(), we do not need to do any further explicit registration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh I see, thanks for the clarification!
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
Hey @awgu. |
zhaojuanmao
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
awesome PR! it handles so many subtle cases properly and carefully, especially like the idea to keep flat_param as the model's attribute while dynamically register it as model's parameter, it seems that this idea simplified the state_dict changes a lot
| """ | ||
| if not self._handles: | ||
| return | ||
| handle = self._handles[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are we assuming there is only one flat_param_handle per '_fsdp_wrapped_module' for now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I have an assert a few lines above 😄
pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py
Lines 3118 to 3122 in fa871ec
| p_assert( | |
| len(self._handles) <= 1, | |
| "Expects <=1 handle per FSDP instance; needs to be refactored " | |
| "for >1 handle (e.g. non-recursive wrapping)" | |
| ) |
Summary: **Overview** This PR adds the option to use the original parameters via `use_orig_params=True` in the FSDP constructor. - This exposes the original parameters rather than the `FlatParameter`s from `named_parameters()`, which means that the optimizer runs on the original parameters. Hence, users may assign original parameters from the same `FlatParameter` to different parameter groups. - This enables decoupling the original parameter variables from their storage without changing the variables themselves, which is critical for our upcoming execution-order-based non-recursive wrapping policy. For more detailed design explanation, refer to the Quip shared internally. **Follow-Ups** See 85831 (removing link to avoid spamming the issue whenever I update this PR). `test_fsdp_use_orig_params.py` adds ~4 min 46 seconds to the TTS on the AWS cluster. Pull Request resolved: #84911 Approved by: https://github.com/rohan-varma Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/be682befbc836a07d5d070bb569450429526a64b Reviewed By: seemethere Differential Revision: D40197130 Pulled By: seemethere fbshipit-source-id: fbf36e28bd06f49c8cb31febce86c26bc7ba7a34
ghstack-source-id: 5ee0687 Pull Request resolved: pytorch/pytorch#84911
ghstack-source-id: e936ff5 Pull Request resolved: pytorch/pytorch#84911
Stack from ghstack:
_fsdp_wrapped_module.flat_param#86122 [FSDP][2/N] Remove_fsdp_wrapped_module.flat_paramFlattenParamsWrapper#86117 [FSDP][1/N] RetireFlattenParamsWrappersummon_full_params(with_grads=True)#85738 [FSDP] Add initialsummon_full_params(with_grads=True)use_orig_params#84911 [FSDP] Adduse_orig_paramsOverview
This PR adds the option to use the original parameters via
use_orig_params=Truein the FSDP constructor.FlatParameters fromnamed_parameters(), which means that the optimizer runs on the original parameters. Hence, users may assign original parameters from the sameFlatParameterto different parameter groups.For more detailed design explanation, refer to the Quip shared internally.
Follow-Ups
See 85831 (removing link to avoid spamming the issue whenever I update this PR).
test_fsdp_use_orig_params.pyadds ~4 min 46 seconds to the TTS on the AWS cluster.