Skip to content

Conversation

@rohan-varma
Copy link
Contributor

@rohan-varma rohan-varma commented Nov 30, 2022

Stack from ghstack (oldest at bottom):

Adds 2 new hybrid sharding strategy to FSDP:

  1. HYBRID_SHARD: applies zero-3 style sharding within a node, and data parallel across
  2. HYBRID_SHARD_ZERO2: applies zero-2 style sharding within a node, and data parallel across

These are useful for medium sized models and aim to decrease communication volume, tests and benchmarks will be run to understand which workloads are optimal under which sharding strategy.

Hybrid sharding in general works by sharding the model using a process group within a single node, and creating intra-node process groups for replication / data parallelism. The user either needs to pass in a tuple of these process groups, or None, and we generate the process groups appropriately.

** Acknowledgements **

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 30, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/89915

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 Failures

As of commit 4e9638a:

The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: distributed (fsdp) release notes category label Nov 30, 2022
rohan-varma added a commit that referenced this pull request Nov 30, 2022
ghstack-source-id: 93f5a35
Pull Request resolved: #89915
[ghstack-poisoned]
@rohan-varma rohan-varma changed the title HSDP Hybrid Sharded Data Parallel Dec 1, 2022
Adds 2 new hybrid sharding strategy to FSDP:
1. HYBRID_SHARD: applies zero-3 style sharding within a node, and data parallel across
2. HYBRID_SHARD_ZERO2: applies zero-2 style sharding within a node, and data parallel across

These are useful for medium sized models and aim to decrease communication volume, tests and benchmarks will be run to understand which workloads are optimal under which sharding strategy.

Hybrid sharding in general works by sharding the model using a process group within a single node, and creating intra-node process groups for replication / data parallelism. The user either needs to pass in a tuple of these process groups, or None, and we generate the process groups appropriately.

** Acknowledgements **
- awgu 's excellent prototype: awgu@5ad3a16
- liangluofb For ideation, feedback, and initial implementation and experimentation

[ghstack-poisoned]
Adds 2 new hybrid sharding strategy to FSDP:
1. HYBRID_SHARD: applies zero-3 style sharding within a node, and data parallel across
2. HYBRID_SHARD_ZERO2: applies zero-2 style sharding within a node, and data parallel across

These are useful for medium sized models and aim to decrease communication volume, tests and benchmarks will be run to understand which workloads are optimal under which sharding strategy.

Hybrid sharding in general works by sharding the model using a process group within a single node, and creating intra-node process groups for replication / data parallelism. The user either needs to pass in a tuple of these process groups, or None, and we generate the process groups appropriately.

** Acknowledgements **
- awgu 's excellent prototype: awgu@5ad3a16
- liangluofb For ideation, feedback, and initial implementation and experimentation

[ghstack-poisoned]
Adds 2 new hybrid sharding strategy to FSDP:
1. HYBRID_SHARD: applies zero-3 style sharding within a node, and data parallel across
2. HYBRID_SHARD_ZERO2: applies zero-2 style sharding within a node, and data parallel across

These are useful for medium sized models and aim to decrease communication volume, tests and benchmarks will be run to understand which workloads are optimal under which sharding strategy.

Hybrid sharding in general works by sharding the model using a process group within a single node, and creating intra-node process groups for replication / data parallelism. The user either needs to pass in a tuple of these process groups, or None, and we generate the process groups appropriately.

** Acknowledgements **
- awgu 's excellent prototype: awgu@5ad3a16
- liangluofb For ideation, feedback, and initial implementation and experimentation

[ghstack-poisoned]
Adds 2 new hybrid sharding strategy to FSDP:
1. HYBRID_SHARD: applies zero-3 style sharding within a node, and data parallel across
2. HYBRID_SHARD_ZERO2: applies zero-2 style sharding within a node, and data parallel across

These are useful for medium sized models and aim to decrease communication volume, tests and benchmarks will be run to understand which workloads are optimal under which sharding strategy.

Hybrid sharding in general works by sharding the model using a process group within a single node, and creating intra-node process groups for replication / data parallelism. The user either needs to pass in a tuple of these process groups, or None, and we generate the process groups appropriately.

