Skip to content

Conversation

@awgu
Copy link
Collaborator

@awgu awgu commented Sep 13, 2022

Stack from ghstack:

Overview
This PR adds the option to use the original parameters via use_orig_params=True in the FSDP constructor.

  • This exposes the original parameters rather than the FlatParameters from named_parameters(), which means that the optimizer runs on the original parameters. Hence, users may assign original parameters from the same FlatParameter to different parameter groups.
  • This enables decoupling the original parameter variables from their storage without changing the variables themselves, which is critical for our upcoming execution-order-based non-recursive wrapping policy.

For more detailed design explanation, refer to the Quip shared internally.

Follow-Ups
See 85831 (removing link to avoid spamming the issue whenever I update this PR).

test_fsdp_use_orig_params.py adds ~4 min 46 seconds to the TTS on the AWS cluster.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 13, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/84911

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit fa871ec:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Contributor

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we have a brief PR description for reviewers to have context?

@awgu awgu marked this pull request as draft September 15, 2022 00:42
Andrew Gu added 2 commits September 15, 2022 17:11
@awgu awgu added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 28, 2022
Andrew Gu added 3 commits September 30, 2022 18:49
**Overview**
This PR adds the option to use the original parameters via `use_orig_params=True` in the FSDP constructor.
- This exposes the original parameters rather than the `FlatParameter`s from `named_parameters()`, which means that the optimizer runs on the original parameters. Hence, users may assign original parameters from the same `FlatParameter` to different parameter groups.
- This enables decoupling the original parameter variables from their storage without changing the variables themselves, which is critical for our upcoming execution-order-based non-recursive wrapping policy.

For more detailed design explanation, refer to the Quip shared internally.

**Follow-Ups**
#85831

`test_fsdp_use_orig_params.py` adds ~4 min 46 seconds to the TTS on the AWS cluster.

[ghstack-poisoned]
**Overview**
This PR adds the option to use the original parameters via `use_orig_params=True` in the FSDP constructor.
- This exposes the original parameters rather than the `FlatParameter`s from `named_parameters()`, which means that the optimizer runs on the original parameters. Hence, users may assign original parameters from the same `FlatParameter` to different parameter groups.
- This enables decoupling the original parameter variables from their storage without changing the variables themselves, which is critical for our upcoming execution-order-based non-recursive wrapping policy.

For more detailed design explanation, refer to the Quip shared internally.

**Follow-Ups**
#85831

`test_fsdp_use_orig_params.py` adds ~4 min 46 seconds to the TTS on the AWS cluster.

[ghstack-poisoned]
**Overview**
This PR adds the option to use the original parameters via `use_orig_params=True` in the FSDP constructor.
- This exposes the original parameters rather than the `FlatParameter`s from `named_parameters()`, which means that the optimizer runs on the original parameters. Hence, users may assign original parameters from the same `FlatParameter` to different parameter groups.
- This enables decoupling the original parameter variables from their storage without changing the variables themselves, which is critical for our upcoming execution-order-based non-recursive wrapping policy.

For more detailed design explanation, refer to the Quip shared internally.

**Follow-Ups**
See 85831 (removing link to avoid spamming the issue whenever I update this PR).

`test_fsdp_use_orig_params.py` adds ~4 min 46 seconds to the TTS on the AWS cluster.

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

/easycla

As part of the transition to the PyTorch Foundation, this project now requires contributions be covered under the new CLA. See #85559 for additional details.

This comment will trigger a new check of this PR. If you are already covered, you will simply see a new "EasyCLA" check that passes. If you are not covered, a bot will leave a new comment with a link to sign.

