Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions test/distributed/_composable/fsdp/test_fully_shard_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,5 +725,72 @@ def test_process_group_init(self):
self.assertEqual(param.grad, ref_param.grad)


class TestFullyShardHSDPBroadcast(FSDPTestMultiThread):
@property
def world_size(self) -> int:
return 4

@unittest.skipIf(not TEST_CUDA, "no cuda")
def test_hsdp_broadcast_across_replicas(self):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might wonder, why not always have HSDP broadcast during init time. The issue is that we only need to broadcast if we are initializing from scratch (not a checkpoint). If we are initializing from a checkpoint, then we are already guaranteed that replicas are the same, and broadcasting is wasteful and can affect init time.

shard_size, replicate_size = 2, 2
mesh = init_device_mesh(
"cuda", (replicate_size, shard_size), mesh_dim_names=("replicate", "shard")
)
model_args = ModelArgs()
model = Transformer(model_args)
# Add a buffer to show that this flow works for buffers too
model.register_buffer("buf", torch.randn((model_args.dim,)))
for module in model.modules():
if isinstance(module, TransformerBlock):
fully_shard(module, mesh=mesh)
fully_shard(model, mesh=mesh)

# Only preserve the model states on the replicate mesh's rank 0
if mesh.get_local_rank("replicate") > 0:
for tensor in itertools.chain(model.parameters(), model.buffers()):
tensor.detach().fill_(1337)

# Check that replicas are different
for tensor in itertools.chain(model.parameters(), model.buffers()):
local_tensor = tensor.to_local() if isinstance(tensor, DTensor) else tensor
local_tensor_list = [
torch.empty_like(local_tensor) for _ in range(mesh["replicate"].size())
]
dist.all_gather(
local_tensor_list, local_tensor, group=mesh.get_group("replicate")
)
for other_local_tensor in local_tensor_list[1:]:
self.assertEqual(other_local_tensor.shape, local_tensor_list[0].shape)
self.assertNotEqual(other_local_tensor, local_tensor_list[0])

# Broadcast from replicate mesh's rank 0
replicate_group = mesh.get_group("replicate")
for tensor in itertools.chain(model.parameters(), model.buffers()):
# E.g. for mesh [[0, 1, 2, 3], [4, 5, 6, 7]] sharding on dim-1 and
# replicating on dim-0, broadcast with sources 0, 1, 2, 3
src_rank = dist.get_process_group_ranks(replicate_group)[0]
torch.distributed.broadcast(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Today, in-place c10d broadcast is preferred.

If we want to use functional broadcast:

  1. We need to verify the semantics. We may still need to get the src_rank like we do here, which is confusing since it is the rank with respect to the global process group, not the one passed to broadcast.
  2. We need to swap the newly broadcasted tensor in. Since FSDP holds a reference, we cannot just setattr(module, param_name, broadcasted_param) since that would leave FSDP's reference as stale. We may consider using swap_tensors, but we are blocked by the local tensor padding issue since the broadcasted parameter would not have padding and is actually a view into the padded local tensor.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think a inplace broadcast make sense here!

tensor.to_local() if isinstance(tensor, DTensor) else tensor,
src=src_rank,
group=replicate_group,
)

# Check that replicas are the same
for tensor in itertools.chain(model.parameters(), model.buffers()):
local_tensor = tensor.to_local() if isinstance(tensor, DTensor) else tensor
local_tensor_list = [
torch.empty_like(local_tensor) for _ in range(mesh["replicate"].size())
]
dist.all_gather(
local_tensor_list, local_tensor, group=mesh.get_group("replicate")
)
for other_local_tensor in local_tensor_list[1:]:
self.assertEqual(other_local_tensor, local_tensor_list[0])

# Check that we can run an iteration without erroring
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
model(inp).sum().backward()


if __name__ == "__main__":
run_tests()