Skip to content

Commit 6e5f736

Browse files
H-Huangpytorchmergebot
authored andcommitted
[15/N] Add allreduce_coalesced custom op with CPU/CUDA implementations (#88846)
Differential Revision: [D41227740](https://our.internmc.facebook.com/intern/diff/D41227740) Pull Request resolved: #88846 Approved by: https://github.com/kwen2501
1 parent ae2c668 commit 6e5f736

File tree

7 files changed

+102
-3
lines changed

7 files changed

+102
-3
lines changed

test/distributed/test_c10d_common.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1503,6 +1503,21 @@ def _test_collectives(self, backend):
15031503
with self.subTest(collective=collective, args=args):
15041504
self._call_collective_with_varying_tensors(backend, collective, *args)
15051505

1506+
def _test_allreduce_coalesced(self, backend):
1507+
store = dist.FileStore(self.file_name, self.world_size)
1508+
dist.init_process_group(
1509+
backend,
1510+
world_size=self.world_size,
1511+
rank=self.rank,
1512+
store=store,
1513+
)
1514+
# TODO: this will be updated in the future to not be backend specific
1515+
device = "cuda" if backend == "nccl" else "cpu"
1516+
tensors = [torch.ones(10, 10, device=torch.device(device))]
1517+
dist.all_reduce_coalesced(tensors, dist.ReduceOp.SUM)
1518+
for tensor in tensors:
1519+
self.assertEqual(tensor, torch.ones(10, 10) * self.world_size)
1520+
15061521
class CompilerTest(MultiProcessTestCase):
15071522
def setUp(self):
15081523
super(CompilerTest, self).setUp()

test/distributed/test_c10d_gloo.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2363,6 +2363,10 @@ class GlooProcessGroupWithDispatchedCollectivesTests(test_c10d_common.ProcessGro
23632363
def test_collectives(self):
23642364
self._test_collectives(backend="gloo")
23652365

2366+
@requires_gloo()
2367+
def test_allreduce_coalesced(self):
2368+
self._test_allreduce_coalesced(backend="gloo")
2369+
23662370
class CompilerTest(test_c10d_common.CompilerTest):
23672371

23682372
@property

test/distributed/test_c10d_nccl.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2953,6 +2953,11 @@ class NcclProcessGroupWithDispatchedCollectivesTests(test_c10d_common.ProcessGro
29532953
def test_collectives(self):
29542954
self._test_collectives(backend="nccl")
29552955

2956+
@requires_nccl()
2957+
@skip_if_lt_x_gpu(1)
2958+
def test_allreduce_coalesced(self):
2959+
self._test_allreduce_coalesced(backend="nccl")
2960+
29562961
if __name__ == "__main__":
29572962
assert (
29582963
not torch.cuda._initialized

torch/csrc/distributed/c10d/Ops.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,19 @@ std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> allreduce_(
4040
std::move(tensor_vec), work);
4141
}
4242

43+
c10::intrusive_ptr<Work> allreduce_coalesced_(
44+
at::TensorList tensors,
45+
const c10::intrusive_ptr<ProcessGroup>& process_group,
46+
const c10::intrusive_ptr<ReduceOp>& reduce_op,
47+
int64_t timeout) {
48+
auto tensor_vec = tensors.vec();
49+
AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{};
50+
opts.reduceOp = *reduce_op.get();
51+
opts.timeout = std::chrono::milliseconds(timeout);
52+
53+
return process_group->allreduce_coalesced(tensor_vec, opts);
54+
}
55+
4356
c10::intrusive_ptr<Work> reduce_(
4457
at::TensorList tensors,
4558
const c10::intrusive_ptr<ProcessGroup>& process_group,
@@ -177,6 +190,10 @@ TORCH_LIBRARY(c10d, m) {
177190
m.def(
178191
"allreduce_",
179192
dispatch(c10::DispatchKey::CompositeExplicitAutograd, allreduce_));
193+
m.def(
194+
"allreduce_coalesced_",
195+
dispatch(
196+
c10::DispatchKey::CompositeExplicitAutograd, allreduce_coalesced_));
180197
m.def(
181198
"allgather_",
182199
dispatch(c10::DispatchKey::CompositeExplicitAutograd, allgather_));
@@ -249,6 +266,25 @@ c10::intrusive_ptr<Work> allreduce(
249266
opts.timeout.count()));
250267
}
251268

269+
c10::intrusive_ptr<Work> allreduce_coalesced(
270+
const c10::intrusive_ptr<ProcessGroup>& process_group,
271+
at::TensorList tensors,
272+
const AllreduceCoalescedOptions& opts) {
273+
static auto op = c10::Dispatcher::singleton()
274+
.findSchemaOrThrow("c10d::allreduce_coalesced_", "")
275+
.typed<c10::intrusive_ptr<::c10d::Work>(
276+
at::TensorList,
277+
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
278+
const c10::intrusive_ptr<::c10d::ReduceOp>&,
279+
int64_t)>();
280+
281+
return op.call(
282+
tensors,
283+
process_group,
284+
c10::make_intrusive<ReduceOp>(opts.reduceOp),
285+
opts.timeout.count());
286+
}
287+
252288
c10::intrusive_ptr<Work> allgather(
253289
const c10::intrusive_ptr<ProcessGroup>& process_group,
254290
const std::vector<std::vector<at::Tensor>>& output_tensors,

torch/csrc/distributed/c10d/Ops.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ TORCH_API c10::intrusive_ptr<Work> allreduce(
2121
at::TensorList tensors,
2222
const AllreduceOptions& opts = {});
2323

24+
TORCH_API c10::intrusive_ptr<Work> allreduce_coalesced(
25+
const c10::intrusive_ptr<ProcessGroup>& process_group,
26+
at::TensorList tensors,
27+
const AllreduceCoalescedOptions& opts = {});
28+
2429
TORCH_API c10::intrusive_ptr<Work> allgather(
2530
const c10::intrusive_ptr<ProcessGroup>& process_group,
2631
const std::vector<std::vector<at::Tensor>>& output_tensors,

torch/csrc/distributed/c10d/OpsImpl.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,32 @@ std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> allreduce_cuda_(
149149
std::move(tensor_vec), work);
150150
}
151151

152+
c10::intrusive_ptr<Work> allreduce_coalesced_cpu_(
153+
at::TensorList tensors,
154+
const c10::intrusive_ptr<ProcessGroup>& process_group,
155+
const c10::intrusive_ptr<ReduceOp>& reduce_op,
156+
int64_t timeout) {
157+
auto tensor_vec = tensors.vec();
158+
AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{};
159+
opts.reduceOp = *reduce_op.get();
160+
opts.timeout = std::chrono::milliseconds(timeout);
161+
162+
return process_group->allreduce_coalesced(tensor_vec, opts);
163+
}
164+
165+
c10::intrusive_ptr<Work> allreduce_coalesced_cuda_(
166+
at::TensorList tensors,
167+
const c10::intrusive_ptr<ProcessGroup>& process_group,
168+
const c10::intrusive_ptr<ReduceOp>& reduce_op,
169+
int64_t timeout) {
170+
auto tensor_vec = tensors.vec();
171+
AllreduceCoalescedOptions opts = AllreduceCoalescedOptions{};
172+
opts.reduceOp = *reduce_op.get();
173+
opts.timeout = std::chrono::milliseconds(timeout);
174+
175+
return process_group->allreduce_coalesced(tensor_vec, opts);
176+
}
177+
152178
std::tuple<std::vector<std::vector<at::Tensor>>, c10::intrusive_ptr<Work>>
153179
allgather_cpu_(
154180
const std::vector<std::vector<at::Tensor>>& output_tensors,
@@ -367,6 +393,14 @@ TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
367393
m.impl("allreduce_", allreduce_cuda_);
368394
}
369395

396+
TORCH_LIBRARY_IMPL(c10d, CPU, m) {
397+
m.impl("allreduce_coalesced_", allreduce_coalesced_cpu_);
398+
}
399+
400+
TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
401+
m.impl("allreduce_coalesced_", allreduce_coalesced_cuda_);
402+
}
403+
370404
TORCH_LIBRARY_IMPL(c10d, CPU, m) {
371405
m.impl("allgather_", allgather_cpu_);
372406
}

torch/csrc/distributed/c10d/init.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,10 +1134,10 @@ that adds a prefix to each key inserted to the store.
11341134

11351135
.def(
11361136
"allreduce_coalesced",
1137-
[](::c10d::ProcessGroup& self,
1138-
std::vector<at::Tensor>& xs,
1137+
[](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
1138+
const std::vector<at::Tensor>& xs,
11391139
::c10d::AllreduceCoalescedOptions opts) {
1140-
return self.allreduce_coalesced(xs, opts);
1140+
return ::c10d::ops::allreduce_coalesced(self, xs, opts);
11411141
},
11421142
py::arg("tensors"),
11431143
py::arg("opts") = ::c10d::AllreduceCoalescedOptions(),

0 commit comments

Comments
 (0)