Skip to content

Commit ff061ba

Browse files
wanchaolpytorchmergebot
authored andcommitted
[comm_mode] adding some initial c10d ops to CommDebugMode (#125475)
looks like we can make it work :) Pull Request resolved: #125475 Approved by: https://github.com/awgu
1 parent d4727fd commit ff061ba

File tree

5 files changed

+48
-24
lines changed

5 files changed

+48
-24
lines changed

test/distributed/_tensor/debug/test_comm_mode.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99

1010
from torch.distributed._tensor.debug.comm_mode import CommDebugMode
1111
from torch.distributed._tensor.placement_types import Shard
12+
from torch.testing._internal.common_distributed import requires_nccl
1213
from torch.testing._internal.common_utils import run_tests, TestCase
1314
from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule
1415
from torch.testing._internal.distributed.fake_pg import FakeStore
1516

1617
c10d_functional = torch.ops.c10d_functional
18+
c10d_ops = torch.ops.c10d
1719

1820

1921
class TestCommMode(TestCase):
@@ -79,6 +81,26 @@ def f(x, y):
7981
self.assertEqual(comm_counts[c10d_functional.all_gather_into_tensor], 1)
8082
self.assertEqual(comm_counts[c10d_functional.reduce_scatter_tensor], 0)
8183

84+
@requires_nccl()
85+
def test_comm_mode_with_c10d(self):
86+
world_pg = self.world_pg
87+
88+
inp = torch.rand(2, 8, 16).cuda()
89+
all_gather_out = inp.new_empty(self.world_size * 2, 8, 16)
90+
91+
comm_mode = CommDebugMode()
92+
with comm_mode:
93+
dist.all_reduce(inp)
94+
dist.all_gather_into_tensor(all_gather_out, inp)
95+
dist.reduce_scatter_tensor(inp, all_gather_out)
96+
dist.broadcast(inp, 0)
97+
98+
comm_counts = comm_mode.get_comm_counts()
99+
self.assertEqual(comm_counts[c10d_ops.allreduce_], 1)
100+
self.assertEqual(comm_counts[c10d_ops._allgather_base_], 1)
101+
self.assertEqual(comm_counts[c10d_ops._reduce_scatter_base_], 1)
102+
self.assertEqual(comm_counts[c10d_ops.broadcast_], 1)
103+
82104

83105
if __name__ == "__main__":
84106
run_tests()

test/distributed/_tensor/test_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,8 @@ def test_fsdp1_tp_2d_dtensor_local_shards_and_offsets(self):
144144
global_tensor, tp_mesh, placements=[Shard(0)]
145145
)
146146
dtensor_2d = DTensor.from_local(
147-
dtensor_tp.to_local(), mesh_2d, [Replicate(), Shard(0)]
147+
dtensor_tp.to_local(), mesh_2d, [Replicate(), Shard(0)], run_check=False
148148
).redistribute(mesh_2d, [Shard(0), Shard(0)])
149-
self.assertEqual(len(comm_mode.get_comm_counts()), 1)
150149
self.assertEqual(
151150
comm_mode.get_comm_counts()[c10d_functional.all_gather_into_tensor], 1
152151
)
@@ -196,7 +195,6 @@ def test_fsdp2_tp_2d_dtensor_local_shards_and_offsets(self):
196195
stride=global_tensor.stride(),
197196
)
198197

199-
self.assertEqual(len(comm_mode.get_comm_counts()), 0)
200198
self.assertEqual(
201199
comm_mode.get_comm_counts()[c10d_functional.all_gather_into_tensor], 0
202200
)

