Skip to content

Commit f70bd71

Browse files
Andrew Gupytorchmergebot
authored andcommitted
[FSDP2] Computed grad divide factors at runtime (#125484)
**Context** We are interested in supporting the case where HSDP reduce-scatters but does not all-reduce in a microbatch backward. This saves communication while still saving memory. Only on the last microbatch do we need to both reduce-scatter and all-reduce. This is not implemented yet and will hopefully come in a future PR. There is one notable part of doing this. On the last microbatch, we need to perform an accumulation step after reduce-scatter and before all-reduce. If not, then the preceding microbatch's gradients will not be contributed across the replica group. (In other words, we cannot simply accumulate _after_ all-reduce.) Consider 32 GPUs with 4-way replication and 8-way sharding and 2 microbatches, and focus on global rank 0. - After the first microbatch, rank 0 will have its shard of $\frac{1}{8} \sum_{i \in S(0)} g_i^{(1)}$, where we define $S(0) = \{0, 1, \dots, 7\}$ to be the ranks in its shard group and we define the $(1)$ superscript to denote the first microbatch. - Upon the second microbatch, rank 0 after its reduce-scatter will additionally have its shard of $\frac{1}{8} \sum_{i \in S(0)} g_i^{(2)}$. If we only all-reduce this, then this second microbatch's gradients become $\frac{1}{32} \sum_{i=0, 1, \dots, 31} g_i^{(2)}$, so in total, rank 0 has $\frac{1}{8} \sum_{i \in S(0)} g_i^{(1)} + \frac{1}{32} \sum_{i=0, 1, \dots, 31} g_i^{(2)}$, which is wrong. - Importantly, we must accumulate $\frac{1}{8} \sum_{i \in S(0)} g_i^{(1)} + \frac{1}{8} \sum_{i \in S(0)} g_i^{(2)} = \frac{1}{8}\sum_{i \in S(0)} (g_i^{(1)} + g_i^{(2)})$ first before all-reducing to get $\frac{1}{32} \sum_{i=0, 1, \dots, 31} (g_i^{(1)} + g_i^{(2)})$. Now, note how under this approach, we want a factor of $\frac{1}{8}$ only (i.e. reciprocal of the shard group size), not $\frac{1}{32}$, for the first microbatch's gradients. - For bf16/fp32, since we use `ReduceOp.AVG` and we only reduce-scatter on the first microbatch, we correctly have a factor of $\frac{1}{8}$ on the first microbatch. - For fp16, since we precompute the gradient divide factors at init time assuming always reducing over both shard and replica groups, we incorrectly have a factor of $\frac{1}{32}$ on the first microbatch, deviating from the bf16/fp32 case. We can address this issue by matching the bf16/fp32 vs. fp16 semantics by computing the divide factors at runtime based on which process groups were passed into the reduction function (`foreach_reduce`). **Additional Notes** How to implement the HSDP reduce-scatter but no all-reduce is not entirely clear yet. (What is the cleanest way to do this?) We need to store the partial reduce-scatter output and check for it upon the next backward. We should also be sure to error if the set of parameters receiving gradients changes, in which case we cannot support this easily. Anyway, we will implement this in a follow-up. Pull Request resolved: #125484 Approved by: https://github.com/wanchaol ghstack dependencies: #125431, #125479
1 parent dba689b commit f70bd71

File tree

3 files changed

+59
-64
lines changed

3 files changed

+59
-64
lines changed

test/distributed/_composable/fsdp/test_fully_shard_comm.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
OffloadPolicy,
1919
)
2020
from torch.distributed._composable.fsdp._fsdp_collectives import (
21+
_div_if_needed,
22+
_get_gradient_divide_factors,
2123
foreach_all_gather,
2224
foreach_all_gather_copy_out,
2325
foreach_reduce,
@@ -207,6 +209,18 @@ def test_reduce_scatter_fp32(self):
207209
reduce_scatter_dtype=torch.float32,
208210
)
209211

212+
@unittest.skipIf(not TEST_CUDA, "no cuda")
213+
def test_reduce_scatter_fp16(self):
214+
param_sizes = self._get_param_sizes()
215+
default_stream = torch.cuda.current_stream()
216+
stream = torch.cuda.Stream()
217+
for reduce_scatter_stream in (default_stream, stream):
218+
self._test_reduce_scatter(
219+
param_sizes,
220+
reduce_scatter_stream=reduce_scatter_stream,
221+
reduce_scatter_dtype=torch.float16,
222+
)
223+
210224
def _test_reduce_scatter(
211225
self,
212226
param_sizes: List[torch.Size],
@@ -238,17 +252,24 @@ def _test_reduce_scatter(
238252
orig_dtype=orig_params[0].dtype,
239253
reduce_dtype=reduce_scatter_dtype,
240254
device=self.device,
241-
divide_factors=fsdp_param_group._grad_divide_factors,
242255
all_reduce_group=None,
243256
all_reduce_stream=all_reduce_stream,
244257
)
245258
torch.cuda.current_stream().wait_event(view_out_event)
246259

247260
# Check reduce-scatter correctness
261+
predivide_factor, postdivide_factor = _get_gradient_divide_factors(
262+
group, None, reduce_scatter_dtype
263+
)
248264
reduced_grads = [grad.detach().clone() for grad in unsharded_grads]
249265
for grad in reduced_grads:
250-
dist.all_reduce(grad, group=group)
251-
grad /= self.world_size
266+
_div_if_needed(grad, predivide_factor)
267+
dist.all_reduce(
268+
grad,
269+
group=group,
270+
op=dist.ReduceOp.AVG if predivide_factor is None else dist.ReduceOp.SUM,
271+
)
272+
_div_if_needed(grad, postdivide_factor)
252273
for fsdp_param, reduced_grad in zip(fsdp_params, reduced_grads):
253274
sharded_grad = fsdp_param.sharded_param.grad
254275
self.assertIsInstance(sharded_grad, DTensor)

torch/distributed/_composable/fsdp/_fsdp_collectives.py

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ def foreach_reduce(
125125
orig_dtype: torch.dtype,
126126
reduce_dtype: Optional[torch.dtype],
127127
device: torch.device,
128-
divide_factors: Union[Tuple[None, None], Tuple[float, float]],
129128
all_reduce_group: Optional[dist.ProcessGroup],
130129
all_reduce_stream: torch.cuda.Stream,
131130
) -> torch.cuda.Event:
@@ -142,7 +141,9 @@ def foreach_reduce(
142141
)
143142
grad_dtype = unsharded_grads[0].dtype
144143
reduce_dtype = reduce_dtype or grad_dtype
145-
predivide_factor, postdivide_factor = divide_factors
144+
predivide_factor, postdivide_factor = _get_gradient_divide_factors(
145+
reduce_scatter_group, all_reduce_group, reduce_dtype
146+
)
146147
world_size = reduce_scatter_group.size()
147148
padded_unsharded_sizes = tuple(
148149
_get_dim0_padded_size(grad.size(), world_size) for grad in unsharded_grads
@@ -166,18 +167,22 @@ def foreach_reduce(
166167
(reduce_scatter_output_numel,)
167168
)
168169
_div_if_needed(reduce_scatter_input, predivide_factor)
169-
_reduce_scatter(
170-
post_reduce_output,
171-
reduce_scatter_input,
172-
reduce_scatter_group,
173-
divide_factors,
170+
dist.reduce_scatter_tensor(
171+
output=post_reduce_output,
172+
input=reduce_scatter_input,
173+
group=reduce_scatter_group,
174+
op=ReduceOp.AVG if predivide_factor is None else ReduceOp.SUM,
174175
)
175176
view_out_stream = reduce_scatter_stream
176177
if all_reduce_group is not None:
177178
view_out_stream = all_reduce_stream
178179
all_reduce_stream.wait_stream(reduce_scatter_stream)
179180
with torch.cuda.stream(all_reduce_stream):
180-
_all_reduce(post_reduce_output, all_reduce_group, divide_factors)
181+
dist.all_reduce(
182+
post_reduce_output,
183+
group=all_reduce_group,
184+
op=ReduceOp.AVG if predivide_factor is None else ReduceOp.SUM,
185+
)
181186
with torch.cuda.stream(view_out_stream):
182187
_div_if_needed(post_reduce_output, postdivide_factor)
183188
post_reduce_output = _to_dtype_if_needed(post_reduce_output, orig_dtype)
@@ -257,30 +262,27 @@ def _get_all_gather_input_metadatas(
257262
)
258263

259264

260-
def _reduce_scatter(
261-
output: torch.Tensor,
262-
input: torch.Tensor,
263-
group: dist.ProcessGroup,
264-
divide_factors: Union[Tuple[None, None], Tuple[float, float]],
265-
) -> None:
266-
if divide_factors[0]:
267-
dist.reduce_scatter_tensor(output, input, group=group)
268-
else:
269-
# Using NCCL's reduce-scatter to do the division by world size saves
270-
# extra memory read/write from a separate division kernel
271-
dist.reduce_scatter_tensor(output, input, op=ReduceOp.AVG, group=group)
272-
273-
274-
def _all_reduce(
275-
tensor: torch.Tensor,
276-
group: dist.ProcessGroup,
277-
divide_factors: Union[Tuple[None, None], Tuple[float, float]],
278-
) -> None:
279-
if divide_factors[0]:
280-
dist.all_reduce(tensor, group=group)
281-
else:
282-
# saves extra memory read/write from a separate division kernel
283-
dist.all_reduce(tensor, op=ReduceOp.AVG, group=group)
265+
def _get_gradient_divide_factors(
266+
reduce_scatter_group: dist.ProcessGroup,
267+
all_reduce_group: Optional[dist.ProcessGroup],
268+
reduce_dtype: torch.dtype,
269+
) -> Union[Tuple[None, None], Tuple[float, float]]:
270+
# For fp32/bf16, we do not need to worry about overflow/underflow, so we
271+
# use NCCL's built-in division to avoid separate div kernels
272+
if reduce_dtype in (torch.float32, torch.bfloat16):
273+
return None, None
274+
data_parallel_size = reduce_scatter_group.size()
275+
if all_reduce_group is not None:
276+
data_parallel_size *= all_reduce_group.size()
277+
# Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid
278+
# overflow/underflow. For N data parallel workers, each worker computes
279+
# g_i, and they collectively reduce (g_1 + ... + g_N) / N. To avoid
280+
# overflow/underflow, we divide by ~sqrt(N) before/after the reduction.
281+
factor: int = 1
282+
while data_parallel_size % factor == 0 and data_parallel_size / factor > factor:
283+
factor *= 2
284+
factor = float(factor)
285+
return (factor, data_parallel_size / factor)
284286

285287

286288
def _div_if_needed(tensor: torch.Tensor, div_factor: Optional[float]) -> None:

torch/distributed/_composable/fsdp/_fsdp_param_group.py

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import contextlib
22

3-
from typing import Any, cast, Dict, List, NamedTuple, Optional, Set, Tuple, Union
3+
from typing import Any, cast, Dict, List, NamedTuple, Optional, Set, Tuple
44

55
import torch
66
import torch.distributed as dist
@@ -164,32 +164,6 @@ def _init_mp_dtypes(self) -> None:
164164
)
165165
self._reduce_dtype = next(iter(reduce_dtypes))
166166

167-
def _init_grad_divide_factors(self):
168-
data_parallel_world_size = 1
169-
data_parallel_world_size *= self.mesh_info.shard_mesh_size
170-
if self._is_hsdp:
171-
data_parallel_world_size *= self.mesh_info.replicate_mesh_size
172-
if self._reduce_dtype in (torch.float32, torch.bfloat16):
173-
# Use NCCL's AVG op to divide after reduction since it is more
174-
# performant and fp32 has sufficient precision
175-
self._grad_divide_factors: Union[Tuple[None, None], Tuple[float, float]] = (
176-
None,
177-
None,
178-
)
179-
return
180-
# Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid
181-
# overflow/underflow. For N data parallel workers, each worker computes
182-
# g_i, and they collectively reduce (g_1 + ... + g_N) / N. To avoid
183-
# overflow/underflow, we divide by ~sqrt(N) before/after the reduction.
184-
factor: int = 1
185-
while (
186-
data_parallel_world_size % factor == 0
187-
and data_parallel_world_size / factor > factor
188-
):
189-
factor *= 2
190-
factor = float(factor)
191-
self._grad_divide_factors = (factor, data_parallel_world_size / factor)
192-
193167
def lazy_init(self):
194168
# Lazy init should be idempotent
195169
param_names_on_meta = [
@@ -207,7 +181,6 @@ def lazy_init(self):
207181
# Initialize mixed precision attributes lazily in case the user changes
208182
# the parameter dtypes after construction time but before forward
209183
self._init_mp_dtypes()
210-
self._init_grad_divide_factors()
211184
self._register_state_dict_hooks()
212185

213186
# Runtime #
@@ -346,7 +319,6 @@ def post_backward(self, *unused: Any):
346319
self._orig_dtype,
347320
self._reduce_dtype,
348321
self.device,
349-
self._grad_divide_factors,
350322
self._all_reduce_process_group
351323
if self._is_hsdp and self.all_reduce_grads
352324
else None,

0 commit comments

Comments
 (0)