@rohan-varma rohan-varma self-requested a review October 6, 2022 00:42
def get_error_context():
error_regex = "Optimizer state checkpointing is not supported yet for `use_orig_params=True`"
return self.assertRaisesRegex(
expected_exception=NotImplementedError, expected_regex=error_regex
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we file issues for this and all other unsupported features?

with self.assertRaisesRegex(RuntimeError, "Cannot writeback"):
# Change the gradient to a new one with 1 added to each dimension
# to force a shape mismatch when writing back
if self.rank == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if/else can be condensed to:

param = getattr(fsdp, f"lin{rank}")
lin_weight_shape = param.weight.shape
param.weight = nn.Parameter(...)
param.weight.grad = ....

Copy link
Contributor

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall, thanks for working through this and the great attention to detail! Have 2 high level questions:

  1. Shall we file follow-up issues for all unsupported features such as optimizer state checkpointing
  2. Did we update all necessary documentation mentioning how to use this feature and the caveats/assumptions (such as the gradient writeback)?

flat_param.grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
if self._config.keep_low_precision_grads:
assert flat_param.grad is not None # mypy
flat_param.grad.data = flat_param.grad.to(self._config.param_dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So mixed precision doesn't work with CPU offload? Can we file an issue for this, seems pretty major?

self._use_orig_params
and self._handles
and self._handles[0].uses_sharded_strategy
and self._handles[0].is_sharded(self._handles[0].flat_param)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is is_sharded the canonical, recommended way to check if a param is in the sharded state? how about gradients?

In general, is it worth exposing docs on such methods to aid FSDP developers in the future who are looking to do these common sort of things?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, is_sharded() can be the canonical way to check if a parameter or its gradient is currently sharded.

Do you have any suggestions for how to expose docs / what would help FSDP developers onboard more efficiently? The method is currently documented, but perhaps this is not salient enough.

def _sharded_post_load_state_dict_hook(self, *args, **kwargs) -> None:
pass
if self._use_orig_params:
self._register_orig_params()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there guidance for FSDP developers on when they will need to call these methods?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am still working to thoroughly understand how model state dict is implemented, namely the pre/postconditions of the pre/post save and load hooks, e.g. should FlatParameters be registered or should original parameters be registered, what do the prefixes and state dict keys look like at some point in the recursive call stack, etc.

I started trying to retire FlattenParamsWrapper but got quickly stymied by trying to understand those pre/postconditions. Maybe after I figure this out, I can help provide more internal documentation around these invariants.

if torch.cuda.is_available():
torch.cuda.synchronize()
self._lazy_init()
self._clear_grads_if_needed()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to do this for state_dict? Grads being none shouldn't matter there?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am just using the major FSDP calls as an entry point to release gradient memory as early as possible. In the code crawl I did manually, I found that sometimes people will checkpoint after zero_grad(set_to_none=True) after the optimizer step.

return args, kwargs
self._wait_for_previous_optim_step()
self._needs_pre_forward_unshard.clear()
self._clear_grads_if_needed()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a comment to mention we are calling this to enable correctness when user has set_grad_to_none and it is a sort of delayed set_to_none.

Also, is it worth documenting these writeback semantics clearly to the end user?

in_summon_full_params = self.training_state == TrainingState_.SUMMON_FULL_PARAMS
should_clean_name = (
self.training_state == TrainingState_.SUMMON_FULL_PARAMS
or self._use_orig_params
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to clean the name when using use_orig_params? Shouldn't the param FQNs be exactly the local param names?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is still nested wrapping (FSDP -> FPW -> module), so I think the names will be unclean.

for fsdp_module in FullyShardedDataParallel.fsdp_modules(model)
):
raise NotImplementedError(
"Optimizer state checkpointing is not supported yet for `use_orig_params=True`"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have we filed issues for this?

**Overview**
This PR adds the option to use the original parameters via `use_orig_params=True` in the FSDP constructor.
- This exposes the original parameters rather than the `FlatParameter`s from `named_parameters()`, which means that the optimizer runs on the original parameters. Hence, users may assign original parameters from the same `FlatParameter` to different parameter groups.
- This enables decoupling the original parameter variables from their storage without changing the variables themselves, which is critical for our upcoming execution-order-based non-recursive wrapping policy.

For more detailed design explanation, refer to the Quip shared internally.

**Follow-Ups**
See 85831 (removing link to avoid spamming the issue whenever I update this PR).

`test_fsdp_use_orig_params.py` adds ~4 min 46 seconds to the TTS on the AWS cluster.

[ghstack-poisoned]
Comment on lines +1107 to +1109
param = self.flat_param._params[i] # type: ignore[index]
setattr(module, param_name, param)
param.data = view
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feel we need a comment here, read for a while and seems that it intentionally exposes 'param' variable as the module's attr, so that the .data can be changed and points to changed data later on? also, the param_name is not registered as parameter here, why?

Copy link
Collaborator Author

@awgu awgu Oct 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. More precisely, we never delete the original parameter variable, and instead, FlatParamHandle always keep a reference to the original parameter variable.

Just for knowledge sharing, de-registration can happen in two ways:

  1. delattr(module, param_name) where the parameter is stored as module.param_name.
  2. module._parameters.pop(param_name).

The second way preserves that the parameter is present, i.e. the user may still access module.param_name; however, the parameter will not be returned by named_parameters().

Similarly, registration can happen in two ways:

  1. setattr(module, param_name, param)
  2. module._parameters[param_name] = param

Since we already setattr(), we do not need to do any further explicit registration.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I see, thanks for the clarification!

@awgu
Copy link
Collaborator Author

awgu commented Oct 7, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions
Copy link
Contributor

github-actions bot commented Oct 7, 2022

Hey @awgu.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

Copy link
Contributor

@zhaojuanmao zhaojuanmao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome PR! it handles so many subtle cases properly and carefully, especially like the idea to keep flat_param as the model's attribute while dynamically register it as model's parameter, it seems that this idea simplified the state_dict changes a lot

"""
if not self._handles:
return
handle = self._handles[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we assuming there is only one flat_param_handle per '_fsdp_wrapped_module' for now?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I have an assert a few lines above 😄

p_assert(
len(self._handles) <= 1,
"Expects <=1 handle per FSDP instance; needs to be refactored "
"for >1 handle (e.g. non-recursive wrapping)"
)

facebook-github-bot pushed a commit that referenced this pull request Oct 10, 2022
Summary:
**Overview**
This PR adds the option to use the original parameters via `use_orig_params=True` in the FSDP constructor.
- This exposes the original parameters rather than the `FlatParameter`s from `named_parameters()`, which means that the optimizer runs on the original parameters. Hence, users may assign original parameters from the same `FlatParameter` to different parameter groups.
- This enables decoupling the original parameter variables from their storage without changing the variables themselves, which is critical for our upcoming execution-order-based non-recursive wrapping policy.

For more detailed design explanation, refer to the Quip shared internally.

**Follow-Ups**
See 85831 (removing link to avoid spamming the issue whenever I update this PR).

`test_fsdp_use_orig_params.py` adds ~4 min 46 seconds to the TTS on the AWS cluster.

Pull Request resolved: #84911
Approved by: https://github.com/rohan-varma

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/be682befbc836a07d5d070bb569450429526a64b

Reviewed By: seemethere

Differential Revision: D40197130

Pulled By: seemethere

fbshipit-source-id: fbf36e28bd06f49c8cb31febce86c26bc7ba7a34
Rick0317 pushed a commit to Rick0317/pytorch that referenced this pull request Oct 18, 2022
ghstack-source-id: 5ee0687
Pull Request resolved: pytorch/pytorch#84911
Rick0317 pushed a commit to Rick0317/pytorch that referenced this pull request Oct 18, 2022
ghstack-source-id: e936ff5
Pull Request resolved: pytorch/pytorch#84911
@facebook-github-bot facebook-github-bot deleted the gh/awgu/95/head branch June 8, 2023 15:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request cla signed Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category skip-pr-sanity-checks

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants