Skip to content

Conversation

@wanchaol
Copy link
Collaborator

@wanchaol wanchaol commented Feb 2, 2022

Stack from ghstack (oldest at bottom):

This PR switches shard_parameter API to use dist.scatter instead of dist.broadcast. Instead of sending the whole tensor to each rank, we split the tensor beforehand, and only send the part needed to the corresponding rank, which greatly reduce the communication overhead.

Differential Revision: D33933419

This PR switches `shard_parameter` API to use `dist.scatter` instead of `dist.broadcast`. Instead of sending the whole tensor to each rank, we split the tensor beforehand, and only send the part needed to the corresponding rank, which greatly reduce the communication overhead.

Differential Revision: [D33933419](https://our.internmc.facebook.com/intern/diff/D33933419/)

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Feb 2, 2022

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/pytorch/pytorch/blob/8c1733c27c28d3bf34287999a75ba1d459f3b65d/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default
Add ciflow labels to this PR to trigger more builds:

Workflows Labels (bold enabled) Status
Triggered Workflows
linux-binary-conda ciflow/binaries, ciflow/binaries_conda, ciflow/default ✅ triggered
linux-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
linux-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
linux-binary-manywheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
linux-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/trunk, ciflow/xla ✅ triggered
linux-docs ciflow/all, ciflow/cpu, ciflow/default, ciflow/docs, ciflow/linux, ciflow/trunk ✅ triggered
linux-vulkan-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-build ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-static ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7-no-ops ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
windows-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
windows-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries_libtorch, ciflow/default ✅ triggered
windows-binary-wheel ciflow/binaries, ciflow/binaries_wheel, ciflow/default ✅ triggered
Skipped Workflows
caffe2-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
docker-builds ciflow/all, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-custom-ops ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-full-jit ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-metal ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64-full-jit ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda10.2-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow, ciflow/trunk 🚫 skipped
linux-bionic-rocm4.5-py3.7 ciflow/linux, ciflow/rocm 🚫 skipped
linux-docs-push ciflow/all, ciflow/cpu, ciflow/linux, ciflow/scheduled 🚫 skipped
linux-xenial-cuda11.3-py3.7-gcc7-no-ops ciflow/all, ciflow/cuda, ciflow/linux, ciflow/trunk 🚫 skipped
macos-10-15-py3-arm64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-10-15-py3-lite-interpreter-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-11-py3-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
parallelnative-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
periodic-libtorch-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-libtorch-linux-xenial-cuda11.1-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck 🚫 skipped
periodic-linux-xenial-cuda11.1-py3.7-gcc7-debug ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-win-vs2019-cuda11.1-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped
periodic-win-vs2019-cuda11.5-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-build ciflow/all, ciflow/android, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Feb 2, 2022

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 0af895f (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@facebook-github-bot facebook-github-bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Feb 2, 2022
@wanchaol wanchaol requested a review from fduwjj February 2, 2022 02:03
This PR switches `shard_parameter` API to use `dist.scatter` instead of `dist.broadcast`. Instead of sending the whole tensor to each rank, we split the tensor beforehand, and only send the part needed to the corresponding rank, which greatly reduce the communication overhead.

Differential Revision: [D33933419](https://our.internmc.facebook.com/intern/diff/D33933419/)

[ghstack-poisoned]
This PR switches `shard_parameter` API to use `dist.scatter` instead of `dist.broadcast`. Instead of sending the whole tensor to each rank, we split the tensor beforehand, and only send the part needed to the corresponding rank, which greatly reduce the communication overhead.

Differential Revision: [D33933419](https://our.internmc.facebook.com/intern/diff/D33933419/)

[ghstack-poisoned]
wanchaol pushed a commit that referenced this pull request Feb 2, 2022
Pull Request resolved: #72160

This PR switches `shard_parameter` API to use `dist.scatter` instead of `dist.broadcast`. Instead of sending the whole tensor to each rank, we split the tensor beforehand, and only send the part needed to the corresponding rank, which greatly reduce the communication overhead.
ghstack-source-id: 148163484

Differential Revision: [D33933419](https://our.internmc.facebook.com/intern/diff/D33933419/)
Copy link
Contributor

@pritamdamania87 pritamdamania87 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just had one comment regarding the change in distributed_c10d.py.

Comment on lines 2329 to 2335
if scatter_list:
raise ValueError(
"Argument ``scatter_list`` must NOT be specified "
"on non-source ranks."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this an unintended change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is intended change actually, it's to match the behavior in gather where dist.gather allow user to pass in gather_list on non-dst ranks, but we internally treat it as empty list

I think the dist.gather behavior make more sense, user can call the collective on all ranks SPMD, and inside dist.gather we ignore the gather/scatter_list on non-dst/src ranks. Otherwise, user have to write conditional code everytime to perform gather/scatter, i.e. for scatter it must be:

if current_rank == src_rank:
    dist.scatter(local_tensor, scatter_list=tensors_to_scatter, src=src_rank, group=pg)
else:
    dist.scatter(local_tensor, scatter_list=None, src=src_rank, group=pg)    

which might be very in-convenient for user. Our low level primitive: pg.gather/scatter enforce the non-dst/src ranks must pass None.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to make this logic (making scatter/gather list to be None for non_src_rank) centralized somewhere in Python? Since we have enforced this in CPP.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was embedded in dist.gather

output_tensors = [gather_list] if dst == my_rank else []
and dist.scatter,
input_tensors = []
, Im just remove the error raising, are you suggesting we have a util function to deal with that? I guess that's a small piece of code only for gather/scatter, maybe we don't need to generalize it out?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing is that our docs clearly say it should not be specified on src_rank:

        scatter_list (list[Tensor]): List of tensors to scatter (default is
            None, must be specified on the source rank)

So this change would be confusing for users where the docs say something whereas the actual behavior is different.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While both scatter and gather doc does not mention that scatter_list "must ONLY be specified on the source rank", so I guess this part is a bit confusing to the user. Anyways, I think it should be a separate discussion around consistency of our c10d APIs, not related to this PR specifically. I just updated the PR to make sure we are passing None on non-dst/src rank, and remove this unintentional change.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

created a issue to track this conversation #74323

Copy link
Contributor

@fduwjj fduwjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can get rid of broadcast finally. Does GH support diff review meme?

Comment on lines 2329 to 2335
if scatter_list:
raise ValueError(
"Argument ``scatter_list`` must NOT be specified "
"on non-source ranks."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to make this logic (making scatter/gather list to be None for non_src_rank) centralized somewhere in Python? Since we have enforced this in CPP.

Comment on lines 458 to 468
# Reshape to get shard for each rank and we don't want autograd
# recording here for the narrow op and 'tensor_to_scatter' should be a
# leaf variable in the autograd graph.
tensor_to_scatter = tensor.narrow(
sharding_spec.dim, # type: ignore[arg-type]
shard_metadata.shard_offsets[sharding_spec.dim], # type: ignore[union-attr, arg-type, index]
shard_metadata.shard_sizes[sharding_spec.dim], # type: ignore[union-attr, index]
).clone().detach().contiguous()
tensors_to_scatter.append(
tensor_to_scatter
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once your new sharding spec extension PR is ready, I guess the logic here can reuse what's been written in the shard api?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah agreed, this might just be using ChunkShardingSpec.shard


current_offsets[sharding_spec.dim] += chunked_dim_size # type: ignore[index]

# Scatter the shards (use broadcast since NCCL doesn't support scatter, this is very inefficient).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: You can remove this comment now 😊

Copy link
Contributor

@pritamdamania87 pritamdamania87 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this PR needs to be rebased on top of #72130 or vice-versa?

@wanchaol
Copy link
Collaborator Author

yep, plan to rebase soon (hopefully can land that PR then rebase to master directly)

This PR switches `shard_parameter` API to use `dist.scatter` instead of `dist.broadcast`. Instead of sending the whole tensor to each rank, we split the tensor beforehand, and only send the part needed to the corresponding rank, which greatly reduce the communication overhead.

Differential Revision: [D33933419](https://our.internmc.facebook.com/intern/diff/D33933419/)

[ghstack-poisoned]
wanchaol pushed a commit that referenced this pull request Mar 10, 2022
Pull Request resolved: #72160

This PR switches `shard_parameter` API to use `dist.scatter` instead of `dist.broadcast`. Instead of sending the whole tensor to each rank, we split the tensor beforehand, and only send the part needed to the corresponding rank, which greatly reduce the communication overhead.
ghstack-source-id: 150986659

Differential Revision: [D33933419](https://our.internmc.facebook.com/intern/diff/D33933419/)
This PR switches `shard_parameter` API to use `dist.scatter` instead of `dist.broadcast`. Instead of sending the whole tensor to each rank, we split the tensor beforehand, and only send the part needed to the corresponding rank, which greatly reduce the communication overhead.

Differential Revision: [D33933419](https://our.internmc.facebook.com/intern/diff/D33933419/)

[ghstack-poisoned]
wanchaol pushed a commit that referenced this pull request Mar 10, 2022
Pull Request resolved: #72160

This PR switches `shard_parameter` API to use `dist.scatter` instead of `dist.broadcast`. Instead of sending the whole tensor to each rank, we split the tensor beforehand, and only send the part needed to the corresponding rank, which greatly reduce the communication overhead.
ghstack-source-id: 151007579

Differential Revision: [D33933419](https://our.internmc.facebook.com/intern/diff/D33933419/)
Copy link
Contributor

@pritamdamania87 pritamdamania87 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Requesting changes since I think scatter behavior should align with our docs

facebook-github-bot pushed a commit that referenced this pull request Mar 21, 2022
Summary:
Pull Request resolved: #72160

This PR switches `shard_parameter` API to use `dist.scatter` instead of `dist.broadcast`. Instead of sending the whole tensor to each rank, we split the tensor beforehand, and only send the part needed to the corresponding rank, which greatly reduce the communication overhead.
ghstack-source-id: 151643718

Test Plan:
test_shard_parameter
test_shard_parameter_errors

Reviewed By: pritamdamania87

Differential Revision: D33933419

fbshipit-source-id: c823c5d0066a9fe7451c07cbacb30a3bbd361af4
@github-actions
Copy link
Contributor

Hey @wanchaol.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@facebook-github-bot facebook-github-bot deleted the gh/wanchaol/200/head branch March 25, 2022 14:18
shahofblah pushed a commit that referenced this pull request Mar 25, 2022
Summary:
Pull Request resolved: #72160

This PR switches `shard_parameter` API to use `dist.scatter` instead of `dist.broadcast`. Instead of sending the whole tensor to each rank, we split the tensor beforehand, and only send the part needed to the corresponding rank, which greatly reduce the communication overhead.
ghstack-source-id: 151643718

Test Plan:
test_shard_parameter
test_shard_parameter_errors

Reviewed By: pritamdamania87

Differential Revision: D33933419

fbshipit-source-id: c823c5d0066a9fe7451c07cbacb30a3bbd361af4
(cherry picked from commit b1b553e)
fduwjj added a commit that referenced this pull request Apr 6, 2022
Original commit changeset: c823c5d0066a

Original Phabricator Diff: D33933419

#72160 broke some sharded tensor tests. Let's revert it first and reland it again once the test failure has been fixed.

Differential Revision: [D35418031](https://our.internmc.facebook.com/intern/diff/D35418031/)

[ghstack-poisoned]
facebook-github-bot pushed a commit that referenced this pull request Apr 6, 2022
Summary:
Pull Request resolved: #75295

Original commit changeset: c823c5d0066a

Original Phabricator Diff: D33933419 (288de54)

#72160 broke some sharded tensor tests. Let's revert it first and reland it again once the test failure has been fixed.
ghstack-source-id: 153135278

(Note: this ignores all push blocking failures!)

Test Plan: CI

Reviewed By: wanchaol

Differential Revision: D35418031

fbshipit-source-id: a4435a62d4487a927d0cf79624afe4f0951809c8
pytorchmergebot pushed a commit that referenced this pull request Apr 6, 2022
Summary:
Pull Request resolved: #75295

Original commit changeset: c823c5d0066a

Original Phabricator Diff: D33933419 (288de54)

#72160 broke some sharded tensor tests. Let's revert it first and reland it again once the test failure has been fixed.
ghstack-source-id: 153135278

(Note: this ignores all push blocking failures!)

Test Plan: CI

Reviewed By: wanchaol

Differential Revision: D35418031

fbshipit-source-id: a4435a62d4487a927d0cf79624afe4f0951809c8
(cherry picked from commit a20d119)
@facebook-github-bot
Copy link
Contributor

This pull request has been reverted by 57ba615. To re-land this change, please open another pull request, assignthe same reviewers, fix the CI failures that caused the revert and make sure that the failing CI runs on the PR by applying the proper ciflow label (e.g., ciflow/trunk).

NesrineMHB pushed a commit to NesrineMHB/pytorch that referenced this pull request Apr 8, 2022
Original commit changeset: c823c5d0066a

Original Phabricator Diff: D33933419

pytorch/pytorch#72160 broke some sharded tensor tests. Let's revert it first and reland it again once the test failure has been fixed.

Differential Revision: [D35418031](https://our.internmc.facebook.com/intern/diff/D35418031/)

ghstack-source-id: 153135278
Pull Request resolved: pytorch/pytorch#75295
wanchaol pushed a commit that referenced this pull request Apr 18, 2022
This is a reland of PR #72160. The previous PR failed on some cases where un-even scatter happens. So this PR made some additional fixes to ensure it scatters correctly.
1. Fix a bug in ProcessGroupNCCL::scatter, which is a similar issue to #75535
2. resize the shard to the same size before calling `dist.scatter`, and resize it back to the original layout after receiving from scatter.

Differential Revision: [D35726920](https://our.internmc.facebook.com/intern/diff/D35726920/)

[ghstack-poisoned]
wanchaol pushed a commit that referenced this pull request Apr 18, 2022
This is a reland of PR #72160. The previous PR failed on some cases where un-even scatter happens. So this PR made some additional fixes to ensure it scatters correctly.
1. Fix a bug in ProcessGroupNCCL::scatter, which is a similar issue to #75535
2. resize the shard to the same size before calling `dist.scatter`, and resize it back to the original layout after receiving from scatter.

Differential Revision: [D35726920](https://our.internmc.facebook.com/intern/diff/D35726920/)

[ghstack-poisoned]
wanchaol pushed a commit that referenced this pull request Apr 18, 2022
This is a reland of PR #72160. The previous PR failed on some cases where un-even scatter happens. So this PR made some additional fixes to ensure it scatters correctly.
1. Fix a bug in ProcessGroupNCCL::scatter, which is a similar issue to #75535
2. resize the shard to the same size before calling `dist.scatter`, and resize it back to the original layout after receiving from scatter.

Differential Revision: [D35726920](https://our.internmc.facebook.com/intern/diff/D35726920/)

[ghstack-poisoned]
wanchaol pushed a commit that referenced this pull request Apr 18, 2022
Pull Request resolved: #75991

This is a reland of PR #72160. The previous PR failed on some cases where un-even scatter happens. So this PR made some additional fixes to ensure it scatters correctly.
1. Fix a bug in ProcessGroupNCCL::scatter, which is a similar issue to #75535
2. resize the shard to the same size before calling `dist.scatter`, and resize it back to the original layout after receiving from scatter.
ghstack-source-id: 154155925

Differential Revision: [D35726920](https://our.internmc.facebook.com/intern/diff/D35726920/)
wanchaol pushed a commit that referenced this pull request Apr 25, 2022
This is a reland of PR #72160. The previous PR failed on some cases where un-even scatter happens. So this PR made some additional fixes to ensure it scatters correctly.
1. Fix a bug in ProcessGroupNCCL::scatter, which is a similar issue to #75535
2. resize the shard to the same size before calling `dist.scatter`, and resize it back to the original layout after receiving from scatter.

Differential Revision: [D35726920](https://our.internmc.facebook.com/intern/diff/D35726920/)

[ghstack-poisoned]
wanchaol pushed a commit that referenced this pull request Apr 25, 2022
Pull Request resolved: #75991

This is a reland of PR #72160. The previous PR failed on some cases where un-even scatter happens. So this PR made some additional fixes to ensure it scatters correctly.
1. Fix a bug in ProcessGroupNCCL::scatter, which is a similar issue to #75535
2. resize the shard to the same size before calling `dist.scatter`, and resize it back to the original layout after receiving from scatter.
ghstack-source-id: 154682435

Differential Revision: [D35726920](https://our.internmc.facebook.com/intern/diff/D35726920/)
wanchaol pushed a commit that referenced this pull request Apr 25, 2022
This is a reland of PR #72160. The previous PR failed on some cases where un-even scatter happens. So this PR made some additional fixes to ensure it scatters correctly.
1. Fix a bug in ProcessGroupNCCL::scatter, which is a similar issue to #75535
2. resize the shard to the same size before calling `dist.scatter`, and resize it back to the original layout after receiving from scatter.

Differential Revision: [D35726920](https://our.internmc.facebook.com/intern/diff/D35726920/)

[ghstack-poisoned]
wanchaol pushed a commit that referenced this pull request Apr 25, 2022
This is a reland of PR #72160. The previous PR failed on some cases where un-even scatter happens. So this PR made some additional fixes to ensure it scatters correctly.
1. Fix a bug in ProcessGroupNCCL::scatter, which is a similar issue to #75535
2. resize the shard to the same size before calling `dist.scatter`, and resize it back to the original layout after receiving from scatter.

Differential Revision: [D35726920](https://our.internmc.facebook.com/intern/diff/D35726920/)

[ghstack-poisoned]
wanchaol pushed a commit that referenced this pull request Apr 25, 2022
Pull Request resolved: #75991

This is a reland of PR #72160. The previous PR failed on some cases where un-even scatter happens. So this PR made some additional fixes to ensure it scatters correctly.
1. Fix a bug in ProcessGroupNCCL::scatter, which is a similar issue to #75535
2. resize the shard to the same size before calling `dist.scatter`, and resize it back to the original layout after receiving from scatter.
ghstack-source-id: 154725305

Differential Revision: [D35726920](https://our.internmc.facebook.com/intern/diff/D35726920/)
wanchaol pushed a commit that referenced this pull request Apr 25, 2022
This is a reland of PR #72160. The previous PR failed on some cases where un-even scatter happens. So this PR made some additional fixes to ensure it scatters correctly.
1. Fix a bug in ProcessGroupNCCL::scatter, which is a similar issue to #75535
2. resize the shard to the same size before calling `dist.scatter`, and resize it back to the original layout after receiving from scatter.

Differential Revision: [D35726920](https://our.internmc.facebook.com/intern/diff/D35726920/)

[ghstack-poisoned]
facebook-github-bot pushed a commit that referenced this pull request Apr 25, 2022
Summary:
Pull Request resolved: #75991

This is a reland of PR #72160. The previous PR failed on some cases where un-even scatter happens. So this PR made some additional fixes to ensure it scatters correctly.
1. Fix a bug in ProcessGroupNCCL::scatter, which is a similar issue to #75535
2. resize the shard to the same size before calling `dist.scatter`, and resize it back to the original layout after receiving from scatter.
ghstack-source-id: 154725305

Test Plan:
test_sharded_tensor
test_linear
test_megatron_prototype
test_embedding/embeddingbag

Reviewed By: pritamdamania87

Differential Revision: D35726920

fbshipit-source-id: d9bd0e44f47ef5b9e9add0dc66c5fda99e93943a
pytorchmergebot pushed a commit that referenced this pull request Apr 25, 2022
Summary:
Pull Request resolved: #75991

This is a reland of PR #72160. The previous PR failed on some cases where un-even scatter happens. So this PR made some additional fixes to ensure it scatters correctly.
1. Fix a bug in ProcessGroupNCCL::scatter, which is a similar issue to #75535
2. resize the shard to the same size before calling `dist.scatter`, and resize it back to the original layout after receiving from scatter.
ghstack-source-id: 154725305

Test Plan:
test_sharded_tensor
test_linear
test_megatron_prototype
test_embedding/embeddingbag

Reviewed By: pritamdamania87

Differential Revision: D35726920

fbshipit-source-id: d9bd0e44f47ef5b9e9add0dc66c5fda99e93943a
(cherry picked from commit ed11e5d)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed oncall: distributed Add this issue/PR to distributed oncall triage queue Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants