Skip to content

Conversation

@awgu
Copy link
Collaborator

@awgu awgu commented Sep 27, 2022

Stack from ghstack:

This adds summon_full_params(with_grads=True) for use_orig_params=True and offload_to_cpu=False. Filling in the use_orig_params=False case requires some already-planned refactoring, and the offload_to_cpu=True case needs some additional work as well.

Adding this is helpful for debugging use_orig_params=True to make sure gradients are being updated correctly.

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 27, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/85738

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures, 1 Pending

As of commit 8122cbd:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: distributed (fsdp) release notes category label Sep 27, 2022
@facebook-github-bot facebook-github-bot added cla signed oncall: distributed Add this issue/PR to distributed oncall triage queue labels Sep 27, 2022
This adds `summon_full_params(with_grads=True)` for `use_orig_params=True` and `offload_to_cpu=False`. Filling in the `use_orig_params=False` case requires some already-planned refactoring, and the `offload_to_cpu=True` case needs some additional work as well.

[ghstack-poisoned]
This adds `summon_full_params(with_grads=True)` for `use_orig_params=True` and `offload_to_cpu=False`. Filling in the `use_orig_params=False` case requires some already-planned refactoring, and the `offload_to_cpu=True` case needs some additional work as well.

[ghstack-poisoned]
This adds `summon_full_params(with_grads=True)` for `use_orig_params=True` and `offload_to_cpu=False`. Filling in the `use_orig_params=False` case requires some already-planned refactoring, and the `offload_to_cpu=True` case needs some additional work as well.

[ghstack-poisoned]
Copy link
Contributor

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall, some mostly minor questions / comments.

self._check_sharded(flat_param.grad)
flat_param._saved_grad_shard = flat_param.grad # type: ignore[attr-defined]
sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
dist._all_gather_base(padded_unsharded_grad, sharded_grad, self.process_group)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it mean that the gradient is all zeros if the flat_param.grad = None on all ranks?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed briefly offline, but there are two options:

  1. (As in the PR currently) We only use a single all-gather collective per FlatParameter. In that case, if all ranks' sharded gradient is None, then the unsharded gradient is incorrectly torch.zeros(unsharded_size). If only some ranks' sharded gradients are None, then the unsharded gradient zeros those corresponding elements.
  2. We use a preceding all-reduce collective per FlatParameter to indicate if each rank's sharded gradient is None or not. This solves the problem from 1.

