-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Hybrid Sharded Data Parallel #89915
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hybrid Sharded Data Parallel #89915
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/89915
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 FailuresAs of commit 4e9638a: The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
Adds 2 new hybrid sharding strategy to FSDP: 1. HYBRID_SHARD: applies zero-3 style sharding within a node, and data parallel across 2. HYBRID_SHARD_ZERO2: applies zero-2 style sharding within a node, and data parallel across These are useful for medium sized models and aim to decrease communication volume, tests and benchmarks will be run to understand which workloads are optimal under which sharding strategy. Hybrid sharding in general works by sharding the model using a process group within a single node, and creating intra-node process groups for replication / data parallelism. The user either needs to pass in a tuple of these process groups, or None, and we generate the process groups appropriately. ** Acknowledgements ** - awgu 's excellent prototype: awgu@5ad3a16 - liangluofb For ideation, feedback, and initial implementation and experimentation [ghstack-poisoned]
Adds 2 new hybrid sharding strategy to FSDP: 1. HYBRID_SHARD: applies zero-3 style sharding within a node, and data parallel across 2. HYBRID_SHARD_ZERO2: applies zero-2 style sharding within a node, and data parallel across These are useful for medium sized models and aim to decrease communication volume, tests and benchmarks will be run to understand which workloads are optimal under which sharding strategy. Hybrid sharding in general works by sharding the model using a process group within a single node, and creating intra-node process groups for replication / data parallelism. The user either needs to pass in a tuple of these process groups, or None, and we generate the process groups appropriately. ** Acknowledgements ** - awgu 's excellent prototype: awgu@5ad3a16 - liangluofb For ideation, feedback, and initial implementation and experimentation [ghstack-poisoned]
Adds 2 new hybrid sharding strategy to FSDP: 1. HYBRID_SHARD: applies zero-3 style sharding within a node, and data parallel across 2. HYBRID_SHARD_ZERO2: applies zero-2 style sharding within a node, and data parallel across These are useful for medium sized models and aim to decrease communication volume, tests and benchmarks will be run to understand which workloads are optimal under which sharding strategy. Hybrid sharding in general works by sharding the model using a process group within a single node, and creating intra-node process groups for replication / data parallelism. The user either needs to pass in a tuple of these process groups, or None, and we generate the process groups appropriately. ** Acknowledgements ** - awgu 's excellent prototype: awgu@5ad3a16 - liangluofb For ideation, feedback, and initial implementation and experimentation [ghstack-poisoned]
Adds 2 new hybrid sharding strategy to FSDP: 1. HYBRID_SHARD: applies zero-3 style sharding within a node, and data parallel across 2. HYBRID_SHARD_ZERO2: applies zero-2 style sharding within a node, and data parallel across These are useful for medium sized models and aim to decrease communication volume, tests and benchmarks will be run to understand which workloads are optimal under which sharding strategy. Hybrid sharding in general works by sharding the model using a process group within a single node, and creating intra-node process groups for replication / data parallelism. The user either needs to pass in a tuple of these process groups, or None, and we generate the process groups appropriately. ** Acknowledgements ** - awgu 's excellent prototype: awgu@5ad3a16 - liangluofb For ideation, feedback, and initial implementation and experimentation [ghstack-poisoned]
Adds 2 new hybrid sharding strategy to FSDP: 1. HYBRID_SHARD: applies zero-3 style sharding within a node, and data parallel across 2. HYBRID_SHARD_ZERO2: applies zero-2 style sharding within a node, and data parallel across These are useful for medium sized models and aim to decrease communication volume, tests and benchmarks will be run to understand which workloads are optimal under which sharding strategy. Hybrid sharding in general works by sharding the model using a process group within a single node, and creating intra-node process groups for replication / data parallelism. The user either needs to pass in a tuple of these process groups, or None, and we generate the process groups appropriately. ** Acknowledgements ** - awgu 's excellent prototype: awgu@5ad3a16 - liangluofb For ideation, feedback, and initial implementation and experimentation [ghstack-poisoned]
Adds 2 new hybrid sharding strategy to FSDP: 1. HYBRID_SHARD: applies zero-3 style sharding within a node, and data parallel across 2. HYBRID_SHARD_ZERO2: applies zero-2 style sharding within a node, and data parallel across These are useful for medium sized models and aim to decrease communication volume, tests and benchmarks will be run to understand which workloads are optimal under which sharding strategy. Hybrid sharding in general works by sharding the model using a process group within a single node, and creating intra-node process groups for replication / data parallelism. The user either needs to pass in a tuple of these process groups, or None, and we generate the process groups appropriately. ** Acknowledgements ** - awgu 's excellent prototype: awgu@5ad3a16 - liangluofb For ideation, feedback, and initial implementation and experimentation [ghstack-poisoned]
awgu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made an initial pass and left a lot of nitpicks. I will read the test code in a follow-up pass, possibly after you respond to some of the comments.
| # FSDP module directly | ||
| submodule._fsdp_use_orig_params = use_orig_params | ||
|
|
||
| # Initializes self.process_group, along with rank and world size. This will |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Personally, I do not like explaining what a function/method call does inline like this since this creates redundancy, which can go stale if only one place is updated. The developer should read the docstring for _init_process_group_state.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[Easy] I recommend changing before landing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mostly want to emphasize the part a couple lines later that mentions this is done before auto wrapping, and the logic for why, which I think is vaulable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good.
|
Just as a heads up, it looks like test failures are real: I am not sure if we can make |
| nodes. This results in reduced communication volume as expensive all-gathers and | ||
| reduce-scatters are only done within a node, which can be more performant for medium | ||
| -sized models. | ||
| - ``_HYBRID_SHARD_ZERO2``: Apply ``SHARD_GRAD_OP`` within a node, and replicate parameters across |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we should omit this from docstring for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think that would be good to be safe.
awgu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a final pass of nits, and we can land after that.
| ) | ||
| else: | ||
| state = _init_process_group_state_for_hybrid_shard(state, process_group) | ||
| assert state.process_group is not None, "Expected to populate state.process_group for hybrid shard" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the asserts should be at the end of _init_process_group_state_for_hybrid_shard then, representing a post-condition. This is just my personal design preference and probably does not matter in this case. However, in general, you would have to enforce the post-condition upon each call to _init_process_group_state_for_hybrid_shard.
(This is also how I approach post-conditions and invariants in general -- they should be coupled to the method/function itself, not their usages.)
Adds 2 new hybrid sharding strategy to FSDP: 1. HYBRID_SHARD: applies zero-3 style sharding within a node, and data parallel across 2. HYBRID_SHARD_ZERO2: applies zero-2 style sharding within a node, and data parallel across These are useful for medium sized models and aim to decrease communication volume, tests and benchmarks will be run to understand which workloads are optimal under which sharding strategy. Hybrid sharding in general works by sharding the model using a process group within a single node, and creating intra-node process groups for replication / data parallelism. The user either needs to pass in a tuple of these process groups, or None, and we generate the process groups appropriately. ** Acknowledgements ** - awgu 's excellent prototype: awgu@5ad3a16 - liangluofb For ideation, feedback, and initial implementation and experimentation [ghstack-poisoned]
awgu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great to me! Awesome work, and thanks for fixing all of the nits. I am very excited to see the experiment results and downstream impact!
| # Owner(s): ["oncall: distributed"] | ||
|
|
||
| import contextlib | ||
| import functools |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| import functools |
Adds 2 new hybrid sharding strategy to FSDP: 1. HYBRID_SHARD: applies zero-3 style sharding within a node, and data parallel across 2. HYBRID_SHARD_ZERO2: applies zero-2 style sharding within a node, and data parallel across These are useful for medium sized models and aim to decrease communication volume, tests and benchmarks will be run to understand which workloads are optimal under which sharding strategy. Hybrid sharding in general works by sharding the model using a process group within a single node, and creating intra-node process groups for replication / data parallelism. The user either needs to pass in a tuple of these process groups, or None, and we generate the process groups appropriately. ** Acknowledgements ** - awgu 's excellent prototype: awgu@5ad3a16 - liangluofb For ideation, feedback, and initial implementation and experimentation [ghstack-poisoned]
Adds 2 new hybrid sharding strategy to FSDP: 1. HYBRID_SHARD: applies zero-3 style sharding within a node, and data parallel across 2. HYBRID_SHARD_ZERO2: applies zero-2 style sharding within a node, and data parallel across These are useful for medium sized models and aim to decrease communication volume, tests and benchmarks will be run to understand which workloads are optimal under which sharding strategy. Hybrid sharding in general works by sharding the model using a process group within a single node, and creating intra-node process groups for replication / data parallelism. The user either needs to pass in a tuple of these process groups, or None, and we generate the process groups appropriately. ** Acknowledgements ** - awgu 's excellent prototype: awgu@5ad3a16 - liangluofb For ideation, feedback, and initial implementation and experimentation [ghstack-poisoned]
Adds 2 new hybrid sharding strategy to FSDP: 1. HYBRID_SHARD: applies zero-3 style sharding within a node, and data parallel across 2. HYBRID_SHARD_ZERO2: applies zero-2 style sharding within a node, and data parallel across These are useful for medium sized models and aim to decrease communication volume, tests and benchmarks will be run to understand which workloads are optimal under which sharding strategy. Hybrid sharding in general works by sharding the model using a process group within a single node, and creating intra-node process groups for replication / data parallelism. The user either needs to pass in a tuple of these process groups, or None, and we generate the process groups appropriately. ** Acknowledgements ** - awgu 's excellent prototype: awgu@5ad3a16 - liangluofb For ideation, feedback, and initial implementation and experimentation [ghstack-poisoned]
|
CI failures are related to autocast and are unrelated to this PR. |
|
@pytorchbot merge -f "CI failures unrelated" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
This PR broke ROCm CI periodic jobs, where the distributed tests get run. |
|
Not sure how this PR only broke ROCm CI, unless this test is getting skipped on other platforms? The test is using the context managers from earlier in the file, so not sure why |
This could be because of a Python versioning issue. Parenthesized context managers may have only been added in Python 3.10: In other words, the following syntax was not permitted until 3.10: This could be why the |
Adds 2 new hybrid sharding strategy to FSDP: 1. HYBRID_SHARD: applies zero-3 style sharding within a node, and data parallel across 2. HYBRID_SHARD_ZERO2: applies zero-2 style sharding within a node, and data parallel across These are useful for medium sized models and aim to decrease communication volume, tests and benchmarks will be run to understand which workloads are optimal under which sharding strategy. Hybrid sharding in general works by sharding the model using a process group within a single node, and creating intra-node process groups for replication / data parallelism. The user either needs to pass in a tuple of these process groups, or None, and we generate the process groups appropriately. ** Acknowledgements ** - @awgu 's excellent prototype: awgu@5ad3a16 - @liangluofb For ideation, feedback, and initial implementation and experimentation Pull Request resolved: pytorch#89915 Approved by: https://github.com/awgu
Fixes PR #89915. The following syntax was not permitted until 3.10: ``` with ( patch_allreduce(patched_allreduce), patch_reduce_scatter(patched_reduce_scatter), ): ``` Pull Request resolved: #90580 Approved by: https://github.com/awgu
Stack from ghstack (oldest at bottom):
Adds 2 new hybrid sharding strategy to FSDP:
These are useful for medium sized models and aim to decrease communication volume, tests and benchmarks will be run to understand which workloads are optimal under which sharding strategy.
Hybrid sharding in general works by sharding the model using a process group within a single node, and creating intra-node process groups for replication / data parallelism. The user either needs to pass in a tuple of these process groups, or None, and we generate the process groups appropriately.
** Acknowledgements **