** Acknowledgements **
- awgu 's excellent prototype: awgu@5ad3a16
- liangluofb For ideation, feedback, and initial implementation and experimentation

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Dec 1, 2022
ghstack-source-id: aee66d4
Pull Request resolved: #89915
Adds 2 new hybrid sharding strategy to FSDP:
1. HYBRID_SHARD: applies zero-3 style sharding within a node, and data parallel across
2. HYBRID_SHARD_ZERO2: applies zero-2 style sharding within a node, and data parallel across

These are useful for medium sized models and aim to decrease communication volume, tests and benchmarks will be run to understand which workloads are optimal under which sharding strategy.

Hybrid sharding in general works by sharding the model using a process group within a single node, and creating intra-node process groups for replication / data parallelism. The user either needs to pass in a tuple of these process groups, or None, and we generate the process groups appropriately.

** Acknowledgements **
- awgu 's excellent prototype: awgu@5ad3a16
- liangluofb For ideation, feedback, and initial implementation and experimentation

[ghstack-poisoned]
Adds 2 new hybrid sharding strategy to FSDP:
1. HYBRID_SHARD: applies zero-3 style sharding within a node, and data parallel across
2. HYBRID_SHARD_ZERO2: applies zero-2 style sharding within a node, and data parallel across

These are useful for medium sized models and aim to decrease communication volume, tests and benchmarks will be run to understand which workloads are optimal under which sharding strategy.

Hybrid sharding in general works by sharding the model using a process group within a single node, and creating intra-node process groups for replication / data parallelism. The user either needs to pass in a tuple of these process groups, or None, and we generate the process groups appropriately.

** Acknowledgements **
- awgu 's excellent prototype: awgu@5ad3a16
- liangluofb For ideation, feedback, and initial implementation and experimentation

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Dec 2, 2022
ghstack-source-id: 9d6258e
Pull Request resolved: #89915
Copy link
Collaborator

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made an initial pass and left a lot of nitpicks. I will read the test code in a follow-up pass, possibly after you respond to some of the comments.

# FSDP module directly
submodule._fsdp_use_orig_params = use_orig_params

# Initializes self.process_group, along with rank and world size. This will
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Personally, I do not like explaining what a function/method call does inline like this since this creates redundancy, which can go stale if only one place is updated. The developer should read the docstring for _init_process_group_state.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Easy] I recommend changing before landing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mostly want to emphasize the part a couple lines later that mentions this is done before auto wrapping, and the logic for why, which I think is vaulable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good.

@awgu awgu self-requested a review December 4, 2022 00:32
@awgu
Copy link
Collaborator

awgu commented Dec 4, 2022

Just as a heads up, it looks like test failures are real:

