Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .ci/pytorch/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ test_inductor_distributed() {
pytest test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_train_parity_2d_mlp
pytest test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_train_parity_hsdp
pytest test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_train_parity_2d_transformer_checkpoint_resume
pytest test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_gradient_accumulation
pytest test/distributed/_composable/fsdp/test_fully_shard_frozen.py
pytest test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py -k test_compute_dtype
pytest test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py -k test_reduce_dtype
Expand Down
47 changes: 30 additions & 17 deletions test/distributed/_composable/fsdp/test_fully_shard_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
FSDPTestMultiThread,
MLP,
patch_all_gather,
patch_all_reduce,
patch_reduce_scatter,
test_compiled_fsdp,
)
Expand Down Expand Up @@ -649,16 +650,21 @@ def _test_train_shared_params(
class TestFullyShardGradientAccumulation(FSDPTest):
@property
def world_size(self) -> int:
return min(2, torch.cuda.device_count())
return min(4, torch.cuda.device_count())

@skip_if_lt_x_gpu(2)
def test_gradient_accumulation(self):
"""
Tests gradient accumulation with/without gradient reduction and
with/without resharding after backward.
"""
meshes = [init_device_mesh("cuda", (self.world_size,))] # always test FSDP
if self.world_size == 4: # test HSDP too if enough GPUs
shard_size, replicate_size = 2, 2
meshes.append(init_device_mesh("cuda", (replicate_size, shard_size)))
self.run_subtests(
{
"mesh": meshes,
"reshard_after_forward": [True, False, 2],
# "all": disable reduce-scatter for all modules
# "root_only": disable reduce-scatter for root's linear only
Expand All @@ -672,6 +678,7 @@ def test_gradient_accumulation(self):

def _test_gradient_accumulation(
self,
mesh: DeviceMesh,
reshard_after_forward: Union[bool, int],
mode: str,
reshard_after_backward: bool,
Expand All @@ -691,15 +698,13 @@ def _test_gradient_accumulation(
global_batch_size = local_batch_size * self.world_size
if mode == "some_mlps":
num_mlps_to_disable_reduce_scatter = 2
model = nn.Sequential(
*(
[nn.Linear(lin_dim, lin_dim)]
+ [MLP(lin_dim, torch.device("cpu")) for _ in range(num_mlps)]
)
)
modules = [nn.Linear(lin_dim, lin_dim)]
modules.extend(MLP(lin_dim) for _ in range(num_mlps))
model = nn.Sequential(*modules)
ref_model = copy.deepcopy(model).cuda()
fully_shard_fn = functools.partial(
fully_shard,
mesh=mesh,
reshard_after_forward=reshard_after_forward,
offload_policy=offload_policy,
)
Expand All @@ -709,10 +714,11 @@ def _test_gradient_accumulation(
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)

# TODO: Migrate to `CommDebugMode` once it supports c10d collectives.
orig_all_gather = dist.all_gather_into_tensor
all_gather_count = 0
orig_reduce_scatter = dist.reduce_scatter_tensor
reduce_scatter_count = 0
orig_all_reduce = dist.all_reduce
all_gather_count, reduce_scatter_count, all_reduce_count = 0, 0, 0

def all_gather_with_count(*args, **kwargs):
nonlocal all_gather_count
Expand All @@ -724,11 +730,16 @@ def reduce_scatter_with_count(*args, **kwargs):
reduce_scatter_count += 1
return orig_reduce_scatter(*args, **kwargs)

def all_reduce_with_count(*args, **kwargs):
nonlocal all_reduce_count
all_reduce_count += 1
return orig_all_reduce(*args, **kwargs)

torch.manual_seed(1) # same on all ranks
for iter_idx in range(5):
with patch_all_gather(all_gather_with_count), patch_reduce_scatter(
reduce_scatter_with_count
):
), patch_all_reduce(all_reduce_with_count):
for microbatch_idx in range(num_microbatches):
is_last_microbatch = microbatch_idx == num_microbatches - 1
if mode == "all":
Expand Down Expand Up @@ -756,10 +767,7 @@ def reduce_scatter_with_count(*args, **kwargs):
* local_batch_size
].detach()
losses: List[torch.Tensor] = []
for _model, _optim, inp in (
(ref_model, ref_optim, global_inp),
(model, optim, local_inp),
):
for _model, inp in ((ref_model, global_inp), (model, local_inp)):
losses.append(_model(inp).sum())
losses[-1].backward()
dist.all_reduce(losses[1]) # partial -> replicated
Expand All @@ -778,7 +786,13 @@ def reduce_scatter_with_count(*args, **kwargs):
# Expect additional reduce-scatters for all MLPs
expected_reduce_scatter_count += (num_mlps) * (num_microbatches - 1)
self.assertEqual(reduce_scatter_count, expected_reduce_scatter_count)
reduce_scatter_count = 0
# Exclude the loss all-reduce per microbatch in our training loop
all_reduce_count -= num_microbatches
if mesh.ndim == 2:
self.assertEqual(all_reduce_count, expected_reduce_scatter_count)
else:
self.assertEqual(all_reduce_count, 0)
reduce_scatter_count = all_reduce_count = 0

# Expect one all-gather per MLP plus one for the root's linear in
# the first microbatch's forward
Expand Down Expand Up @@ -872,8 +886,7 @@ def _test_1f1b_microbatching(
ref_losses.append(ref_model(inp).sum())
ref_losses[-1].backward()
for param in ref_model.parameters():
dist.all_reduce(param.grad)
param.grad.detach().div_(self.world_size)
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG)

for loss, ref_loss in zip(losses, ref_losses):
self.assertEqual(loss, ref_loss)
Expand Down
21 changes: 10 additions & 11 deletions torch/distributed/_composable/fsdp/_fsdp_param_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def _init_mp_dtypes(self) -> None:
def _init_grad_divide_factors(self):
data_parallel_world_size = 1
data_parallel_world_size *= self.mesh_info.shard_mesh_size
if isinstance(self.mesh_info, HSDPMeshInfo):
if self._is_hsdp:
data_parallel_world_size *= self.mesh_info.replicate_mesh_size
if self._reduce_dtype in (torch.float32, torch.bfloat16):
# Use NCCL's AVG op to divide after reduction since it is more
Expand Down Expand Up @@ -348,7 +348,7 @@ def post_backward(self, *unused: Any):
self.device,
self._grad_divide_factors,
self._all_reduce_process_group
if self._should_all_reduce_grads()
if self._is_hsdp and self.all_reduce_grads
else None,
self.comm_ctx.all_reduce_stream,
)
Expand Down Expand Up @@ -481,6 +481,10 @@ def _use_post_forward_mesh(self) -> bool:
and self.mesh_info != self.post_forward_mesh_info
)

@property
def _is_hsdp(self) -> bool:
return isinstance(self.mesh_info, HSDPMeshInfo)

@property
def _all_gather_process_group(self) -> dist.ProcessGroup:
mesh_info = (
Expand All @@ -493,18 +497,13 @@ def _all_gather_process_group(self) -> dist.ProcessGroup:

@property
def _reduce_scatter_process_group(self) -> dist.ProcessGroup:
mesh_info = self.mesh_info
assert isinstance(mesh_info, FSDPMeshInfo)
return mesh_info.shard_process_group
assert isinstance(self.mesh_info, FSDPMeshInfo)
return self.mesh_info.shard_process_group

@property
def _all_reduce_process_group(self) -> dist.ProcessGroup:
mesh_info = self.mesh_info
assert isinstance(mesh_info, HSDPMeshInfo)
return mesh_info.replicate_process_group

def _should_all_reduce_grads(self) -> bool:
return isinstance(self.mesh_info, HSDPMeshInfo) and self.all_reduce_grads
assert isinstance(self.mesh_info, HSDPMeshInfo)
return self.mesh_info.replicate_process_group


def _get_param_module_infos(
Expand Down
12 changes: 12 additions & 0 deletions torch/testing/_internal/common_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,18 @@ def patch_reduce_scatter(new_reduce_scatter_tensor: Callable):
dist.reduce_scatter_tensor = orig_reduce_scatter


@contextlib.contextmanager
def patch_all_reduce(new_all_reduce: Callable):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hopefully after #125475 landed lots of these would be simplified!

orig_all_reduce = dist.all_reduce
dist.barrier()
dist.all_reduce = new_all_reduce
try:
yield
finally:
dist.barrier()
dist.all_reduce = orig_all_reduce


@no_type_check
@contextlib.contextmanager
def patch_unshard(new_unshard: Callable):
Expand Down