Skip to content

Commit ed26b4f

Browse files
committed
Update on "[comm_mode] adding some initial c10d ops to CommDebugMode"
looks like we can make it work :) cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k [ghstack-poisoned]
1 parent 1f5b128 commit ed26b4f

File tree

2 files changed

+15
-20
lines changed

2 files changed

+15
-20
lines changed

test/distributed/fsdp/test_fsdp_tp_integration.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def _get_grads_as_flattened(
201201
all_grads_as_flattened = torch.cat(
202202
[torch.empty_like(local_grads_as_flattened) for _ in range(fsdp_pg.size())]
203203
).contiguous()
204-
dist._all_gather_base(
204+
dist.all_gather_into_tensor(
205205
all_grads_as_flattened, local_grads_as_flattened, group=fsdp_pg
206206
)
207207
if not uses_tp:
@@ -387,11 +387,16 @@ def forward(self, x):
387387
fsdp_2d_model(torch.rand(2, 10).cuda(self.rank)).sum().backward()
388388

389389
funcol = torch.ops.c10d_functional
390+
c10d_ops = torch.ops.c10d
390391
comm_counts = comm_mode.get_comm_counts()
391-
self.assertEqual(comm_mode.get_total_counts(), 5)
392+
self.assertEqual(comm_mode.get_total_counts(), 7)
393+
# TP comms
392394
self.assertEqual(comm_counts[funcol.reduce_scatter_tensor], 2)
393395
self.assertEqual(comm_counts[funcol.all_gather_into_tensor], 2)
394396
self.assertEqual(comm_counts[funcol.all_reduce], 1)
397+
# FSDP comms
398+
self.assertEqual(comm_counts[c10d_ops._allgather_base_], 1)
399+
self.assertEqual(comm_counts[c10d_ops._reduce_scatter_base_], 1)
395400

396401
grads = [p.grad for p in fsdp_2d_model.parameters() if p.grad is not None]
397402

test/distributed/tensor/parallel/test_tp_style.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,8 @@ def test_colwise_parallel_style(self):
4949
model = nn.Linear(16, 16, device=self.device_type)
5050

5151
default_col_parallel = ColwiseParallel()
52+
colwise_mod = parallelize_module(deepcopy(model), mesh, default_col_parallel)
5253
with comm_mode:
53-
colwise_mod = parallelize_module(
54-
deepcopy(model), mesh, default_col_parallel
55-
)
5654
out = colwise_mod(tensor)
5755
# ensure output shard on the last dim
5856
self.assertEqual(out.shape, (8, 16 // self.world_size))
@@ -65,10 +63,8 @@ def test_colwise_parallel_style(self):
6563
self.assertEqual(comm_mode.get_total_counts(), 1)
6664

6765
sharded_col_parallel = ColwiseParallel(input_layouts=Shard(0))
66+
colwise_mod = parallelize_module(deepcopy(model), mesh, sharded_col_parallel)
6867
with comm_mode:
69-
colwise_mod = parallelize_module(
70-
deepcopy(model), mesh, sharded_col_parallel
71-
)
7268
out = colwise_mod(tensor)
7369
# ensure output shard on the last dim
7470
self.assertEqual(out.shape, (8 * self.world_size, 16 // self.world_size))
@@ -94,10 +90,8 @@ def test_colwise_parallel_embedding(self):
9490
model = nn.Embedding(16, 16, device=self.device_type)
9591

9692
default_col_parallel = ColwiseParallel()
93+
colwise_mod = parallelize_module(deepcopy(model), mesh, default_col_parallel)
9794
with comm_mode:
98-
colwise_mod = parallelize_module(
99-
deepcopy(model), mesh, default_col_parallel
100-
)
10195
out = colwise_mod(tensor)
10296
# ensure output shard on the last dim
10397
self.assertEqual(out.shape, (4, 2, 16 // self.world_size))
@@ -119,10 +113,8 @@ def test_rowwise_parallel_style(self):
119113
model = nn.Linear(16, 16, device=self.device_type)
120114

121115
default_row_parallel = RowwiseParallel()
116+
rowwise_mod = parallelize_module(deepcopy(model), mesh, default_row_parallel)
122117
with comm_mode:
123-
rowwise_mod = parallelize_module(
124-
deepcopy(model), mesh, default_row_parallel
125-
)
126118
out = rowwise_mod(tensor)
127119
# ensure output replicated
128120
self.assertEqual(out.shape, (8, 16))
@@ -135,10 +127,8 @@ def test_rowwise_parallel_style(self):
135127
self.assertEqual(comm_mode.get_total_counts(), 1)
136128

137129
sharded_row_parallel = RowwiseParallel(output_layouts=Shard(0))
130+
rowwise_mod = parallelize_module(deepcopy(model), mesh, sharded_row_parallel)
138131
with comm_mode:
139-
rowwise_mod = parallelize_module(
140-
deepcopy(model), mesh, sharded_row_parallel
141-
)
142132
out = rowwise_mod(tensor)
143133
# ensure output replicated
144134
self.assertEqual(out.shape, (8 // self.world_size, 16))
@@ -163,10 +153,10 @@ def test_rowwise_parallel_embedding(self):
163153
tensor = torch.arange(8, device=self.device_type).reshape(4, 2)
164154
model = nn.Embedding(16, 16, device=self.device_type)
165155

156+
rowwise_mod = parallelize_module(
157+
deepcopy(model), mesh, RowwiseParallel(input_layouts=Replicate())
158+
)
166159
with comm_mode:
167-
rowwise_mod = parallelize_module(
168-
deepcopy(model), mesh, RowwiseParallel(input_layouts=Replicate())
169-
)
170160
out = rowwise_mod(tensor)
171161
# ensure output shard on the last dim
172162
self.assertEqual(out.shape, (4, 2, 16))

0 commit comments

Comments
 (0)