Skip to content

Commit dba689b

Browse files
Revert "[FSDP2] Computed grad divide factors at runtime (#125484)"
This reverts commit 9aa7699. Reverted #125484 on behalf of https://github.com/huydhn due to Sorry for reverting your change, I am trying to restore ROCm distributed failures in trunk https://hud.pytorch.org/pytorch/pytorch/commit/9aa7699185e4ec39077e3046dfd63244dffa9ddb ([comment](#125484 (comment)))
1 parent 8a0529e commit dba689b

File tree

3 files changed

+64
-59
lines changed

3 files changed

+64
-59
lines changed

test/distributed/_composable/fsdp/test_fully_shard_comm.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
OffloadPolicy,
1919
)
2020
from torch.distributed._composable.fsdp._fsdp_collectives import (
21-
_div_if_needed,
22-
_get_gradient_divide_factors,
2321
foreach_all_gather,
2422
foreach_all_gather_copy_out,
2523
foreach_reduce,
@@ -209,18 +207,6 @@ def test_reduce_scatter_fp32(self):
209207
reduce_scatter_dtype=torch.float32,
210208
)
211209

212-
@unittest.skipIf(not TEST_CUDA, "no cuda")
213-
def test_reduce_scatter_fp16(self):
214-
param_sizes = self._get_param_sizes()
215-
default_stream = torch.cuda.current_stream()
216-
stream = torch.cuda.Stream()
217-
for reduce_scatter_stream in (default_stream, stream):
218-
self._test_reduce_scatter(
219-
param_sizes,
220-
reduce_scatter_stream=reduce_scatter_stream,
221-
reduce_scatter_dtype=torch.float16,
222-
)
223-
224210
def _test_reduce_scatter(
225211
self,
226212
param_sizes: List[torch.Size],
@@ -252,24 +238,17 @@ def _test_reduce_scatter(
252238
orig_dtype=orig_params[0].dtype,
253239
reduce_dtype=reduce_scatter_dtype,
254240
device=self.device,
241+
divide_factors=fsdp_param_group._grad_divide_factors,
255242
all_reduce_group=None,
256243
all_reduce_stream=all_reduce_stream,
257244
)
258245
torch.cuda.current_stream().wait_event(view_out_event)
259246

260247
# Check reduce-scatter correctness
261-
predivide_factor, postdivide_factor = _get_gradient_divide_factors(
262-
group, None, reduce_scatter_dtype
263-
)
264248
reduced_grads = [grad.detach().clone() for grad in unsharded_grads]
265249
for grad in reduced_grads:
266-
_div_if_needed(grad, predivide_factor)
267-
dist.all_reduce(
268-
grad,
269-
group=group,
270-
op=dist.ReduceOp.AVG if predivide_factor is None else dist.ReduceOp.SUM,
271-
)
272-
_div_if_needed(grad, postdivide_factor)
250+
dist.all_reduce(grad, group=group)
251+
grad /= self.world_size
273252
for fsdp_param, reduced_grad in zip(fsdp_params, reduced_grads):
274253
sharded_grad = fsdp_param.sharded_param.grad
275254
self.assertIsInstance(sharded_grad, DTensor)

torch/distributed/_composable/fsdp/_fsdp_collectives.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def foreach_reduce(
125125
orig_dtype: torch.dtype,
126126
reduce_dtype: Optional[torch.dtype],
127127
device: torch.device,
128+
divide_factors: Union[Tuple[None, None], Tuple[float, float]],
128129
all_reduce_group: Optional[dist.ProcessGroup],
129130
all_reduce_stream: torch.cuda.Stream,
130131
) -> torch.cuda.Event:
@@ -141,9 +142,7 @@ def foreach_reduce(
141142
)
142143
grad_dtype = unsharded_grads[0].dtype
143144
reduce_dtype = reduce_dtype or grad_dtype
144-
predivide_factor, postdivide_factor = _get_gradient_divide_factors(
145-
reduce_scatter_group, all_reduce_group, reduce_dtype
146-
)
145+
predivide_factor, postdivide_factor = divide_factors
147146
world_size = reduce_scatter_group.size()
148147
padded_unsharded_sizes = tuple(
149148
_get_dim0_padded_size(grad.size(), world_size) for grad in unsharded_grads
@@ -167,22 +166,18 @@ def foreach_reduce(
167166
(reduce_scatter_output_numel,)
168167
)
169168
_div_if_needed(reduce_scatter_input, predivide_factor)
170-
dist.reduce_scatter_tensor(
171-
output=post_reduce_output,
172-
input=reduce_scatter_input,
173-
group=reduce_scatter_group,
174-
op=ReduceOp.AVG if predivide_factor is None else ReduceOp.SUM,
169+
_reduce_scatter(
170+
post_reduce_output,
171+
reduce_scatter_input,
172+
reduce_scatter_group,
173+
divide_factors,
175174
)
176175
view_out_stream = reduce_scatter_stream
177176
if all_reduce_group is not None:
178177
view_out_stream = all_reduce_stream
179178
all_reduce_stream.wait_stream(reduce_scatter_stream)
180179
with torch.cuda.stream(all_reduce_stream):
181-
dist.all_reduce(
182-
post_reduce_output,
183-
group=all_reduce_group,
184-
op=ReduceOp.AVG if predivide_factor is None else ReduceOp.SUM,
185-
)
180+
_all_reduce(post_reduce_output, all_reduce_group, divide_factors)
186181
with torch.cuda.stream(view_out_stream):
187182
_div_if_needed(post_reduce_output, postdivide_factor)
188183
post_reduce_output = _to_dtype_if_needed(post_reduce_output, orig_dtype)
@@ -262,27 +257,30 @@ def _get_all_gather_input_metadatas(
262257
)
263258

264259

265-
def _get_gradient_divide_factors(
266-
reduce_scatter_group: dist.ProcessGroup,
267-
all_reduce_group: Optional[dist.ProcessGroup],
268-
reduce_dtype: torch.dtype,
269-
) -> Union[Tuple[None, None], Tuple[float, float]]:
270-
# For fp32/bf16, we do not need to worry about overflow/underflow, so we
271-
# use NCCL's built-in division to avoid separate div kernels
272-
if reduce_dtype in (torch.float32, torch.bfloat16):
273-
return None, None
274-
data_parallel_size = reduce_scatter_group.size()
275-
if all_reduce_group is not None:
276-
data_parallel_size *= all_reduce_group.size()
277-
# Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid
278-
# overflow/underflow. For N data parallel workers, each worker computes
279-
# g_i, and they collectively reduce (g_1 + ... + g_N) / N. To avoid
280-
# overflow/underflow, we divide by ~sqrt(N) before/after the reduction.
281-
factor: int = 1
282-
while data_parallel_size % factor == 0 and data_parallel_size / factor > factor:
283-
factor *= 2
284-
factor = float(factor)
285-
return (factor, data_parallel_size / factor)
260+
def _reduce_scatter(
261+
output: torch.Tensor,
262+
input: torch.Tensor,
263+
group: dist.ProcessGroup,
264+
divide_factors: Union[Tuple[None, None], Tuple[float, float]],
265+
) -> None:
266+
if divide_factors[0]:
267+
dist.reduce_scatter_tensor(output, input, group=group)
268+
else:
269+
# Using NCCL's reduce-scatter to do the division by world size saves
270+
# extra memory read/write from a separate division kernel
271+
dist.reduce_scatter_tensor(output, input, op=ReduceOp.AVG, group=group)
272+
273+
274+
def _all_reduce(
275+
tensor: torch.Tensor,
276+
group: dist.ProcessGroup,
277+
divide_factors: Union[Tuple[None, None], Tuple[float, float]],
278+
) -> None:
279+
if divide_factors[0]:
280+
dist.all_reduce(tensor, group=group)
281+
else:
282+
# saves extra memory read/write from a separate division kernel
283+
dist.all_reduce(tensor, op=ReduceOp.AVG, group=group)
286284

287285

288286
def _div_if_needed(tensor: torch.Tensor, div_factor: Optional[float]) -> None:

torch/distributed/_composable/fsdp/_fsdp_param_group.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import contextlib
22

3-
from typing import Any, cast, Dict, List, NamedTuple, Optional, Set, Tuple
3+
from typing import Any, cast, Dict, List, NamedTuple, Optional, Set, Tuple, Union
44

55
import torch
66
import torch.distributed as dist
@@ -164,6 +164,32 @@ def _init_mp_dtypes(self) -> None:
164164
)
165165
self._reduce_dtype = next(iter(reduce_dtypes))
166166

167+
def _init_grad_divide_factors(self):
168+
data_parallel_world_size = 1
169+
data_parallel_world_size *= self.mesh_info.shard_mesh_size
170+
if self._is_hsdp:
171+
data_parallel_world_size *= self.mesh_info.replicate_mesh_size
172+
if self._reduce_dtype in (torch.float32, torch.bfloat16):
173+
# Use NCCL's AVG op to divide after reduction since it is more
174+
# performant and fp32 has sufficient precision
175+
self._grad_divide_factors: Union[Tuple[None, None], Tuple[float, float]] = (
176+
None,
177+
None,
178+
)
179+
return
180+
# Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid
181+
# overflow/underflow. For N data parallel workers, each worker computes
182+
# g_i, and they collectively reduce (g_1 + ... + g_N) / N. To avoid
183+
# overflow/underflow, we divide by ~sqrt(N) before/after the reduction.
184+
factor: int = 1
185+
while (
186+
data_parallel_world_size % factor == 0
187+
and data_parallel_world_size / factor > factor
188+
):
189+
factor *= 2
190+
factor = float(factor)
191+
self._grad_divide_factors = (factor, data_parallel_world_size / factor)
192+
167193
def lazy_init(self):
168194
# Lazy init should be idempotent
169195
param_names_on_meta = [
@@ -181,6 +207,7 @@ def lazy_init(self):
181207
# Initialize mixed precision attributes lazily in case the user changes
182208
# the parameter dtypes after construction time but before forward
183209
self._init_mp_dtypes()
210+
self._init_grad_divide_factors()
184211
self._register_state_dict_hooks()
185212

186213
# Runtime #
@@ -319,6 +346,7 @@ def post_backward(self, *unused: Any):
319346
self._orig_dtype,
320347
self._reduce_dtype,
321348
self.device,
349+
self._grad_divide_factors,
322350
self._all_reduce_process_group
323351
if self._is_hsdp and self.all_reduce_grads
324352
else None,

0 commit comments

Comments
 (0)