Skip to content

Conversation

@awgu
Copy link
Collaborator

@awgu awgu commented May 3, 2024

Stack from ghstack (oldest at bottom):

Context
We are interested in supporting the case where HSDP reduce-scatters but does not all-reduce in a microbatch backward. This saves communication while still saving memory. Only on the last microbatch do we need to both reduce-scatter and all-reduce. This is not implemented yet and will hopefully come in a future PR.

There is one notable part of doing this. On the last microbatch, we need to perform an accumulation step after reduce-scatter and before all-reduce. If not, then the preceding microbatch's gradients will not be contributed across the replica group. (In other words, we cannot simply accumulate after all-reduce.)

Consider 32 GPUs with 4-way replication and 8-way sharding and 2 microbatches, and focus on global rank 0.

  • After the first microbatch, rank 0 will have its shard of $\frac{1}{8} \sum_{i \in S(0)} g_i^{(1)}$, where we define $S(0) = {0, 1, \dots, 7}$ to be the ranks in its shard group and we define the $(1)$ superscript to denote the first microbatch.
  • Upon the second microbatch, rank 0 after its reduce-scatter will additionally have its shard of $\frac{1}{8} \sum_{i \in S(0)} g_i^{(2)}$. If we only all-reduce this, then this second microbatch's gradients become $\frac{1}{32} \sum_{i=0, 1, \dots, 31} g_i^{(2)}$, so in total, rank 0 has $\frac{1}{8} \sum_{i \in S(0)} g_i^{(1)} + \frac{1}{32} \sum_{i=0, 1, \dots, 31} g_i^{(2)}$, which is wrong.
  • Importantly, we must accumulate $\frac{1}{8} \sum_{i \in S(0)} g_i^{(1)} + \frac{1}{8} \sum_{i \in S(0)} g_i^{(2)} = \frac{1}{8}\sum_{i \in S(0)} (g_i^{(1)} + g_i^{(2)})$ first before all-reducing to get $\frac{1}{32} \sum_{i=0, 1, \dots, 31} (g_i^{(1)} + g_i^{(2)})$.

Now, note how under this approach, we want a factor of $\frac{1}{8}$ only (i.e. reciprocal of the shard group size), not $\frac{1}{32}$, for the first microbatch's gradients.

  • For bf16/fp32, since we use ReduceOp.AVG and we only reduce-scatter on the first microbatch, we correctly have a factor of $\frac{1}{8}$ on the first microbatch.
  • For fp16, since we precompute the gradient divide factors at init time assuming always reducing over both shard and replica groups, we incorrectly have a factor of $\frac{1}{32}$ on the first microbatch, deviating from the bf16/fp32 case.

We can address this issue by matching the bf16/fp32 vs. fp16 semantics by computing the divide factors at runtime based on which process groups were passed into the reduction function (foreach_reduce).

Additional Notes
How to implement the HSDP reduce-scatter but no all-reduce is not entirely clear yet. (What is the cleanest way to do this?) We need to store the partial reduce-scatter output and check for it upon the next backward. We should also be sure to error if the set of parameters receiving gradients changes, in which case we cannot support this easily. Anyway, we will implement this in a follow-up.

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k

@pytorch-bot
Copy link