Since summon_full_params(with_grads=True) is meant for debugging, I can see the argument for pursuing 2. I can change this in a follow-up PR and add/adjust unit tests accordingly.

)
param = getattr(module, param_name)
param.grad = view
for i, (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this for shared parameters? could you give an example for what this changes when there are shared params?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This just makes sure that each shared parameter's .grad is also populated. Without this, then if we had a model like

lin = nn.Linear(5, 5)
lin.weight = lin.bias

then only one of (lin.weight, lin.bias) would get a .grad since only one of them is in accounted for in _param_infos / _params / etc.

This adds `summon_full_params(with_grads=True)` for `use_orig_params=True` and `offload_to_cpu=False`. Filling in the `use_orig_params=False` case requires some already-planned refactoring, and the `offload_to_cpu=True` case needs some additional work as well.

Adding this is helpful for debugging `use_orig_params=True` to make sure gradients are being updated correctly.

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Sep 27, 2022
ghstack-source-id: 9c80c33
Pull Request resolved: #85738
This adds `summon_full_params(with_grads=True)` for `use_orig_params=True` and `offload_to_cpu=False`. Filling in the `use_orig_params=False` case requires some already-planned refactoring, and the `offload_to_cpu=True` case needs some additional work as well.

Adding this is helpful for debugging `use_orig_params=True` to make sure gradients are being updated correctly.

[ghstack-poisoned]
This adds `summon_full_params(with_grads=True)` for `use_orig_params=True` and `offload_to_cpu=False`. Filling in the `use_orig_params=False` case requires some already-planned refactoring, and the `offload_to_cpu=True` case needs some additional work as well.

Adding this is helpful for debugging `use_orig_params=True` to make sure gradients are being updated correctly.

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Sep 28, 2022
ghstack-source-id: aee4fa1
Pull Request resolved: #85738
This adds `summon_full_params(with_grads=True)` for `use_orig_params=True` and `offload_to_cpu=False`. Filling in the `use_orig_params=False` case requires some already-planned refactoring, and the `offload_to_cpu=True` case needs some additional work as well.

Adding this is helpful for debugging `use_orig_params=True` to make sure gradients are being updated correctly.

[ghstack-poisoned]
This adds `summon_full_params(with_grads=True)` for `use_orig_params=True` and `offload_to_cpu=False`. Filling in the `use_orig_params=False` case requires some already-planned refactoring, and the `offload_to_cpu=True` case needs some additional work as well.

Adding this is helpful for debugging `use_orig_params=True` to make sure gradients are being updated correctly.

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Sep 28, 2022
ghstack-source-id: 2ac67ef
Pull Request resolved: #85738
This adds `summon_full_params(with_grads=True)` for `use_orig_params=True` and `offload_to_cpu=False`. Filling in the `use_orig_params=False` case requires some already-planned refactoring, and the `offload_to_cpu=True` case needs some additional work as well.

Adding this is helpful for debugging `use_orig_params=True` to make sure gradients are being updated correctly.

[ghstack-poisoned]
This adds `summon_full_params(with_grads=True)` for `use_orig_params=True` and `offload_to_cpu=False`. Filling in the `use_orig_params=False` case requires some already-planned refactoring, and the `offload_to_cpu=True` case needs some additional work as well.

Adding this is helpful for debugging `use_orig_params=True` to make sure gradients are being updated correctly.

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Sep 30, 2022
ghstack-source-id: a9751ab
Pull Request resolved: #85738
This adds `summon_full_params(with_grads=True)` for `use_orig_params=True` and `offload_to_cpu=False`. Filling in the `use_orig_params=False` case requires some already-planned refactoring, and the `offload_to_cpu=True` case needs some additional work as well.

Adding this is helpful for debugging `use_orig_params=True` to make sure gradients are being updated correctly.

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

/easycla

As part of the transition to the PyTorch Foundation, this project now requires contributions be covered under the new CLA. See #85559 for additional details.

This comment will trigger a new check of this PR. If you are already covered, you will simply see a new "EasyCLA" check that passes. If you are not covered, a bot will leave a new comment with a link to sign.

This adds `summon_full_params(with_grads=True)` for `use_orig_params=True` and `offload_to_cpu=False`. Filling in the `use_orig_params=False` case requires some already-planned refactoring, and the `offload_to_cpu=True` case needs some additional work as well.

Adding this is helpful for debugging `use_orig_params=True` to make sure gradients are being updated correctly.

[ghstack-poisoned]
@awgu awgu added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 7, 2022
@awgu
Copy link
Collaborator Author

awgu commented Oct 7, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions
Copy link
Contributor

github-actions bot commented Oct 7, 2022

Hey @awgu.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

Copy link
Contributor

@zhaojuanmao zhaojuanmao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for adding this! It will improve debug ability a lot under summon_full_params mode for users

sharded_grad = torch.zeros_like(flat_param) # type: ignore[attr-defined]
else:
self._check_sharded(flat_param.grad)
flat_param._saved_grad_shard = flat_param.grad # type: ignore[attr-defined]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious why we intentionally fill "flat_param._saved_grad_shard" here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already use _saved_grad_shard for saving the sharded gradient during the backward pass. Here, I use the same variable, now for a different purpose: saving the sharded gradient during summon_full_params(with_grads=True). This is just to avoid creating an entirely new variable also for saving the sharded gradient.

facebook-github-bot pushed a commit that referenced this pull request Oct 10, 2022
…5738)

Summary:
This adds `summon_full_params(with_grads=True)` for `use_orig_params=True` and `offload_to_cpu=False`. Filling in the `use_orig_params=False` case requires some already-planned refactoring, and the `offload_to_cpu=True` case needs some additional work as well.

Adding this is helpful for debugging `use_orig_params=True` to make sure gradients are being updated correctly.

Pull Request resolved: #85738
Approved by: https://github.com/rohan-varma

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/a95889ba7c1ecd8cb0f90507a6152cb035bcefd1

Reviewed By: seemethere

Differential Revision: D40197192

Pulled By: seemethere

fbshipit-source-id: 742ea641d7f005946e0714181c0a91167fe9fb9d
@facebook-github-bot facebook-github-bot deleted the gh/awgu/110/head branch June 8, 2023 15:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request cla signed Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants