Skip to content

Commit c8e785b

Browse files
author
Andrew Gu
committed
[FSDP2] Computed grad divide factors at runtime
ghstack-source-id: 659d132 Pull Request resolved: #125484
1 parent 0347872 commit c8e785b

File tree

3 files changed

+59
-64
lines changed

3 files changed

+59
-64
lines changed

test/distributed/_composable/fsdp/test_fully_shard_comm.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
OffloadPolicy,
1919
)
2020
from torch.distributed._composable.fsdp._fsdp_collectives import (
21+
_div_if_needed,
22+
_get_gradient_divide_factors,
2123
foreach_all_gather,
2224
foreach_all_gather_copy_out,
2325
foreach_reduce,
@@ -207,6 +209,18 @@ def test_reduce_scatter_fp32(self):
207209
reduce_scatter_dtype=torch.float32,
208210
)
209211

212+
@unittest.skipIf(not TEST_CUDA, "no cuda")
213+
def test_reduce_scatter_fp16(self):
214+
param_sizes = self._get_param_sizes()
215+
default_stream = torch.cuda.current_stream()
216+
stream = torch.cuda.Stream()
217+
for reduce_scatter_stream in (default_stream, stream):
218+
self._test_reduce_scatter(
219+
param_sizes,
220+
reduce_scatter_stream=reduce_scatter_stream,
221+
reduce_scatter_dtype=torch.float16,
222+
)
223+
210224
def _test_reduce_scatter(
211225
self,
212226
param_sizes: List[torch.Size],
@@ -238,17 +252,24 @@ def _test_reduce_scatter(
238252
orig_dtype=orig_params[0].dtype,
239253
reduce_dtype=reduce_scatter_dtype,
240254
device=self.device,
241-
divide_factors=fsdp_param_group._grad_divide_factors,
242255
all_reduce_group=None,
243256
all_reduce_stream=all_reduce_stream,
244257
)
245258
torch.cuda.current_stream().wait_event(view_out_event)
246259

247260
# Check reduce-scatter correctness
261+
predivide_factor, postdivide_factor = _get_gradient_divide_factors(
262+
group, None, reduce_scatter_dtype
263+
)
248264
reduced_grads = [grad.detach().clone() for grad in unsharded_grads]
249265
for grad in reduced_grads:
250-
dist.all_reduce(grad, group=group)
251-
grad /= self.world_size
266+
_div_if_needed(grad, predivide_factor)
267+
dist.all_reduce(
268+
grad,
269+
group=group,
270+
op=dist.ReduceOp.AVG if predivide_factor is None else dist.ReduceOp.SUM,
271+
)
272+
_div_if_needed(grad, postdivide_factor)
252273
for fsdp_param, reduced_grad in zip(fsdp_params, reduced_grads):
253274
sharded_grad = fsdp_param.sharded_param.grad
254275
self.assertIsInstance(sharded_grad, DTensor)

torch/distributed/_composable/fsdp/_fsdp_collectives.py

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ def foreach_reduce(
125125
orig_dtype: torch.dtype,
126126
reduce_dtype: Optional[torch.dtype],
127127
device: torch.device,
128-
divide_factors: Union[Tuple[None, None], Tuple[float, float]],
129128
all_reduce_group: Optional[dist.ProcessGroup],
130129
all_reduce_stream: torch.cuda.Stream,
131130
) -> torch.cuda.Event:
@@ -142,7 +141,9 @@ def foreach_reduce(
142141
)
143142
grad_dtype = unsharded_grads[0].dtype
144143
reduce_dtype = reduce_dtype or grad_dtype
145-
predivide_factor, postdivide_factor = divide_factors
144+
predivide_factor, postdivide_factor = _get_gradient_divide_factors(
145+
reduce_scatter_group, all_reduce_group, reduce_dtype
146+
)
146147
world_size = reduce_scatter_group.size()
147148
padded_unsharded_sizes = tuple(
148149
_get_dim0_padded_size(grad.size(), world_size) for grad in unsharded_grads
@@ -166,18 +167,22 @@ def foreach_reduce(
166167
(reduce_scatter_output_numel,)
167168
)
168169
_div_if_needed(reduce_scatter_input, predivide_factor)
169-
_reduce_scatter(
170-
post_reduce_output,
171-
reduce_scatter_input,
172-
reduce_scatter_group,
173-
divide_factors,
170+
dist.reduce_scatter_tensor(
171+
output=post_reduce_output,
172+
input=reduce_scatter_input,
173+
group=reduce_scatter_group,
174+
op=ReduceOp.AVG if predivide_factor is None else ReduceOp.SUM,
174175
)
175176
view_out_stream = reduce_scatter_stream
176177
if all_reduce_group is not None:
177178
view_out_stream = all_reduce_stream
178179
all_reduce_stream.wait_stream(reduce_scatter_stream)
179180
with torch.cuda.stream(all_reduce_stream):
180-
_all_reduce(post_reduce_output, all_reduce_group, divide_factors)
181+
dist.all_reduce(
182+
post_reduce_output,
183+
group=all_reduce_group,
184+
op=ReduceOp.AVG if predivide_factor is None else ReduceOp.SUM,
185+
)
181186
with torch.cuda.stream(view_out_stream):
182187
_div_if_needed(post_reduce_output, postdivide_factor)
183188
post_reduce_output = _to_dtype_if_needed(post_reduce_output, orig_dtype)
@@ -257,30 +262,27 @@ def _get_all_gather_input_metadatas(
257262
)
258263

259264

260-
def _reduce_scatter(
261-
output: torch.Tensor,
262-
input: torch.Tensor,
263-
group: dist.ProcessGroup,
264-
divide_factors: Union[Tuple[None, None], Tuple[float, float]],
265-
) -> None:
266-
if divide_factors[0]:
267-
dist.reduce_scatter_tensor(output, input, group=group)
268-
else:
269-
# Using NCCL's reduce-scatter to do the division by world size saves
270-
# extra memory read/write from a separate division kernel
271-
dist.reduce_scatter_tensor(output, input, op=ReduceOp.AVG, group=group)
272-
273-
274-
def _all_reduce(
275-
tensor: torch.Tensor,
276-
group: dist.ProcessGroup,
277-
divide_factors: Union[Tuple[None, None], Tuple[float, float]],
278-
) -> None:
279-
if divide_factors[0]:
280-
dist.all_reduce(tensor, group=group)
281-
else:
282-
# saves extra memory read/write from a separate division kernel
283-
dist.all_reduce(tensor, op=ReduceOp.AVG, group=group)
265+
def _get_gradient_divide_factors(
266+
reduce_scatter_group: dist.ProcessGroup,
267+
all_reduce_group: Optional[dist.ProcessGroup],
268+
reduce_dtype: torch.dtype,
269+
) -> Union[Tuple[None, None], Tuple[float, float]]:
270+
# For fp32/bf16, we do not need to worry about overflow/underflow, so we
271+
# use NCCL's built-in division to avoid separate div kernels
272+
if reduce_dtype in (torch.float32, torch.bfloat16):
273+
return None, None
274+
data_parallel_size = reduce_scatter_group.size()
275+
if all_reduce_group is not None:
276+
data_parallel_size *= all_reduce_group.size()
277+
# Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid
278+
# overflow/underflow. For N data parallel workers, each worker computes
279+
# g_i, and they collectively reduce (g_1 + ... + g_N) / N. To avoid
280+
# overflow/underflow, we divide by ~sqrt(N) before/after the reduction.
281+
factor: int = 1
282+
while data_parallel_size % factor == 0 and data_parallel_size / factor > factor:
283+
factor *= 2
284+
factor = float(factor)
285+
return (factor, data_parallel_size / factor)
284286

285287

286288
def _div_if_needed(tensor: torch.Tensor, div_factor: Optional[float]) -> None:

torch/distributed/_composable/fsdp/_fsdp_param_group.py

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import contextlib
22

3-
from typing import Any, cast, Dict, List, NamedTuple, Optional, Set, Tuple, Union
3+
from typing import Any, cast, Dict, List, NamedTuple, Optional, Set, Tuple
44

55
import torch
66
import torch.distributed as dist
@@ -164,32 +164,6 @@ def _init_mp_dtypes(self) -> None:
164164
)
165165
self._reduce_dtype = next(iter(reduce_dtypes))
166166

167-
def _init_grad_divide_factors(self):
168-
data_parallel_world_size = 1
169-
data_parallel_world_size *= self.mesh_info.shard_mesh_size
170-
if self._is_hsdp:
171-
data_parallel_world_size *= self.mesh_info.replicate_mesh_size
172-
if self._reduce_dtype in (torch.float32, torch.bfloat16):
173-
# Use NCCL's AVG op to divide after reduction since it is more
174-
# performant and fp32 has sufficient precision
175-
self._grad_divide_factors: Union[Tuple[None, None], Tuple[float, float]] = (
176-
None,
177-
None,
178-
)
179-
return
180-
# Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid
181-
# overflow/underflow. For N data parallel workers, each worker computes
182-
# g_i, and they collectively reduce (g_1 + ... + g_N) / N. To avoid
183-
# overflow/underflow, we divide by ~sqrt(N) before/after the reduction.
184-
factor: int = 1
185-
while (
186-
data_parallel_world_size % factor == 0
187-
and data_parallel_world_size / factor > factor
188-
):
189-
factor *= 2
190-
factor = float(factor)
191-
self._grad_divide_factors = (factor, data_parallel_world_size / factor)
192-
193167
def lazy_init(self):
194168
# Lazy init should be idempotent
195169
param_names_on_meta = [
@@ -207,7 +181,6 @@ def lazy_init(self):
207181
# Initialize mixed precision attributes lazily in case the user changes
208182
# the parameter dtypes after construction time but before forward
209183
self._init_mp_dtypes()
210-
self._init_grad_divide_factors()
211184
self._register_state_dict_hooks()
212185

213186
# Runtime #
@@ -346,7 +319,6 @@ def post_backward(self, *unused: Any):
346319
self._orig_dtype,
347320
self._reduce_dtype,
348321
self.device,
349-
self._grad_divide_factors,
350322
self._all_reduce_process_group
351323
if self._is_hsdp and self.all_reduce_grads
352324
else None,

0 commit comments

Comments
 (0)