Skip to content

Commit 54ee411

Browse files
author
Andrew Gu
committed
Update on "[FSDP2] Added HSDP grad acc tests and some minor changes"
This adds HSDP to the existing gradient accumulation tests and includes some minor changes. cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k [ghstack-poisoned]
2 parents 3190fd2 + f1f38c3 commit 54ee411

File tree

1 file changed

+24
-21
lines changed

1 file changed

+24
-21
lines changed

test/distributed/_composable/fsdp/test_fully_shard_init.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -738,51 +738,54 @@ def test_hsdp_broadcast_across_replicas(self):
738738
)
739739
model_args = ModelArgs()
740740
model = Transformer(model_args)
741+
# Add a buffer to show that this flow works for buffers too
742+
model.register_buffer("buf", torch.randn((model_args.dim,)))
741743
for module in model.modules():
742744
if isinstance(module, TransformerBlock):
743745
fully_shard(module, mesh=mesh)
744746
fully_shard(model, mesh=mesh)
745747

746-
# Only preserve the model parameters on the replicate mesh's rank 0
748+
# Only preserve the model states on the replicate mesh's rank 0
747749
if mesh.get_local_rank("replicate") > 0:
748-
print(f"[Rank {self.rank}] filling with 1337")
749-
for param in model.parameters():
750-
param.detach().fill_(1337)
750+
for tensor in itertools.chain(model.parameters(), model.buffers()):
751+
tensor.detach().fill_(1337)
751752

752753
# Check that replicas are different
753-
for param_name, param in model.named_parameters():
754-
local_param = param.to_local()
755-
local_param_list = [
756-
torch.empty_like(local_param) for _ in range(mesh["replicate"].size())
754+
for tensor in itertools.chain(model.parameters(), model.buffers()):
755+
local_tensor = tensor.to_local() if isinstance(tensor, DTensor) else tensor
756+
local_tensor_list = [
757+
torch.empty_like(local_tensor) for _ in range(mesh["replicate"].size())
757758
]
758759
dist.all_gather(
759-
local_param_list, local_param, group=mesh.get_group("replicate")
760+
local_tensor_list, local_tensor, group=mesh.get_group("replicate")
760761
)
761-
for other_local_param in local_param_list[1:]:
762-
self.assertEqual(other_local_param.shape, local_param_list[0].shape)
763-
self.assertNotEqual(other_local_param, local_param_list[0])
762+
for other_local_tensor in local_tensor_list[1:]:
763+
self.assertEqual(other_local_tensor.shape, local_tensor_list[0].shape)
764+
self.assertNotEqual(other_local_tensor, local_tensor_list[0])
764765

765766
# Broadcast from replicate mesh's rank 0
766767
replicate_group = mesh.get_group("replicate")
767-
for param in model.parameters():
768+
for tensor in itertools.chain(model.parameters(), model.buffers()):
768769
# E.g. for mesh [[0, 1, 2, 3], [4, 5, 6, 7]] sharding on dim-1 and
769770
# replicating on dim-0, broadcast with sources 0, 1, 2, 3
770771
src_rank = dist.get_process_group_ranks(replicate_group)[0]
771772
torch.distributed.broadcast(
772-
param.to_local(), src=src_rank, group=replicate_group
773+
tensor.to_local() if isinstance(tensor, DTensor) else tensor,
774+
src=src_rank,
775+
group=replicate_group,
773776
)
774777

775778
# Check that replicas are the same
776-
for param_name, param in model.named_parameters():
777-
local_param = param.to_local()
778-
local_param_list = [
779-
torch.empty_like(local_param) for _ in range(mesh["replicate"].size())
779+
for tensor in itertools.chain(model.parameters(), model.buffers()):
780+
local_tensor = tensor.to_local() if isinstance(tensor, DTensor) else tensor
781+
local_tensor_list = [
782+
torch.empty_like(local_tensor) for _ in range(mesh["replicate"].size())
780783
]
781784
dist.all_gather(
782-
local_param_list, local_param, group=mesh.get_group("replicate")
785+
local_tensor_list, local_tensor, group=mesh.get_group("replicate")
783786
)
784-
for other_local_param in local_param_list[1:]:
785-
self.assertEqual(other_local_param, local_param_list[0])
787+
for other_local_tensor in local_tensor_list[1:]:
788+
self.assertEqual(other_local_tensor, local_tensor_list[0])
786789

787790
# Check that we can run an iteration without erroring
788791
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")

0 commit comments

Comments
 (0)