Skip to content

Commit 85b3fff

Browse files
author
Andrew Gu
committed
[FSDP2] Added test to show rank 0 broadcast for HSDP replicas
ghstack-source-id: 74bc2a0 Pull Request resolved: #125431
1 parent b03fb49 commit 85b3fff

File tree

1 file changed

+64
-0
lines changed

1 file changed

+64
-0
lines changed

test/distributed/_composable/fsdp/test_fully_shard_init.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,5 +725,69 @@ def test_process_group_init(self):
725725
self.assertEqual(param.grad, ref_param.grad)
726726

727727

728+
class TestFullyShardHSDPBroadcast(FSDPTestMultiThread):
729+
@property
730+
def world_size(self) -> int:
731+
return 4
732+
733+
@unittest.skipIf(not TEST_CUDA, "no cuda")
734+
def test_hsdp_broadcast_across_replicas(self):
735+
shard_size, replicate_size = 2, 2
736+
mesh = init_device_mesh(
737+
"cuda", (replicate_size, shard_size), mesh_dim_names=("replicate", "shard")
738+
)
739+
model_args = ModelArgs()
740+
model = Transformer(model_args)
741+
for module in model.modules():
742+
if isinstance(module, TransformerBlock):
743+
fully_shard(module, mesh=mesh)
744+
fully_shard(model, mesh=mesh)
745+
746+
# Only preserve the model parameters on the replicate mesh's rank 0
747+
if mesh.get_local_rank("replicate") > 0:
748+
print(f"[Rank {self.rank}] filling with 1337")
749+
for param in model.parameters():
750+
param.detach().fill_(1337)
751+
752+
# Check that replicas are different
753+
for param_name, param in model.named_parameters():
754+
local_param = param.to_local()
755+
local_param_list = [
756+
torch.empty_like(local_param) for _ in range(mesh["replicate"].size())
757+
]
758+
dist.all_gather(
759+
local_param_list, local_param, group=mesh.get_group("replicate")
760+
)
761+
for other_local_param in local_param_list[1:]:
762+
self.assertEqual(other_local_param.shape, local_param_list[0].shape)
763+
self.assertNotEqual(other_local_param, local_param_list[0])
764+
765+
# Broadcast from replicate mesh's rank 0
766+
replicate_group = mesh.get_group("replicate")
767+
for param in model.parameters():
768+
# E.g. for mesh [[0, 1, 2, 3], [4, 5, 6, 7]] sharding on dim-1 and
769+
# replicating on dim-0, broadcast with sources 0, 1, 2, 3
770+
src_rank = dist.get_process_group_ranks(replicate_group)[0]
771+
torch.distributed.broadcast(
772+
param.to_local(), src=src_rank, group=replicate_group
773+
)
774+
775+
# Check that replicas are the same
776+
for param_name, param in model.named_parameters():
777+
local_param = param.to_local()
778+
local_param_list = [
779+
torch.empty_like(local_param) for _ in range(mesh["replicate"].size())
780+
]
781+
dist.all_gather(
782+
local_param_list, local_param, group=mesh.get_group("replicate")
783+
)
784+
for other_local_param in local_param_list[1:]:
785+
self.assertEqual(other_local_param, local_param_list[0])
786+
787+
# Check that we can run an iteration without erroring
788+
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")
789+
model(inp).sum().backward()
790+
791+
728792
if __name__ == "__main__":
729793
run_tests()

0 commit comments

Comments
 (0)