test/distributed/fsdp/test_fsdp_tp_integration.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def _get_grads_as_flattened(
201201
all_grads_as_flattened = torch.cat(
202202
[torch.empty_like(local_grads_as_flattened) for _ in range(fsdp_pg.size())]
203203
).contiguous()
204-
dist._all_gather_base(
204+
dist.all_gather_into_tensor(
205205
all_grads_as_flattened, local_grads_as_flattened, group=fsdp_pg
206206
)
207207
if not uses_tp:
@@ -387,11 +387,16 @@ def forward(self, x):
387387
fsdp_2d_model(torch.rand(2, 10).cuda(self.rank)).sum().backward()
388388

389389
funcol = torch.ops.c10d_functional
390+
c10d_ops = torch.ops.c10d
390391
comm_counts = comm_mode.get_comm_counts()
391-
self.assertEqual(comm_mode.get_total_counts(), 5)
392+
self.assertEqual(comm_mode.get_total_counts(), 7)
393+
# TP comms
392394
self.assertEqual(comm_counts[funcol.reduce_scatter_tensor], 2)
393395
self.assertEqual(comm_counts[funcol.all_gather_into_tensor], 2)
394396
self.assertEqual(comm_counts[funcol.all_reduce], 1)
397+
# FSDP comms
398+
self.assertEqual(comm_counts[c10d_ops._allgather_base_], 1)
399+
self.assertEqual(comm_counts[c10d_ops._reduce_scatter_base_], 1)
395400

396401
grads = [p.grad for p in fsdp_2d_model.parameters() if p.grad is not None]
397402

test/distributed/tensor/parallel/test_tp_style.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,8 @@ def test_colwise_parallel_style(self):
4949
model = nn.Linear(16, 16, device=self.device_type)
5050

5151
default_col_parallel = ColwiseParallel()
52+
colwise_mod = parallelize_module(deepcopy(model), mesh, default_col_parallel)
5253
with comm_mode:
53-
colwise_mod = parallelize_module(
54-
deepcopy(model), mesh, default_col_parallel
55-
)
5654
out = colwise_mod(tensor)
5755
# ensure output shard on the last dim
5856
self.assertEqual(out.shape, (8, 16 // self.world_size))
@@ -65,10 +63,8 @@ def test_colwise_parallel_style(self):
6563
self.assertEqual(comm_mode.get_total_counts(), 1)
6664

6765
sharded_col_parallel = ColwiseParallel(input_layouts=Shard(0))
66+
colwise_mod = parallelize_module(deepcopy(model), mesh, sharded_col_parallel)
6867
with comm_mode:
69-
colwise_mod = parallelize_module(
70-
deepcopy(model), mesh, sharded_col_parallel
71-
)
7268
out = colwise_mod(tensor)
7369
# ensure output shard on the last dim
7470
self.assertEqual(out.shape, (8 * self.world_size, 16 // self.world_size))
@@ -94,10 +90,8 @@ def test_colwise_parallel_embedding(self):
9490
model = nn.Embedding(16, 16, device=self.device_type)
9591

9692
default_col_parallel = ColwiseParallel()
93+
colwise_mod = parallelize_module(deepcopy(model), mesh, default_col_parallel)
9794
with comm_mode:
98-
colwise_mod = parallelize_module(
99-
deepcopy(model), mesh, default_col_parallel
100-
)
10195
out = colwise_mod(tensor)
10296
# ensure output shard on the last dim
10397
self.assertEqual(out.shape, (4, 2, 16 // self.world_size))
@@ -119,10 +113,8 @@ def test_rowwise_parallel_style(self):
119113
model = nn.Linear(16, 16, device=self.device_type)
120114

121115
default_row_parallel = RowwiseParallel()
116+
rowwise_mod = parallelize_module(deepcopy(model), mesh, default_row_parallel)
122117
with comm_mode:
123-
rowwise_mod = parallelize_module(
124-
deepcopy(model), mesh, default_row_parallel
125-
)
126118
out = rowwise_mod(tensor)
127119
# ensure output replicated
128120
self.assertEqual(out.shape, (8, 16))
@@ -135,10 +127,8 @@ def test_rowwise_parallel_style(self):
135127
self.assertEqual(comm_mode.get_total_counts(), 1)
136128

137129
sharded_row_parallel = RowwiseParallel(output_layouts=Shard(0))
130+
rowwise_mod = parallelize_module(deepcopy(model), mesh, sharded_row_parallel)
138131
with comm_mode:
139-
rowwise_mod = parallelize_module(
140-
deepcopy(model), mesh, sharded_row_parallel
141-
)
142132
out = rowwise_mod(tensor)
143133
# ensure output replicated
144134
self.assertEqual(out.shape, (8 // self.world_size, 16))
@@ -163,10 +153,10 @@ def test_rowwise_parallel_embedding(self):
163153
tensor = torch.arange(8, device=self.device_type).reshape(4, 2)
164154
model = nn.Embedding(16, 16, device=self.device_type)
165155

156+
rowwise_mod = parallelize_module(
157+
deepcopy(model), mesh, RowwiseParallel(input_layouts=Replicate())
158+
)
166159
with comm_mode:
167-
rowwise_mod = parallelize_module(
168-
deepcopy(model), mesh, RowwiseParallel(input_layouts=Replicate())
169-
)
170160
out = rowwise_mod(tensor)
171161
# ensure output shard on the last dim
172162
self.assertEqual(out.shape, (4, 2, 16))

torch/distributed/_tensor/debug/comm_mode.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
funcol_native = torch.ops._c10d_functional
1010
funcol_py = torch.ops.c10d_functional
1111
funcol_autograd = torch.ops._c10d_functional_autograd
12+
c10d_ops = torch.ops.c10d
1213

1314
NATIVE_TO_PY_MAPPING = {
1415
funcol_native.all_gather_into_tensor: funcol_py.all_gather_into_tensor,
@@ -22,6 +23,13 @@
2223
funcol_autograd.all_to_all_single: funcol_py.all_to_all_single,
2324
}
2425

26+
c10d_collective_ops = {
27+
c10d_ops.allreduce_,
28+
c10d_ops._allgather_base_,
29+
c10d_ops._reduce_scatter_base_,
30+
c10d_ops.broadcast_,
31+
}
32+
2533

2634
class CommDebugMode(TorchDispatchMode):
2735
"""
@@ -88,7 +96,8 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):
8896
# the need to modify all tests to accommodate the two implementations,
8997
# we make CommDebugMode translate native funcol ops into legacy funcol
9098
# ops until the migration finishes.
91-
if func_packet in self.comm_registry:
99+
100+
if func_packet in self.comm_registry or func_packet in c10d_collective_ops:
92101
if func_packet in NATIVE_TO_PY_MAPPING:
93102
func_packet = NATIVE_TO_PY_MAPPING[func_packet]
94103
self.comm_counts[func_packet] += 1

0 commit comments

Comments
 (0)