-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[shard] use scatter in shard_parameter API #72160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This PR switches `shard_parameter` API to use `dist.scatter` instead of `dist.broadcast`. Instead of sending the whole tensor to each rank, we split the tensor beforehand, and only send the part needed to the corresponding rank, which greatly reduce the communication overhead. Differential Revision: [D33933419](https://our.internmc.facebook.com/intern/diff/D33933419/) [ghstack-poisoned]
CI Flow Status⚛️ CI FlowRuleset - Version:
|
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 0af895f (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
This PR switches `shard_parameter` API to use `dist.scatter` instead of `dist.broadcast`. Instead of sending the whole tensor to each rank, we split the tensor beforehand, and only send the part needed to the corresponding rank, which greatly reduce the communication overhead. Differential Revision: [D33933419](https://our.internmc.facebook.com/intern/diff/D33933419/) [ghstack-poisoned]
This PR switches `shard_parameter` API to use `dist.scatter` instead of `dist.broadcast`. Instead of sending the whole tensor to each rank, we split the tensor beforehand, and only send the part needed to the corresponding rank, which greatly reduce the communication overhead. Differential Revision: [D33933419](https://our.internmc.facebook.com/intern/diff/D33933419/) [ghstack-poisoned]
Pull Request resolved: #72160 This PR switches `shard_parameter` API to use `dist.scatter` instead of `dist.broadcast`. Instead of sending the whole tensor to each rank, we split the tensor beforehand, and only send the part needed to the corresponding rank, which greatly reduce the communication overhead. ghstack-source-id: 148163484 Differential Revision: [D33933419](https://our.internmc.facebook.com/intern/diff/D33933419/)
pritamdamania87
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just had one comment regarding the change in distributed_c10d.py.
| if scatter_list: | ||
| raise ValueError( | ||
| "Argument ``scatter_list`` must NOT be specified " | ||
| "on non-source ranks." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this an unintended change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is intended change actually, it's to match the behavior in gather where dist.gather allow user to pass in gather_list on non-dst ranks, but we internally treat it as empty list
pytorch/torch/distributed/distributed_c10d.py
Line 2253 in defde3b
| gather_list = [] |
I think the dist.gather behavior make more sense, user can call the collective on all ranks SPMD, and inside dist.gather we ignore the gather/scatter_list on non-dst/src ranks. Otherwise, user have to write conditional code everytime to perform gather/scatter, i.e. for scatter it must be:
if current_rank == src_rank:
dist.scatter(local_tensor, scatter_list=tensors_to_scatter, src=src_rank, group=pg)
else:
dist.scatter(local_tensor, scatter_list=None, src=src_rank, group=pg)
which might be very in-convenient for user. Our low level primitive: pg.gather/scatter enforce the non-dst/src ranks must pass None.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to make this logic (making scatter/gather list to be None for non_src_rank) centralized somewhere in Python? Since we have enforced this in CPP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it was embedded in dist.gather
pytorch/torch/distributed/distributed_c10d.py
Line 2261 in c1a4714
| output_tensors = [gather_list] if dst == my_rank else [] |
dist.scatter, pytorch/torch/distributed/distributed_c10d.py
Line 2334 in c1a4714
| input_tensors = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The thing is that our docs clearly say it should not be specified on src_rank:
scatter_list (list[Tensor]): List of tensors to scatter (default is
None, must be specified on the source rank)
So this change would be confusing for users where the docs say something whereas the actual behavior is different.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While both scatter and gather doc does not mention that scatter_list "must ONLY be specified on the source rank", so I guess this part is a bit confusing to the user. Anyways, I think it should be a separate discussion around consistency of our c10d APIs, not related to this PR specifically. I just updated the PR to make sure we are passing None on non-dst/src rank, and remove this unintentional change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
created a issue to track this conversation #74323
fduwjj
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can get rid of broadcast finally. Does GH support diff review meme?
| if scatter_list: | ||
| raise ValueError( | ||
| "Argument ``scatter_list`` must NOT be specified " | ||
| "on non-source ranks." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to make this logic (making scatter/gather list to be None for non_src_rank) centralized somewhere in Python? Since we have enforced this in CPP.
| # Reshape to get shard for each rank and we don't want autograd | ||
| # recording here for the narrow op and 'tensor_to_scatter' should be a | ||
| # leaf variable in the autograd graph. | ||
| tensor_to_scatter = tensor.narrow( | ||
| sharding_spec.dim, # type: ignore[arg-type] | ||
| shard_metadata.shard_offsets[sharding_spec.dim], # type: ignore[union-attr, arg-type, index] | ||
| shard_metadata.shard_sizes[sharding_spec.dim], # type: ignore[union-attr, index] | ||
| ).clone().detach().contiguous() | ||
| tensors_to_scatter.append( | ||
| tensor_to_scatter | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once your new sharding spec extension PR is ready, I guess the logic here can reuse what's been written in the shard api?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah agreed, this might just be using ChunkShardingSpec.shard
|
|
||
| current_offsets[sharding_spec.dim] += chunked_dim_size # type: ignore[index] | ||
|
|
||
| # Scatter the shards (use broadcast since NCCL doesn't support scatter, this is very inefficient). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: You can remove this comment now 😊
pritamdamania87
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this PR needs to be rebased on top of #72130 or vice-versa?
|
yep, plan to rebase soon (hopefully can land that PR then rebase to master directly) |
This PR switches `shard_parameter` API to use `dist.scatter` instead of `dist.broadcast`. Instead of sending the whole tensor to each rank, we split the tensor beforehand, and only send the part needed to the corresponding rank, which greatly reduce the communication overhead. Differential Revision: [D33933419](https://our.internmc.facebook.com/intern/diff/D33933419/) [ghstack-poisoned]
Pull Request resolved: #72160 This PR switches `shard_parameter` API to use `dist.scatter` instead of `dist.broadcast`. Instead of sending the whole tensor to each rank, we split the tensor beforehand, and only send the part needed to the corresponding rank, which greatly reduce the communication overhead. ghstack-source-id: 150986659 Differential Revision: [D33933419](https://our.internmc.facebook.com/intern/diff/D33933419/)
This PR switches `shard_parameter` API to use `dist.scatter` instead of `dist.broadcast`. Instead of sending the whole tensor to each rank, we split the tensor beforehand, and only send the part needed to the corresponding rank, which greatly reduce the communication overhead. Differential Revision: [D33933419](https://our.internmc.facebook.com/intern/diff/D33933419/) [ghstack-poisoned]
Pull Request resolved: #72160 This PR switches `shard_parameter` API to use `dist.scatter` instead of `dist.broadcast`. Instead of sending the whole tensor to each rank, we split the tensor beforehand, and only send the part needed to the corresponding rank, which greatly reduce the communication overhead. ghstack-source-id: 151007579 Differential Revision: [D33933419](https://our.internmc.facebook.com/intern/diff/D33933419/)
pritamdamania87
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Requesting changes since I think scatter behavior should align with our docs
Summary: Pull Request resolved: #72160 This PR switches `shard_parameter` API to use `dist.scatter` instead of `dist.broadcast`. Instead of sending the whole tensor to each rank, we split the tensor beforehand, and only send the part needed to the corresponding rank, which greatly reduce the communication overhead. ghstack-source-id: 151643718 Test Plan: test_shard_parameter test_shard_parameter_errors Reviewed By: pritamdamania87 Differential Revision: D33933419 fbshipit-source-id: c823c5d0066a9fe7451c07cbacb30a3bbd361af4
|
Hey @wanchaol. |
Summary: Pull Request resolved: #72160 This PR switches `shard_parameter` API to use `dist.scatter` instead of `dist.broadcast`. Instead of sending the whole tensor to each rank, we split the tensor beforehand, and only send the part needed to the corresponding rank, which greatly reduce the communication overhead. ghstack-source-id: 151643718 Test Plan: test_shard_parameter test_shard_parameter_errors Reviewed By: pritamdamania87 Differential Revision: D33933419 fbshipit-source-id: c823c5d0066a9fe7451c07cbacb30a3bbd361af4 (cherry picked from commit b1b553e)
Original commit changeset: c823c5d0066a Original Phabricator Diff: D33933419 #72160 broke some sharded tensor tests. Let's revert it first and reland it again once the test failure has been fixed. Differential Revision: [D35418031](https://our.internmc.facebook.com/intern/diff/D35418031/) [ghstack-poisoned]
Summary: Pull Request resolved: #75295 Original commit changeset: c823c5d0066a Original Phabricator Diff: D33933419 (288de54) #72160 broke some sharded tensor tests. Let's revert it first and reland it again once the test failure has been fixed. ghstack-source-id: 153135278 (Note: this ignores all push blocking failures!) Test Plan: CI Reviewed By: wanchaol Differential Revision: D35418031 fbshipit-source-id: a4435a62d4487a927d0cf79624afe4f0951809c8
Summary: Pull Request resolved: #75295 Original commit changeset: c823c5d0066a Original Phabricator Diff: D33933419 (288de54) #72160 broke some sharded tensor tests. Let's revert it first and reland it again once the test failure has been fixed. ghstack-source-id: 153135278 (Note: this ignores all push blocking failures!) Test Plan: CI Reviewed By: wanchaol Differential Revision: D35418031 fbshipit-source-id: a4435a62d4487a927d0cf79624afe4f0951809c8 (cherry picked from commit a20d119)
|
This pull request has been reverted by 57ba615. To re-land this change, please open another pull request, assignthe same reviewers, fix the CI failures that caused the revert and make sure that the failing CI runs on the PR by applying the proper ciflow label (e.g., ciflow/trunk). |
Original commit changeset: c823c5d0066a Original Phabricator Diff: D33933419 pytorch/pytorch#72160 broke some sharded tensor tests. Let's revert it first and reland it again once the test failure has been fixed. Differential Revision: [D35418031](https://our.internmc.facebook.com/intern/diff/D35418031/) ghstack-source-id: 153135278 Pull Request resolved: pytorch/pytorch#75295
This is a reland of PR #72160. The previous PR failed on some cases where un-even scatter happens. So this PR made some additional fixes to ensure it scatters correctly. 1. Fix a bug in ProcessGroupNCCL::scatter, which is a similar issue to #75535 2. resize the shard to the same size before calling `dist.scatter`, and resize it back to the original layout after receiving from scatter. Differential Revision: [D35726920](https://our.internmc.facebook.com/intern/diff/D35726920/) [ghstack-poisoned]
This is a reland of PR #72160. The previous PR failed on some cases where un-even scatter happens. So this PR made some additional fixes to ensure it scatters correctly. 1. Fix a bug in ProcessGroupNCCL::scatter, which is a similar issue to #75535 2. resize the shard to the same size before calling `dist.scatter`, and resize it back to the original layout after receiving from scatter. Differential Revision: [D35726920](https://our.internmc.facebook.com/intern/diff/D35726920/) [ghstack-poisoned]
This is a reland of PR #72160. The previous PR failed on some cases where un-even scatter happens. So this PR made some additional fixes to ensure it scatters correctly. 1. Fix a bug in ProcessGroupNCCL::scatter, which is a similar issue to #75535 2. resize the shard to the same size before calling `dist.scatter`, and resize it back to the original layout after receiving from scatter. Differential Revision: [D35726920](https://our.internmc.facebook.com/intern/diff/D35726920/) [ghstack-poisoned]
Pull Request resolved: #75991 This is a reland of PR #72160. The previous PR failed on some cases where un-even scatter happens. So this PR made some additional fixes to ensure it scatters correctly. 1. Fix a bug in ProcessGroupNCCL::scatter, which is a similar issue to #75535 2. resize the shard to the same size before calling `dist.scatter`, and resize it back to the original layout after receiving from scatter. ghstack-source-id: 154155925 Differential Revision: [D35726920](https://our.internmc.facebook.com/intern/diff/D35726920/)
This is a reland of PR #72160. The previous PR failed on some cases where un-even scatter happens. So this PR made some additional fixes to ensure it scatters correctly. 1. Fix a bug in ProcessGroupNCCL::scatter, which is a similar issue to #75535 2. resize the shard to the same size before calling `dist.scatter`, and resize it back to the original layout after receiving from scatter. Differential Revision: [D35726920](https://our.internmc.facebook.com/intern/diff/D35726920/) [ghstack-poisoned]
Pull Request resolved: #75991 This is a reland of PR #72160. The previous PR failed on some cases where un-even scatter happens. So this PR made some additional fixes to ensure it scatters correctly. 1. Fix a bug in ProcessGroupNCCL::scatter, which is a similar issue to #75535 2. resize the shard to the same size before calling `dist.scatter`, and resize it back to the original layout after receiving from scatter. ghstack-source-id: 154682435 Differential Revision: [D35726920](https://our.internmc.facebook.com/intern/diff/D35726920/)
This is a reland of PR #72160. The previous PR failed on some cases where un-even scatter happens. So this PR made some additional fixes to ensure it scatters correctly. 1. Fix a bug in ProcessGroupNCCL::scatter, which is a similar issue to #75535 2. resize the shard to the same size before calling `dist.scatter`, and resize it back to the original layout after receiving from scatter. Differential Revision: [D35726920](https://our.internmc.facebook.com/intern/diff/D35726920/) [ghstack-poisoned]
This is a reland of PR #72160. The previous PR failed on some cases where un-even scatter happens. So this PR made some additional fixes to ensure it scatters correctly. 1. Fix a bug in ProcessGroupNCCL::scatter, which is a similar issue to #75535 2. resize the shard to the same size before calling `dist.scatter`, and resize it back to the original layout after receiving from scatter. Differential Revision: [D35726920](https://our.internmc.facebook.com/intern/diff/D35726920/) [ghstack-poisoned]
Pull Request resolved: #75991 This is a reland of PR #72160. The previous PR failed on some cases where un-even scatter happens. So this PR made some additional fixes to ensure it scatters correctly. 1. Fix a bug in ProcessGroupNCCL::scatter, which is a similar issue to #75535 2. resize the shard to the same size before calling `dist.scatter`, and resize it back to the original layout after receiving from scatter. ghstack-source-id: 154725305 Differential Revision: [D35726920](https://our.internmc.facebook.com/intern/diff/D35726920/)
This is a reland of PR #72160. The previous PR failed on some cases where un-even scatter happens. So this PR made some additional fixes to ensure it scatters correctly. 1. Fix a bug in ProcessGroupNCCL::scatter, which is a similar issue to #75535 2. resize the shard to the same size before calling `dist.scatter`, and resize it back to the original layout after receiving from scatter. Differential Revision: [D35726920](https://our.internmc.facebook.com/intern/diff/D35726920/) [ghstack-poisoned]
Summary: Pull Request resolved: #75991 This is a reland of PR #72160. The previous PR failed on some cases where un-even scatter happens. So this PR made some additional fixes to ensure it scatters correctly. 1. Fix a bug in ProcessGroupNCCL::scatter, which is a similar issue to #75535 2. resize the shard to the same size before calling `dist.scatter`, and resize it back to the original layout after receiving from scatter. ghstack-source-id: 154725305 Test Plan: test_sharded_tensor test_linear test_megatron_prototype test_embedding/embeddingbag Reviewed By: pritamdamania87 Differential Revision: D35726920 fbshipit-source-id: d9bd0e44f47ef5b9e9add0dc66c5fda99e93943a
Summary: Pull Request resolved: #75991 This is a reland of PR #72160. The previous PR failed on some cases where un-even scatter happens. So this PR made some additional fixes to ensure it scatters correctly. 1. Fix a bug in ProcessGroupNCCL::scatter, which is a similar issue to #75535 2. resize the shard to the same size before calling `dist.scatter`, and resize it back to the original layout after receiving from scatter. ghstack-source-id: 154725305 Test Plan: test_sharded_tensor test_linear test_megatron_prototype test_embedding/embeddingbag Reviewed By: pritamdamania87 Differential Revision: D35726920 fbshipit-source-id: d9bd0e44f47ef5b9e9add0dc66c5fda99e93943a (cherry picked from commit ed11e5d)
Stack from ghstack (oldest at bottom):
This PR switches
shard_parameterAPI to usedist.scatterinstead ofdist.broadcast. Instead of sending the whole tensor to each rank, we split the tensor beforehand, and only send the part needed to the corresponding rank, which greatly reduce the communication overhead.Differential Revision: D33933419