-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[FSDP][1/N] Update summon_full_params(with_grads) None gradient
#87314
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87314
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 988a6ea: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
… gradient" This PR changes `summon_full_params(with_grads=True)`'s behavior to be such that if all ranks have `flat_param.grad = None`, then the original parameters will correctly have `orig_param.grad = None`. This is achieved with a preliminary all-reduce. Note that if a particular original parameter's gradient is `None` on all of the containing ranks, but not all ranks' `flat_param.grad = None`, then that particular gradient is still going to be set to zeros. This can be handled if desired in follow-up work. [ghstack-poisoned]
| @torch.no_grad() | ||
| def unshard_grad(self): | ||
| """ | ||
| Unshards the handle's ``FlatParameter`` 's gradient. If all ranks have |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe add a comment that 'unshard_grad' is not used in critical path, only used in summon_full_params(), as it calls all_reduce and may have performance impact
… gradient" This PR changes `summon_full_params(with_grads=True)`'s behavior to be such that if all ranks have `flat_param.grad = None`, then the original parameters will correctly have `orig_param.grad = None`. This is achieved with a preliminary all-reduce. Note that if a particular original parameter's gradient is `None` on all of the containing ranks, but not all ranks' `flat_param.grad = None`, then that particular gradient is still going to be set to zeros. This can be handled if desired in follow-up work. [ghstack-poisoned]
ghstack-source-id: a0651cd Pull Request resolved: pytorch#87314
…ytorch#87314) This PR changes `summon_full_params(with_grads=True)`'s behavior to be such that if all ranks have `flat_param.grad = None`, then the original parameters will correctly have `orig_param.grad = None`. This is achieved with a preliminary all-reduce. Note that if a particular original parameter's gradient is `None` on all of the containing ranks, but not all ranks' `flat_param.grad = None`, then that particular gradient is still going to be set to zeros. This can be handled if desired in follow-up work. Pull Request resolved: pytorch#87314 Approved by: https://github.com/zhaojuanmao
…ytorch#87314) This PR changes `summon_full_params(with_grads=True)`'s behavior to be such that if all ranks have `flat_param.grad = None`, then the original parameters will correctly have `orig_param.grad = None`. This is achieved with a preliminary all-reduce. Note that if a particular original parameter's gradient is `None` on all of the containing ranks, but not all ranks' `flat_param.grad = None`, then that particular gradient is still going to be set to zeros. This can be handled if desired in follow-up work. Pull Request resolved: pytorch#87314 Approved by: https://github.com/zhaojuanmao
Stack from ghstack:
Noneedge case #87308 [FSDP][2/N] Fix grad zero vs.Noneedge casesummon_full_params(with_grads)Nonegradient #87314 [FSDP][1/N] Updatesummon_full_params(with_grads)NonegradientThis PR changes
summon_full_params(with_grads=True)'s behavior to be such that if all ranks haveflat_param.grad = None, then the original parameters will correctly haveorig_param.grad = None. This is achieved with a preliminary all-reduce. Note that if a particular original parameter's gradient isNoneon all of the containing ranks, but not all ranks'flat_param.grad = None, then that particular gradient is still going to be set to zeros. This can be handled if desired in follow-up work.