======================================================================
ERROR [0.002s]: test_wrap_wrap_method_WrapMethod_WRAP_API (__main__.TestAutoWrap)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 247, in instantiated_test
    test(self, **param_kwargs)
  File "/var/lib/jenkins/workspace/test/distributed/fsdp/test_wrap.py", line 319, in test_wrap
    layer = wrap(nn.Linear(5, 5))
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 295, in wrap
    return _wrap(
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/fsdp/wrap.py", line 313, in _wrap
    return wrapper_cls(module, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 376, in __init__
    _init_process_group_state(self, process_group, sharding_strategy, auto_wrap_policy)
  File "/opt/conda/lib/python3.10/site-packages/torch/distributed/fsdp/_init_utils.py", line 110, in _init_process_group_state
    raise ValueError(
ValueError: process_group should be None or dist.ProcessGroup, but got <class 'torch.testing._internal.common_fsdp.DummyProcessGroup'>

I am not sure if we can make DummyProcessGroup inherit from dist.ProcessGroup as a fix.

@rohan-varma rohan-varma requested a review from awgu December 7, 2022 20:00
nodes. This results in reduced communication volume as expensive all-gathers and
reduce-scatters are only done within a node, which can be more performant for medium
-sized models.
- ``_HYBRID_SHARD_ZERO2``: Apply ``SHARD_GRAD_OP`` within a node, and replicate parameters across
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we should omit this from docstring for now

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think that would be good to be safe.

Copy link
Collaborator

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a final pass of nits, and we can land after that.

)
else:
state = _init_process_group_state_for_hybrid_shard(state, process_group)
assert state.process_group is not None, "Expected to populate state.process_group for hybrid shard"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the asserts should be at the end of _init_process_group_state_for_hybrid_shard then, representing a post-condition. This is just my personal design preference and probably does not matter in this case. However, in general, you would have to enforce the post-condition upon each call to _init_process_group_state_for_hybrid_shard.

(This is also how I approach post-conditions and invariants in general -- they should be coupled to the method/function itself, not their usages.)

Adds 2 new hybrid sharding strategy to FSDP:
1. HYBRID_SHARD: applies zero-3 style sharding within a node, and data parallel across
2. HYBRID_SHARD_ZERO2: applies zero-2 style sharding within a node, and data parallel across

These are useful for medium sized models and aim to decrease communication volume, tests and benchmarks will be run to understand which workloads are optimal under which sharding strategy.

Hybrid sharding in general works by sharding the model using a process group within a single node, and creating intra-node process groups for replication / data parallelism. The user either needs to pass in a tuple of these process groups, or None, and we generate the process groups appropriately.

** Acknowledgements **
- awgu 's excellent prototype: awgu@5ad3a16
- liangluofb For ideation, feedback, and initial implementation and experimentation

[ghstack-poisoned]
@rohan-varma rohan-varma requested a review from awgu December 7, 2022 21:48
Copy link
Collaborator

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me! Awesome work, and thanks for fixing all of the nits. I am very excited to see the experiment results and downstream impact!

# Owner(s): ["oncall: distributed"]

import contextlib
import functools
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import functools

Adds 2 new hybrid sharding strategy to FSDP:
1. HYBRID_SHARD: applies zero-3 style sharding within a node, and data parallel across
2. HYBRID_SHARD_ZERO2: applies zero-2 style sharding within a node, and data parallel across

These are useful for medium sized models and aim to decrease communication volume, tests and benchmarks will be run to understand which workloads are optimal under which sharding strategy.

Hybrid sharding in general works by sharding the model using a process group within a single node, and creating intra-node process groups for replication / data parallelism. The user either needs to pass in a tuple of these process groups, or None, and we generate the process groups appropriately.

** Acknowledgements **
- awgu 's excellent prototype: awgu@5ad3a16
- liangluofb For ideation, feedback, and initial implementation and experimentation

[ghstack-poisoned]
@rohan-varma rohan-varma added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 7, 2022
Adds 2 new hybrid sharding strategy to FSDP:
1. HYBRID_SHARD: applies zero-3 style sharding within a node, and data parallel across
2. HYBRID_SHARD_ZERO2: applies zero-2 style sharding within a node, and data parallel across

These are useful for medium sized models and aim to decrease communication volume, tests and benchmarks will be run to understand which workloads are optimal under which sharding strategy.

Hybrid sharding in general works by sharding the model using a process group within a single node, and creating intra-node process groups for replication / data parallelism. The user either needs to pass in a tuple of these process groups, or None, and we generate the process groups appropriately.

** Acknowledgements **
- awgu 's excellent prototype: awgu@5ad3a16
- liangluofb For ideation, feedback, and initial implementation and experimentation

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Dec 8, 2022
ghstack-source-id: 29e4096
Pull Request resolved: #89915
Adds 2 new hybrid sharding strategy to FSDP:
1. HYBRID_SHARD: applies zero-3 style sharding within a node, and data parallel across
2. HYBRID_SHARD_ZERO2: applies zero-2 style sharding within a node, and data parallel across

These are useful for medium sized models and aim to decrease communication volume, tests and benchmarks will be run to understand which workloads are optimal under which sharding strategy.

Hybrid sharding in general works by sharding the model using a process group within a single node, and creating intra-node process groups for replication / data parallelism. The user either needs to pass in a tuple of these process groups, or None, and we generate the process groups appropriately.

** Acknowledgements **
- awgu 's excellent prototype: awgu@5ad3a16
- liangluofb For ideation, feedback, and initial implementation and experimentation

[ghstack-poisoned]
rohan-varma added a commit that referenced this pull request Dec 8, 2022
ghstack-source-id: c063e71
Pull Request resolved: #89915
@rohan-varma
Copy link
Contributor Author

CI failures are related to autocast and are unrelated to this PR.

@rohan-varma
Copy link
Contributor Author

@pytorchbot merge -f "CI failures unrelated"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@jeffdaily
Copy link
Collaborator

This PR broke ROCm CI periodic jobs, where the distributed tests get run.

======================================================================
ERROR [4.530s]: test_fsdp_hybrid_shard_basic_setup (__main__.TestFSDPHybridShard)
Tests basic functionality of HYBRID_SHARD and _HYBRID_SHARD_ZERO2:
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/torch/testing/_internal/common_distributed.py", line 533, in wrapper
    self._join_processes(fn)
  File "/opt/conda/lib/python3.8/site-packages/torch/testing/_internal/common_distributed.py", line 759, in _join_processes
    self._check_return_codes(elapsed_time)
  File "/opt/conda/lib/python3.8/site-packages/torch/testing/_internal/common_distributed.py", line 804, in _check_return_codes
    raise RuntimeError(error)
RuntimeError: Process 0 exited with error code 10 and exception:
Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/torch/testing/_internal/common_distributed.py", line 657, in run_test
    getattr(self, test_name)()
  File "/opt/conda/lib/python3.8/site-packages/torch/testing/_internal/common_distributed.py", line 535, in wrapper
    fn()
  File "/opt/conda/lib/python3.8/site-packages/torch/testing/_internal/common_distributed.py", line 166, in wrapper
    return func(*args, **kwargs)
  File "/var/lib/jenkins/pytorch/test/distributed/fsdp/test_fsdp_hybrid_shard.py", line 226, in test_fsdp_hybrid_shard_basic_setup
    with (
AttributeError: __enter__

@jeffdaily
Copy link
Collaborator

Not sure how this PR only broke ROCm CI, unless this test is getting skipped on other platforms?

The test is using the context managers from earlier in the file, so not sure why __enter__ isn't an attribute?

https://github.com/pytorch/pytorch/pull/89915/files#diff-19a0c73c8366fbd896a68ad90a2f4b3515f0fa6dc0a61528a2780e418b3f5e92R44

@awgu
Copy link
Collaborator

awgu commented Dec 9, 2022

Not sure how this PR only broke ROCm CI, unless this test is getting skipped on other platforms?

The test is using the context managers from earlier in the file, so not sure why __enter__ isn't an attribute?

https://github.com/pytorch/pytorch/pull/89915/files#diff-19a0c73c8366fbd896a68ad90a2f4b3515f0fa6dc0a61528a2780e418b3f5e92R44

This could be because of a Python versioning issue. Parenthesized context managers may have only been added in Python 3.10:
https://docs.python.org/3.10/whatsnew/3.10.html#parenthesized-context-managers
https://www.blog.pythonlibrary.org/2021/09/08/python-3-10-parenthesized-context-managers/

In other words, the following syntax was not permitted until 3.10:

with (
    patch_allreduce(patched_allreduce),
    patch_reduce_scatter(patched_reduce_scatter),
):

This could be why the __enter__ is not being detected properly.

@jeffdaily
Copy link
Collaborator

@awgu I just figured out the same. PR to fix. #90580

kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
Adds 2 new hybrid sharding strategy to FSDP:
1. HYBRID_SHARD: applies zero-3 style sharding within a node, and data parallel across
2. HYBRID_SHARD_ZERO2: applies zero-2 style sharding within a node, and data parallel across

These are useful for medium sized models and aim to decrease communication volume, tests and benchmarks will be run to understand which workloads are optimal under which sharding strategy.

Hybrid sharding in general works by sharding the model using a process group within a single node, and creating intra-node process groups for replication / data parallelism. The user either needs to pass in a tuple of these process groups, or None, and we generate the process groups appropriately.

** Acknowledgements **
- @awgu 's excellent prototype: awgu@5ad3a16
- @liangluofb For ideation, feedback, and initial implementation and experimentation
Pull Request resolved: pytorch#89915
Approved by: https://github.com/awgu
pytorchmergebot pushed a commit that referenced this pull request Dec 12, 2022
Fixes PR #89915.  The following syntax was not permitted until 3.10:

```
with (
    patch_allreduce(patched_allreduce),
    patch_reduce_scatter(patched_reduce_scatter),
):
```

Pull Request resolved: #90580
Approved by: https://github.com/awgu
@facebook-github-bot facebook-github-bot deleted the gh/rohan-varma/617/head branch June 8, 2023 18:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (fsdp) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants