Skip to content

Commit 05f5697

Browse files
author
Andrew Gu
committed
Update on "[FSDP2] Computed grad divide factors at runtime"
cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k [ghstack-poisoned]
2 parents 4824891 + 1ff0231 commit 05f5697

File tree

1 file changed

+24
-21
lines changed

1 file changed

+24
-21
lines changed

test/distributed/_composable/fsdp/test_fully_shard_init.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -738,51 +738,54 @@ def test_hsdp_broadcast_across_replicas(self):
738738
)
739739
model_args = ModelArgs()
740740
model = Transformer(model_args)
741+
# Add a buffer to show that this flow works for buffers too
742+
model.register_buffer("buf", torch.randn((model_args.dim,)))
741743
for module in model.modules():
742744
if isinstance(module, TransformerBlock):
743745
fully_shard(module, mesh=mesh)
744746
fully_shard(model, mesh=mesh)
745747

746-
# Only preserve the model parameters on the replicate mesh's rank 0
748+
# Only preserve the model states on the replicate mesh's rank 0
747749
if mesh.get_local_rank("replicate") > 0:
748-
print(f"[Rank {self.rank}] filling with 1337")
749-
for param in model.parameters():
750-
param.detach().fill_(1337)
750+
for tensor in itertools.chain(model.parameters(), model.buffers()):
751+
tensor.detach().fill_(1337)
751752

752753
# Check that replicas are different
753-
for param_name, param in model.named_parameters():
754-
local_param = param.to_local()
755-
local_param_list = [
756-
torch.empty_like(local_param) for _ in range(mesh["replicate"].size())
754+
for tensor in itertools.chain(model.parameters(), model.buffers()):
755+
local_tensor = tensor.to_local() if isinstance(tensor, DTensor) else tensor
756+
local_tensor_list = [
757+
torch.empty_like(local_tensor) for _ in range(mesh["replicate"].size())
757758
]
758759
dist.all_gather(
759-
local_param_list, local_param, group=mesh.get_group("replicate")
760+
local_tensor_list, local_tensor, group=mesh.get_group("replicate")
760761
)
761-
for other_local_param in local_param_list[1:]:
762-
self.assertEqual(other_local_param.shape, local_param_list[0].shape)
763-
self.assertNotEqual(other_local_param, local_param_list[0])
762+
for other_local_tensor in local_tensor_list[1:]:
763+
self.assertEqual(other_local_tensor.shape, local_tensor_list[0].shape)
764+
self.assertNotEqual(other_local_tensor, local_tensor_list[0])
764765

765766
# Broadcast from replicate mesh's rank 0
766767
replicate_group = mesh.get_group("replicate")
767-
for param in model.parameters():
768+
for tensor in itertools.chain(model.parameters(), model.buffers()):
768769
# E.g. for mesh [[0, 1, 2, 3], [4, 5, 6, 7]] sharding on dim-1 and
769770
# replicating on dim-0, broadcast with sources 0, 1, 2, 3
770771
src_rank = dist.get_process_group_ranks(replicate_group)[0]
771772
torch.distributed.broadcast(
772-
param.to_local(), src=src_rank, group=replicate_group
773+
tensor.to_local() if isinstance(tensor, DTensor) else tensor,
774+
src=src_rank,
775+
group=replicate_group,
773776
)
774777

775778
# Check that replicas are the same
776-
for param_name, param in model.named_parameters():
777-
local_param = param.to_local()
778-
local_param_list = [
779-
torch.empty_like(local_param) for _ in range(mesh["replicate"].size())
779+
for tensor in itertools.chain(model.parameters(), model.buffers()):
780+
local_tensor = tensor.to_local() if isinstance(tensor, DTensor) else tensor
781+
local_tensor_list = [
782+
torch.empty_like(local_tensor) for _ in range(mesh["replicate"].size())
780783
]
781784
dist.all_gather(
782-
local_param_list, local_param, group=mesh.get_group("replicate")
785+
local_tensor_list, local_tensor, group=mesh.get_group("replicate")
783786
)
784-
for other_local_param in local_param_list[1:]:
785-
self.assertEqual(other_local_param, local_param_list[0])
787+
for other_local_tensor in local_tensor_list[1:]:
788+
self.assertEqual(other_local_tensor, local_tensor_list[0])
786789

787790
# Check that we can run an iteration without erroring
788791
inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda")

0 commit comments

Comments
 (0)