-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[FSDP2] Added test to show rank 0 broadcast for HSDP replicas #125431
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125431
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 7505a71 with merge base b03fb49 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| # E.g. for mesh [[0, 1, 2, 3], [4, 5, 6, 7]] sharding on dim-1 and | ||
| # replicating on dim-0, broadcast with sources 0, 1, 2, 3 | ||
| src_rank = dist.get_process_group_ranks(replicate_group)[0] | ||
| torch.distributed.broadcast( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Today, in-place c10d broadcast is preferred.
If we want to use functional broadcast:
- We need to verify the semantics. We may still need to get the
src_ranklike we do here, which is confusing since it is the rank with respect to the global process group, not the one passed to broadcast. - We need to swap the newly broadcasted tensor in. Since FSDP holds a reference, we cannot just
setattr(module, param_name, broadcasted_param)since that would leave FSDP's reference as stale. We may consider usingswap_tensors, but we are blocked by the local tensor padding issue since the broadcasted parameter would not have padding and is actually a view into the padded local tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I think a inplace broadcast make sense here!
| return 4 | ||
|
|
||
| @unittest.skipIf(not TEST_CUDA, "no cuda") | ||
| def test_hsdp_broadcast_across_replicas(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might wonder, why not always have HSDP broadcast during init time. The issue is that we only need to broadcast if we are initializing from scratch (not a checkpoint). If we are initializing from a checkpoint, then we are already guaranteed that replicas are the same, and broadcasting is wasteful and can affect init time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice! I feel good to have the exmaple of getting global rank from replicate_group. hopefully people copy from here instead of using local rank in c10d
…cas"
This PR shows a simple utility to broadcast the parameters across replicas for HSDP:
```
replicate_group = mesh.get_group("replicate")
for param in model.parameters():
# E.g. for mesh [[0, 1, 2, 3], [4, 5, 6, 7]] sharding on dim-1 and
# replicating on dim-0, broadcast with sources 0, 1, 2, 3
src_rank = dist.get_process_group_ranks(replicate_group)[0]
torch.distributed.broadcast(
param.to_local(), src=src_rank, group=replicate_group
)
```
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k
[ghstack-poisoned]
|
By the way, I am open to the idea of some (1) is because it is a waste to broadcast if loading from a state dict. |
(1) feels easy to understand |
Today, FSDP1 has a |
wanchaol
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
| # E.g. for mesh [[0, 1, 2, 3], [4, 5, 6, 7]] sharding on dim-1 and | ||
| # replicating on dim-0, broadcast with sources 0, 1, 2, 3 | ||
| src_rank = dist.get_process_group_ranks(replicate_group)[0] | ||
| torch.distributed.broadcast( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I think a inplace broadcast make sense here!
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This adds HSDP to the existing gradient accumulation tests and includes some minor changes to simplify things a tiny bit. Pull Request resolved: #125479 Approved by: https://github.com/wanchaol ghstack dependencies: #125431
**Context**
We are interested in supporting the case where HSDP reduce-scatters but does not all-reduce in a microbatch backward. This saves communication while still saving memory. Only on the last microbatch do we need to both reduce-scatter and all-reduce. This is not implemented yet and will hopefully come in a future PR.
There is one notable part of doing this. On the last microbatch, we need to perform an accumulation step after reduce-scatter and before all-reduce. If not, then the preceding microbatch's gradients will not be contributed across the replica group. (In other words, we cannot simply accumulate _after_ all-reduce.)
Consider 32 GPUs with 4-way replication and 8-way sharding and 2 microbatches, and focus on global rank 0.
- After the first microbatch, rank 0 will have its shard of $\frac{1}{8} \sum_{i \in S(0)} g_i^{(1)}$, where we define $S(0) = \{0, 1, \dots, 7\}$ to be the ranks in its shard group and we define the $(1)$ superscript to denote the first microbatch.
- Upon the second microbatch, rank 0 after its reduce-scatter will additionally have its shard of $\frac{1}{8} \sum_{i \in S(0)} g_i^{(2)}$. If we only all-reduce this, then this second microbatch's gradients become $\frac{1}{32} \sum_{i=0, 1, \dots, 31} g_i^{(2)}$, so in total, rank 0 has $\frac{1}{8} \sum_{i \in S(0)} g_i^{(1)} + \frac{1}{32} \sum_{i=0, 1, \dots, 31} g_i^{(2)}$, which is wrong.
- Importantly, we must accumulate $\frac{1}{8} \sum_{i \in S(0)} g_i^{(1)} + \frac{1}{8} \sum_{i \in S(0)} g_i^{(2)} = \frac{1}{8}\sum_{i \in S(0)} (g_i^{(1)} + g_i^{(2)})$ first before all-reducing to get $\frac{1}{32} \sum_{i=0, 1, \dots, 31} (g_i^{(1)} + g_i^{(2)})$.
Now, note how under this approach, we want a factor of $\frac{1}{8}$ only (i.e. reciprocal of the shard group size), not $\frac{1}{32}$, for the first microbatch's gradients.
- For bf16/fp32, since we use `ReduceOp.AVG` and we only reduce-scatter on the first microbatch, we correctly have a factor of $\frac{1}{8}$ on the first microbatch.
- For fp16, since we precompute the gradient divide factors at init time assuming always reducing over both shard and replica groups, we incorrectly have a factor of $\frac{1}{32}$ on the first microbatch, deviating from the bf16/fp32 case.
We can address this issue by matching the bf16/fp32 vs. fp16 semantics by computing the divide factors at runtime based on which process groups were passed into the reduction function (`foreach_reduce`).
**Additional Notes**
How to implement the HSDP reduce-scatter but no all-reduce is not entirely clear yet. (What is the cleanest way to do this?) We need to store the partial reduce-scatter output and check for it upon the next backward. We should also be sure to error if the set of parameters receiving gradients changes, in which case we cannot support this easily. Anyway, we will implement this in a follow-up.
Pull Request resolved: #125484
Approved by: https://github.com/wanchaol
ghstack dependencies: #125431, #125479
**Context**
We are interested in supporting the case where HSDP reduce-scatters but does not all-reduce in a microbatch backward. This saves communication while still saving memory. Only on the last microbatch do we need to both reduce-scatter and all-reduce. This is not implemented yet and will hopefully come in a future PR.
There is one notable part of doing this. On the last microbatch, we need to perform an accumulation step after reduce-scatter and before all-reduce. If not, then the preceding microbatch's gradients will not be contributed across the replica group. (In other words, we cannot simply accumulate _after_ all-reduce.)
Consider 32 GPUs with 4-way replication and 8-way sharding and 2 microbatches, and focus on global rank 0.
- After the first microbatch, rank 0 will have its shard of $\frac{1}{8} \sum_{i \in S(0)} g_i^{(1)}$, where we define $S(0) = \{0, 1, \dots, 7\}$ to be the ranks in its shard group and we define the $(1)$ superscript to denote the first microbatch.
- Upon the second microbatch, rank 0 after its reduce-scatter will additionally have its shard of $\frac{1}{8} \sum_{i \in S(0)} g_i^{(2)}$. If we only all-reduce this, then this second microbatch's gradients become $\frac{1}{32} \sum_{i=0, 1, \dots, 31} g_i^{(2)}$, so in total, rank 0 has $\frac{1}{8} \sum_{i \in S(0)} g_i^{(1)} + \frac{1}{32} \sum_{i=0, 1, \dots, 31} g_i^{(2)}$, which is wrong.
- Importantly, we must accumulate $\frac{1}{8} \sum_{i \in S(0)} g_i^{(1)} + \frac{1}{8} \sum_{i \in S(0)} g_i^{(2)} = \frac{1}{8}\sum_{i \in S(0)} (g_i^{(1)} + g_i^{(2)})$ first before all-reducing to get $\frac{1}{32} \sum_{i=0, 1, \dots, 31} (g_i^{(1)} + g_i^{(2)})$.
Now, note how under this approach, we want a factor of $\frac{1}{8}$ only (i.e. reciprocal of the shard group size), not $\frac{1}{32}$, for the first microbatch's gradients.
- For bf16/fp32, since we use `ReduceOp.AVG` and we only reduce-scatter on the first microbatch, we correctly have a factor of $\frac{1}{8}$ on the first microbatch.
- For fp16, since we precompute the gradient divide factors at init time assuming always reducing over both shard and replica groups, we incorrectly have a factor of $\frac{1}{32}$ on the first microbatch, deviating from the bf16/fp32 case.
We can address this issue by matching the bf16/fp32 vs. fp16 semantics by computing the divide factors at runtime based on which process groups were passed into the reduction function (`foreach_reduce`).
**Additional Notes**
How to implement the HSDP reduce-scatter but no all-reduce is not entirely clear yet. (What is the cleanest way to do this?) We need to store the partial reduce-scatter output and check for it upon the next backward. We should also be sure to error if the set of parameters receiving gradients changes, in which case we cannot support this easily. Anyway, we will implement this in a follow-up.
Pull Request resolved: #125484
Approved by: https://github.com/wanchaol
ghstack dependencies: #125431, #125479
Stack from ghstack (oldest at bottom):
This PR shows a simple utility to broadcast the parameters across replicas for HSDP:
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k