pytorch-bot bot commented May 3, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125484

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit bccac39 with merge base b03fb49 (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ci-td-distributed oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category labels May 3, 2024
cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
@awgu awgu added release notes: distributed (fsdp2) release notes category and removed release notes: distributed (fsdp) release notes category labels May 3, 2024
**Context**
We are interested in supporting the case where HSDP reduce-scatters but does not all-reduce in a microbatch backward. This saves communication while still saving memory. Only on the last microbatch do we need to both reduce-scatter and all-reduce. This is not implemented yet and will hopefully come in a future PR.

There is one notable part of doing this. On the last microbatch, we need to perform an accumulation step after reduce-scatter and before all-reduce. If not, then the preceding microbatch's gradients will not be contributed across the replica group. (In other words, we cannot simply accumulate _after_ all-reduce.)

Consider 32 GPUs with 4-way replication and 8-way sharding and 2 microbatches, and focus on global rank 0.
- After the first microbatch, rank 0 will have its shard of $\frac{1}{8} \sum_{i \in S(0)} g_i^{(1)}$, where define $S(0) = \{0, 1, \dots, 7\}$ to be the ranks in its shard group and the $(1)$ superscript to denote the first microbatch.
- Upon the second microbatch, rank 0 after its reduce-scatter will newly have its shard of $\frac{1}{8} \sum_{i \in S(0)} g_i^{(2)}$. If we only all-reduce this, then this second microbatch's gradients become $\frac{1}{32} \sum_{i=0, 1, \dots, 31} g_i^{(2)}$, so in total, rank 0 has $\frac{1}{8} \sum_{i \in S(0)} g_i^{(1)} + \frac{1}{32} \sum_{i=0, 1, \dots, 31} g_i^{(2)}$, which is wrong.
- Importantly, we must accumulate $\frac{1}{8} \sum_{i \in S(0)} g_i^{(1)}  + \frac{1}{8} \sum_{i \in S(0)} g_i^{(2)} = \frac{1}{8}\sum_{i \in S(0)} (g_i^{(1)} + g_i^{(2)})$ first before all-reducing to get $\frac{1}{32} \sum_{i=0, 1, \dots, 31} (g_i^{(1)} + g_i^{(2))}$.

Now, note how under this approach, we want a factor of $\frac{1}{8}$ only (i.e. reciprocal of the shard group size), not $\frac{1}{32}$.
- For bf16/fp32, since we use `ReduceOp.AVG` and we only reduce-scatter on the first microbatch, we correctly have a factor of $\frac{1}{8}$ on the first microbatch.
- For fp16, since we precompute the gradient divide factors at init time assuming always reducing over both shard and replica groups, we incorrectly have a factor of $\frac{1}{32}$ on the first microbatch, deviating from the bf16/fp32 case.

We can address this issue by matching the bf16/fp32 vs. fp16 semantics by computing the divide factors at runtime based on which process groups were passed into the reduction function (`foreach_reduce`).

**Additional Notes**
How to implement the HSDP reduce-scatter but no all-reduce is not entirely clear yet. (What is the cleanest way to do this?) We need to store the partial reduce-scatter output and check for it upon the next backward. We should also be sure to error if the set of parameters receiving gradients changes, in which case we cannot support this easily. Anyway, we will implement this in a follow-up.



cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request May 3, 2024
ghstack-source-id: 659d132
Pull Request resolved: #125484
@awgu
Copy link
Collaborator Author

awgu commented May 3, 2024

We do not have a unit test that can capture the difference between fp32/bf16 vs. fp16 division factors yet. It might be simpler to test this when we do implement reduce-scatter without all-reduce for HSDP.

I think the important thing is that we should be able to see that there will be a difference if we precompute gradient divide factors purely based on whether FSDP or HSDP rather than considering which process groups are actually used for the reduction on a given microbatch.

@awgu awgu marked this pull request as ready for review May 3, 2024 19:13
@awgu awgu requested review from wanchaol and weifengpy May 3, 2024 19:13
Comment on lines +281 to +283
factor: int = 1
while data_parallel_size % factor == 0 and data_parallel_size / factor > factor:
factor *= 2
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This computation should be pretty fast on CPU (and plus it is $O(\log N)$ for $N$ data parallel GPUs anyway).

@wanchaol
Copy link
Collaborator

wanchaol commented May 3, 2024

  • Upon the second microbatch, rank 0 after its reduce-scatter will additionally have its shard of $\frac{1}{8} \sum_{i \in S(0)} g_i^{(2)}$. If we only all-reduce this, then this second microbatch's gradients become $\frac{1}{32} \sum_{i=0, 1, \dots, 31} g_i^{(2)}$, so in total, rank 0 has $\frac{1}{8} \sum_{i \in S(0)} g_i^{(1)} + \frac{1}{32} \sum_{i=0, 1, \dots, 31} g_i^{(2)}$, which is wrong.

hmm a question on this: shouldn't we also allreduce the first microbatch's gradients? So rank0 should also have $\frac{1}{32} \sum_{i=0, 1, \dots, 31} g_i^{(1)}$ by the time it allreduces the second microbatch's gradient?

@awgu
Copy link
Collaborator Author

awgu commented May 3, 2024

  • Upon the second microbatch, rank 0 after its reduce-scatter will additionally have its shard of 18∑i∈S(0)gi(2). If we only all-reduce this, then this second microbatch's gradients become 132∑i=0,1,…,31gi(2), so in total, rank 0 has 18∑i∈S(0)gi(1)+132∑i=0,1,…,31gi(2), which is wrong.

hmm a question on this: shouldn't we also allreduce the first microbatch's gradients? So rank0 should also have 132∑i=0,1,…,31gi(1) by the time it allreduces the second microbatch's gradient?

Sorry, the point of what we are trying to do is to not all-reduce the first microbatch's gradients. This is to save communication. Just reduce-scattering is enough to save memory. This is a trick we want to do for HSDP but have not implemented. This PR is to prepare.

@wanchaol
Copy link
Collaborator

wanchaol commented May 3, 2024

Sorry, the point of what we are trying to do is to not all-reduce the first microbatch's gradients. This is to save communication. Just reduce-scattering is enough to save memory. This is a trick we want to do for HSDP but have not implemented. This PR is to prepare.

Oh I see, make sense, so the reason we need to do reduce_scatter is to save gradients memory so we have to do it, but we would want to allreduce at the end of all microbatches to save communication?

@awgu
Copy link
Collaborator Author

awgu commented May 3, 2024

Sorry, the point of what we are trying to do is to not all-reduce the first microbatch's gradients. This is to save communication. Just reduce-scattering is enough to save memory. This is a trick we want to do for HSDP but have not implemented. This PR is to prepare.

Oh I see, make sense, so the reason we need to do reduce_scatter is to save gradients memory so we have to do it, but we would want to allreduce at the end of all microbatches to save communication?

Yep!

Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sgtm!

@awgu awgu added the ciflow/trunk Trigger trunk jobs on your pull request label May 3, 2024
@awgu
Copy link
Collaborator Author

awgu commented May 3, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@huydhn huydhn added the ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR label May 5, 2024
@huydhn
Copy link
Contributor

huydhn commented May 5, 2024

@pytorchbot revert -m 'Sorry for reverting your change, I am trying to restore ROCm distributed failures in trunk https://hud.pytorch.org/pytorch/pytorch/commit/9aa7699185e4ec39077e3046dfd63244dffa9ddb' -c weird

I'm not entirely sure if the failure is related, so I'll reland the change if it's proven to be not the case:

  • Add ciflow/periodic to run distributed ROCm jobs
  • Remove ci-td-distributed label as this label enables TD on distributed job, and it could miss that failed test (cc @clee2000)

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@awgu your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request May 5, 2024
This reverts commit 9aa7699.

Reverted #125484 on behalf of https://github.com/huydhn due to Sorry for reverting your change, I am trying to restore ROCm distributed failures in trunk https://hud.pytorch.org/pytorch/pytorch/commit/9aa7699185e4ec39077e3046dfd63244dffa9ddb ([comment](#125484 (comment)))
@ezyang
Copy link
Contributor

ezyang commented May 5, 2024

@pytorchbot merge -f "revert broke stuff"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the gh/awgu/580/head branch June 5, 2024 01:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp2) release notes category Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants