Skip to content

Conversation

@awgu
Copy link
Collaborator

@awgu awgu commented Oct 19, 2022

Stack from ghstack:

This PR changes summon_full_params(with_grads=True)'s behavior to be such that if all ranks have flat_param.grad = None, then the original parameters will correctly have orig_param.grad = None. This is achieved with a preliminary all-reduce. Note that if a particular original parameter's gradient is None on all of the containing ranks, but not all ranks' flat_param.grad = None, then that particular gradient is still going to be set to zeros. This can be handled if desired in follow-up work.

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 19, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87314

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 988a6ea:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: distributed (fsdp) release notes category label Oct 19, 2022
… gradient"


This PR changes `summon_full_params(with_grads=True)`'s behavior to be such that if all ranks have `flat_param.grad = None`, then the original parameters will correctly have `orig_param.grad = None`. This is achieved with a preliminary all-reduce. Note that if a particular original parameter's gradient is `None` on all of the containing ranks, but not all ranks' `flat_param.grad = None`, then that particular gradient is still going to be set to zeros. This can be handled if desired in follow-up work.

[ghstack-poisoned]
@torch.no_grad()
def unshard_grad(self):
"""
Unshards the handle's ``FlatParameter`` 's gradient. If all ranks have
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe add a comment that 'unshard_grad' is not used in critical path, only used in summon_full_params(), as it calls all_reduce and may have performance impact

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 21, 2022
… gradient"


This PR changes `summon_full_params(with_grads=True)`'s behavior to be such that if all ranks have `flat_param.grad = None`, then the original parameters will correctly have `orig_param.grad = None`. This is achieved with a preliminary all-reduce. Note that if a particular original parameter's gradient is `None` on all of the containing ranks, but not all ranks' `flat_param.grad = None`, then that particular gradient is still going to be set to zeros. This can be handled if desired in follow-up work.

[ghstack-poisoned]
awgu pushed a commit to awgu/pytorch that referenced this pull request Oct 21, 2022
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Nov 5, 2022
…ytorch#87314)

This PR changes `summon_full_params(with_grads=True)`'s behavior to be such that if all ranks have `flat_param.grad = None`, then the original parameters will correctly have `orig_param.grad = None`. This is achieved with a preliminary all-reduce. Note that if a particular original parameter's gradient is `None` on all of the containing ranks, but not all ranks' `flat_param.grad = None`, then that particular gradient is still going to be set to zeros. This can be handled if desired in follow-up work.
Pull Request resolved: pytorch#87314
Approved by: https://github.com/zhaojuanmao
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
…ytorch#87314)

This PR changes `summon_full_params(with_grads=True)`'s behavior to be such that if all ranks have `flat_param.grad = None`, then the original parameters will correctly have `orig_param.grad = None`. This is achieved with a preliminary all-reduce. Note that if a particular original parameter's gradient is `None` on all of the containing ranks, but not all ranks' `flat_param.grad = None`, then that particular gradient is still going to be set to zeros. This can be handled if desired in follow-up work.
Pull Request resolved: pytorch#87314
Approved by: https://github.com/zhaojuanmao
@facebook-github-bot facebook-github-bot deleted the gh/awgu/135/head branch June 8, 2023 15:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request release notes: distributed (fsdp) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants