Skip to content

Commit bccac39

Browse files
author
Andrew Gu
committed
Update on "[FSDP2] Computed grad divide factors at runtime"
**Context** We are interested in supporting the case where HSDP reduce-scatters but does not all-reduce in a microbatch backward. This saves communication while still saving memory. Only on the last microbatch do we need to both reduce-scatter and all-reduce. This is not implemented yet and will hopefully come in a future PR. There is one notable part of doing this. On the last microbatch, we need to perform an accumulation step after reduce-scatter and before all-reduce. If not, then the preceding microbatch's gradients will not be contributed across the replica group. (In other words, we cannot simply accumulate _after_ all-reduce.) Consider 32 GPUs with 4-way replication and 8-way sharding and 2 microbatches, and focus on global rank 0. - After the first microbatch, rank 0 will have its shard of $\frac{1}{8} \sum_{i \in S(0)} g_i^{(1)}$, where define $S(0) = \{0, 1, \dots, 7\}$ to be the ranks in its shard group and the $(1)$ superscript to denote the first microbatch. - Upon the second microbatch, rank 0 after its reduce-scatter will newly have its shard of $\frac{1}{8} \sum_{i \in S(0)} g_i^{(2)}$. If we only all-reduce this, then this second microbatch's gradients become $\frac{1}{32} \sum_{i=0, 1, \dots, 31} g_i^{(2)}$, so in total, rank 0 has $\frac{1}{8} \sum_{i \in S(0)} g_i^{(1)} + \frac{1}{32} \sum_{i=0, 1, \dots, 31} g_i^{(2)}$, which is wrong. - Importantly, we must accumulate $\frac{1}{8} \sum_{i \in S(0)} g_i^{(1)} + \frac{1}{8} \sum_{i \in S(0)} g_i^{(2)} = \frac{1}{8}\sum_{i \in S(0)} (g_i^{(1)} + g_i^{(2)})$ first before all-reducing to get $\frac{1}{32} \sum_{i=0, 1, \dots, 31} (g_i^{(1)} + g_i^{(2))}$. Now, note how under this approach, we want a factor of $\frac{1}{8}$ only (i.e. reciprocal of the shard group size), not $\frac{1}{32}$. - For bf16/fp32, since we use `ReduceOp.AVG` and we only reduce-scatter on the first microbatch, we correctly have a factor of $\frac{1}{8}$ on the first microbatch. - For fp16, since we precompute the gradient divide factors at init time assuming always reducing over both shard and replica groups, we incorrectly have a factor of $\frac{1}{32}$ on the first microbatch, deviating from the bf16/fp32 case. We can address this issue by matching the bf16/fp32 vs. fp16 semantics by computing the divide factors at runtime based on which process groups were passed into the reduction function (`foreach_reduce`). **Additional Notes** How to implement the HSDP reduce-scatter but no all-reduce is not entirely clear yet. (What is the cleanest way to do this?) We need to store the partial reduce-scatter output and check for it upon the next backward. We should also be sure to error if the set of parameters receiving gradients changes, in which case we cannot support this easily. Anyway, we will implement this in a follow-up. cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k [ghstack-poisoned]
2 parents 05f5697 + c2c3c6b commit bccac39

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

test/distributed/_composable/fsdp/test_fully_shard_comm.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
OffloadPolicy,
1919
)
2020
from torch.distributed._composable.fsdp._fsdp_collectives import (
21+
_div_if_needed,
22+
_get_gradient_divide_factors,
2123
foreach_all_gather,
2224
foreach_all_gather_copy_out,
2325
foreach_reduce,
@@ -207,6 +209,18 @@ def test_reduce_scatter_fp32(self):
207209
reduce_scatter_dtype=torch.float32,
208210
)
209211

212+
@unittest.skipIf(not TEST_CUDA, "no cuda")
213+
def test_reduce_scatter_fp16(self):
214+
param_sizes = self._get_param_sizes()
215+
default_stream = torch.cuda.current_stream()
216+
stream = torch.cuda.Stream()
217+
for reduce_scatter_stream in (default_stream, stream):
218+
self._test_reduce_scatter(
219+
param_sizes,
220+
reduce_scatter_stream=reduce_scatter_stream,
221+
reduce_scatter_dtype=torch.float16,
222+
)
223+
210224
def _test_reduce_scatter(
211225
self,
212226
param_sizes: List[torch.Size],
@@ -244,10 +258,18 @@ def _test_reduce_scatter(
244258
torch.cuda.current_stream().wait_event(view_out_event)
245259

246260
# Check reduce-scatter correctness
261+
predivide_factor, postdivide_factor = _get_gradient_divide_factors(
262+
group, None, reduce_scatter_dtype
263+
)
247264
reduced_grads = [grad.detach().clone() for grad in unsharded_grads]
248265
for grad in reduced_grads:
249-
dist.all_reduce(grad, group=group)
250-
grad /= self.world_size
266+
_div_if_needed(grad, predivide_factor)
267+
dist.all_reduce(
268+
grad,
269+
group=group,
270+
op=dist.ReduceOp.AVG if predivide_factor is None else dist.ReduceOp.SUM,
271+
)
272+
_div_if_needed(grad, postdivide_factor)
251273
for fsdp_param, reduced_grad in zip(fsdp_params, reduced_grads):
252274
sharded_grad = fsdp_param.sharded_param.grad
253275
self.assertIsInstance(sharded_grad, DTensor)

0 commit comments

Comments
 (0)