-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[FSDP] Add initial summon_full_params(with_grads=True)
#85738
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/85738
Note: Links to docs will display an error until the docs builds have been completed. ✅ No Failures, 1 PendingAs of commit 8122cbd: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This adds `summon_full_params(with_grads=True)` for `use_orig_params=True` and `offload_to_cpu=False`. Filling in the `use_orig_params=False` case requires some already-planned refactoring, and the `offload_to_cpu=True` case needs some additional work as well. [ghstack-poisoned]
This adds `summon_full_params(with_grads=True)` for `use_orig_params=True` and `offload_to_cpu=False`. Filling in the `use_orig_params=False` case requires some already-planned refactoring, and the `offload_to_cpu=True` case needs some additional work as well. [ghstack-poisoned]
This adds `summon_full_params(with_grads=True)` for `use_orig_params=True` and `offload_to_cpu=False`. Filling in the `use_orig_params=False` case requires some already-planned refactoring, and the `offload_to_cpu=True` case needs some additional work as well. [ghstack-poisoned]
rohan-varma
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall, some mostly minor questions / comments.
| self._check_sharded(flat_param.grad) | ||
| flat_param._saved_grad_shard = flat_param.grad # type: ignore[attr-defined] | ||
| sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined] | ||
| dist._all_gather_base(padded_unsharded_grad, sharded_grad, self.process_group) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does it mean that the gradient is all zeros if the flat_param.grad = None on all ranks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We discussed briefly offline, but there are two options:
- (As in the PR currently) We only use a single all-gather collective per
FlatParameter. In that case, if all ranks' sharded gradient isNone, then the unsharded gradient is incorrectlytorch.zeros(unsharded_size). If only some ranks' sharded gradients areNone, then the unsharded gradient zeros those corresponding elements. - We use a preceding all-reduce collective per
FlatParameterto indicate if each rank's sharded gradient isNoneor not. This solves the problem from 1.
Since summon_full_params(with_grads=True) is meant for debugging, I can see the argument for pursuing 2. I can change this in a follow-up PR and add/adjust unit tests accordingly.
| ) | ||
| param = getattr(module, param_name) | ||
| param.grad = view | ||
| for i, ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this for shared parameters? could you give an example for what this changes when there are shared params?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This just makes sure that each shared parameter's .grad is also populated. Without this, then if we had a model like
lin = nn.Linear(5, 5)
lin.weight = lin.bias
then only one of (lin.weight, lin.bias) would get a .grad since only one of them is in accounted for in _param_infos / _params / etc.
This adds `summon_full_params(with_grads=True)` for `use_orig_params=True` and `offload_to_cpu=False`. Filling in the `use_orig_params=False` case requires some already-planned refactoring, and the `offload_to_cpu=True` case needs some additional work as well. Adding this is helpful for debugging `use_orig_params=True` to make sure gradients are being updated correctly. [ghstack-poisoned]
This adds `summon_full_params(with_grads=True)` for `use_orig_params=True` and `offload_to_cpu=False`. Filling in the `use_orig_params=False` case requires some already-planned refactoring, and the `offload_to_cpu=True` case needs some additional work as well. Adding this is helpful for debugging `use_orig_params=True` to make sure gradients are being updated correctly. [ghstack-poisoned]
This adds `summon_full_params(with_grads=True)` for `use_orig_params=True` and `offload_to_cpu=False`. Filling in the `use_orig_params=False` case requires some already-planned refactoring, and the `offload_to_cpu=True` case needs some additional work as well. Adding this is helpful for debugging `use_orig_params=True` to make sure gradients are being updated correctly. [ghstack-poisoned]
This adds `summon_full_params(with_grads=True)` for `use_orig_params=True` and `offload_to_cpu=False`. Filling in the `use_orig_params=False` case requires some already-planned refactoring, and the `offload_to_cpu=True` case needs some additional work as well. Adding this is helpful for debugging `use_orig_params=True` to make sure gradients are being updated correctly. [ghstack-poisoned]
This adds `summon_full_params(with_grads=True)` for `use_orig_params=True` and `offload_to_cpu=False`. Filling in the `use_orig_params=False` case requires some already-planned refactoring, and the `offload_to_cpu=True` case needs some additional work as well. Adding this is helpful for debugging `use_orig_params=True` to make sure gradients are being updated correctly. [ghstack-poisoned]
This adds `summon_full_params(with_grads=True)` for `use_orig_params=True` and `offload_to_cpu=False`. Filling in the `use_orig_params=False` case requires some already-planned refactoring, and the `offload_to_cpu=True` case needs some additional work as well. Adding this is helpful for debugging `use_orig_params=True` to make sure gradients are being updated correctly. [ghstack-poisoned]
This adds `summon_full_params(with_grads=True)` for `use_orig_params=True` and `offload_to_cpu=False`. Filling in the `use_orig_params=False` case requires some already-planned refactoring, and the `offload_to_cpu=True` case needs some additional work as well. Adding this is helpful for debugging `use_orig_params=True` to make sure gradients are being updated correctly. [ghstack-poisoned]
This adds `summon_full_params(with_grads=True)` for `use_orig_params=True` and `offload_to_cpu=False`. Filling in the `use_orig_params=False` case requires some already-planned refactoring, and the `offload_to_cpu=True` case needs some additional work as well. Adding this is helpful for debugging `use_orig_params=True` to make sure gradients are being updated correctly. [ghstack-poisoned]
|
/easycla As part of the transition to the PyTorch Foundation, this project now requires contributions be covered under the new CLA. See #85559 for additional details. This comment will trigger a new check of this PR. If you are already covered, you will simply see a new "EasyCLA" check that passes. If you are not covered, a bot will leave a new comment with a link to sign. |
This adds `summon_full_params(with_grads=True)` for `use_orig_params=True` and `offload_to_cpu=False`. Filling in the `use_orig_params=False` case requires some already-planned refactoring, and the `offload_to_cpu=True` case needs some additional work as well. Adding this is helpful for debugging `use_orig_params=True` to make sure gradients are being updated correctly. [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
Hey @awgu. |
zhaojuanmao
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for adding this! It will improve debug ability a lot under summon_full_params mode for users
| sharded_grad = torch.zeros_like(flat_param) # type: ignore[attr-defined] | ||
| else: | ||
| self._check_sharded(flat_param.grad) | ||
| flat_param._saved_grad_shard = flat_param.grad # type: ignore[attr-defined] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious why we intentionally fill "flat_param._saved_grad_shard" here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We already use _saved_grad_shard for saving the sharded gradient during the backward pass. Here, I use the same variable, now for a different purpose: saving the sharded gradient during summon_full_params(with_grads=True). This is just to avoid creating an entirely new variable also for saving the sharded gradient.
…5738) Summary: This adds `summon_full_params(with_grads=True)` for `use_orig_params=True` and `offload_to_cpu=False`. Filling in the `use_orig_params=False` case requires some already-planned refactoring, and the `offload_to_cpu=True` case needs some additional work as well. Adding this is helpful for debugging `use_orig_params=True` to make sure gradients are being updated correctly. Pull Request resolved: #85738 Approved by: https://github.com/rohan-varma Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/a95889ba7c1ecd8cb0f90507a6152cb035bcefd1 Reviewed By: seemethere Differential Revision: D40197192 Pulled By: seemethere fbshipit-source-id: 742ea641d7f005946e0714181c0a91167fe9fb9d
Stack from ghstack:
_fsdp_wrapped_module.flat_param#86122 [FSDP][2/N] Remove_fsdp_wrapped_module.flat_paramFlattenParamsWrapper#86117 [FSDP][1/N] RetireFlattenParamsWrappersummon_full_params(with_grads=True)#85738 [FSDP] Add initialsummon_full_params(with_grads=True)use_orig_params#84911 [FSDP] Adduse_orig_paramsThis adds
summon_full_params(with_grads=True)foruse_orig_params=Trueandoffload_to_cpu=False. Filling in theuse_orig_params=Falsecase requires some already-planned refactoring, and theoffload_to_cpu=Truecase needs some additional work as well.Adding this is helpful for debugging
use_orig_params=Trueto make sure gradients are being updated correctly.