Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions test/distributed/_composable/fsdp/test_fully_shard_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
OffloadPolicy,
)
from torch.distributed._composable.fsdp._fsdp_collectives import (
_div_if_needed,
_get_gradient_divide_factors,
foreach_all_gather,
foreach_all_gather_copy_out,
foreach_reduce,
Expand Down Expand Up @@ -207,6 +209,18 @@ def test_reduce_scatter_fp32(self):
reduce_scatter_dtype=torch.float32,
)

@unittest.skipIf(not TEST_CUDA, "no cuda")
def test_reduce_scatter_fp16(self):
param_sizes = self._get_param_sizes()
default_stream = torch.cuda.current_stream()
stream = torch.cuda.Stream()
for reduce_scatter_stream in (default_stream, stream):
self._test_reduce_scatter(
param_sizes,
reduce_scatter_stream=reduce_scatter_stream,
reduce_scatter_dtype=torch.float16,
)

def _test_reduce_scatter(
self,
param_sizes: List[torch.Size],
Expand Down Expand Up @@ -238,17 +252,24 @@ def _test_reduce_scatter(
orig_dtype=orig_params[0].dtype,
reduce_dtype=reduce_scatter_dtype,
device=self.device,
divide_factors=fsdp_param_group._grad_divide_factors,
all_reduce_group=None,
all_reduce_stream=all_reduce_stream,
)
torch.cuda.current_stream().wait_event(view_out_event)

# Check reduce-scatter correctness
predivide_factor, postdivide_factor = _get_gradient_divide_factors(
group, None, reduce_scatter_dtype
)
reduced_grads = [grad.detach().clone() for grad in unsharded_grads]
for grad in reduced_grads:
dist.all_reduce(grad, group=group)
grad /= self.world_size
_div_if_needed(grad, predivide_factor)
dist.all_reduce(
grad,
group=group,
op=dist.ReduceOp.AVG if predivide_factor is None else dist.ReduceOp.SUM,
)
_div_if_needed(grad, postdivide_factor)
for fsdp_param, reduced_grad in zip(fsdp_params, reduced_grads):
sharded_grad = fsdp_param.sharded_param.grad
self.assertIsInstance(sharded_grad, DTensor)
Expand Down
66 changes: 34 additions & 32 deletions torch/distributed/_composable/fsdp/_fsdp_collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ def foreach_reduce(
orig_dtype: torch.dtype,
reduce_dtype: Optional[torch.dtype],
device: torch.device,
divide_factors: Union[Tuple[None, None], Tuple[float, float]],
all_reduce_group: Optional[dist.ProcessGroup],
all_reduce_stream: torch.cuda.Stream,
) -> torch.cuda.Event:
Expand All @@ -142,7 +141,9 @@ def foreach_reduce(
)
grad_dtype = unsharded_grads[0].dtype
reduce_dtype = reduce_dtype or grad_dtype
predivide_factor, postdivide_factor = divide_factors
predivide_factor, postdivide_factor = _get_gradient_divide_factors(
reduce_scatter_group, all_reduce_group, reduce_dtype
)
world_size = reduce_scatter_group.size()
padded_unsharded_sizes = tuple(
_get_dim0_padded_size(grad.size(), world_size) for grad in unsharded_grads
Expand All @@ -166,18 +167,22 @@ def foreach_reduce(
(reduce_scatter_output_numel,)
)
_div_if_needed(reduce_scatter_input, predivide_factor)
_reduce_scatter(
post_reduce_output,
reduce_scatter_input,
reduce_scatter_group,
divide_factors,
dist.reduce_scatter_tensor(
output=post_reduce_output,
input=reduce_scatter_input,
group=reduce_scatter_group,
op=ReduceOp.AVG if predivide_factor is None else ReduceOp.SUM,
)
view_out_stream = reduce_scatter_stream
if all_reduce_group is not None:
view_out_stream = all_reduce_stream
all_reduce_stream.wait_stream(reduce_scatter_stream)
with torch.cuda.stream(all_reduce_stream):
_all_reduce(post_reduce_output, all_reduce_group, divide_factors)
dist.all_reduce(
post_reduce_output,
group=all_reduce_group,
op=ReduceOp.AVG if predivide_factor is None else ReduceOp.SUM,
)
with torch.cuda.stream(view_out_stream):
_div_if_needed(post_reduce_output, postdivide_factor)
post_reduce_output = _to_dtype_if_needed(post_reduce_output, orig_dtype)
Expand Down Expand Up @@ -257,30 +262,27 @@ def _get_all_gather_input_metadatas(
)


def _reduce_scatter(
output: torch.Tensor,
input: torch.Tensor,
group: dist.ProcessGroup,
divide_factors: Union[Tuple[None, None], Tuple[float, float]],
) -> None:
if divide_factors[0]:
dist.reduce_scatter_tensor(output, input, group=group)
else:
# Using NCCL's reduce-scatter to do the division by world size saves
# extra memory read/write from a separate division kernel
dist.reduce_scatter_tensor(output, input, op=ReduceOp.AVG, group=group)


def _all_reduce(
tensor: torch.Tensor,
group: dist.ProcessGroup,
divide_factors: Union[Tuple[None, None], Tuple[float, float]],
) -> None:
if divide_factors[0]:
dist.all_reduce(tensor, group=group)
else:
# saves extra memory read/write from a separate division kernel
dist.all_reduce(tensor, op=ReduceOp.AVG, group=group)
def _get_gradient_divide_factors(
reduce_scatter_group: dist.ProcessGroup,
all_reduce_group: Optional[dist.ProcessGroup],
reduce_dtype: torch.dtype,
) -> Union[Tuple[None, None], Tuple[float, float]]:
# For fp32/bf16, we do not need to worry about overflow/underflow, so we
# use NCCL's built-in division to avoid separate div kernels
if reduce_dtype in (torch.float32, torch.bfloat16):
return None, None
data_parallel_size = reduce_scatter_group.size()
if all_reduce_group is not None:
data_parallel_size *= all_reduce_group.size()
# Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid
# overflow/underflow. For N data parallel workers, each worker computes
# g_i, and they collectively reduce (g_1 + ... + g_N) / N. To avoid
# overflow/underflow, we divide by ~sqrt(N) before/after the reduction.
factor: int = 1
while data_parallel_size % factor == 0 and data_parallel_size / factor > factor:
factor *= 2
Comment on lines +281 to +283
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This computation should be pretty fast on CPU (and plus it is $O(\log N)$ for $N$ data parallel GPUs anyway).

factor = float(factor)
return (factor, data_parallel_size / factor)


def _div_if_needed(tensor: torch.Tensor, div_factor: Optional[float]) -> None:
Expand Down
30 changes: 1 addition & 29 deletions torch/distributed/_composable/fsdp/_fsdp_param_group.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import contextlib

from typing import Any, cast, Dict, List, NamedTuple, Optional, Set, Tuple, Union
from typing import Any, cast, Dict, List, NamedTuple, Optional, Set, Tuple

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -164,32 +164,6 @@ def _init_mp_dtypes(self) -> None:
)
self._reduce_dtype = next(iter(reduce_dtypes))

def _init_grad_divide_factors(self):
data_parallel_world_size = 1
data_parallel_world_size *= self.mesh_info.shard_mesh_size
if self._is_hsdp:
data_parallel_world_size *= self.mesh_info.replicate_mesh_size
if self._reduce_dtype in (torch.float32, torch.bfloat16):
# Use NCCL's AVG op to divide after reduction since it is more
# performant and fp32 has sufficient precision
self._grad_divide_factors: Union[Tuple[None, None], Tuple[float, float]] = (
None,
None,
)
return
# Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid
# overflow/underflow. For N data parallel workers, each worker computes
# g_i, and they collectively reduce (g_1 + ... + g_N) / N. To avoid
# overflow/underflow, we divide by ~sqrt(N) before/after the reduction.
factor: int = 1
while (
data_parallel_world_size % factor == 0
and data_parallel_world_size / factor > factor
):
factor *= 2
factor = float(factor)
self._grad_divide_factors = (factor, data_parallel_world_size / factor)

def lazy_init(self):
# Lazy init should be idempotent
param_names_on_meta = [
Expand All @@ -207,7 +181,6 @@ def lazy_init(self):
# Initialize mixed precision attributes lazily in case the user changes
# the parameter dtypes after construction time but before forward
self._init_mp_dtypes()
self._init_grad_divide_factors()
self._register_state_dict_hooks()

# Runtime #
Expand Down Expand Up @@ -346,7 +319,6 @@ def post_backward(self, *unused: Any):
self._orig_dtype,
self._reduce_dtype,
self.device,
self._grad_divide_factors,
self._all_reduce_process_group
if self._is_hsdp and self.all_reduce_grads
else None,
Expand Down