Skip to content

Commit 8015078

Browse files
H-Huangpytorchmergebot
authored andcommitted
[21/N] Add alltoall_base custom op with CPU/CUDA implementations (#89813)
Differential Revision: [D41812670](https://our.internmc.facebook.com/intern/diff/D41812670) Pull Request resolved: #89813 Approved by: https://github.com/kwen2501
1 parent e65ee39 commit 8015078

File tree

7 files changed

+119
-15
lines changed

7 files changed

+119
-15
lines changed

test/distributed/test_c10d_common.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1522,6 +1522,20 @@ def _test_allreduce_coalesced(self, backend):
15221522
for tensor in tensors:
15231523
self.assertEqual(tensor, torch.ones(10, 10) * self.world_size)
15241524

1525+
def _test_all_to_all_single(self, backend):
1526+
store = dist.FileStore(self.file_name, self.world_size)
1527+
dist.init_process_group(
1528+
backend,
1529+
world_size=self.world_size,
1530+
rank=self.rank,
1531+
store=store,
1532+
)
1533+
device = "cuda" if backend == "nccl" else "cpu"
1534+
# test alltoall_base
1535+
input_tensor = torch.ones(2, 2, device=torch.device(device))
1536+
output_tensor = torch.zeros(2, 2, device=torch.device(device))
1537+
dist.all_to_all_single(output_tensor, input_tensor)
1538+
15251539
class CompilerTest(MultiProcessTestCase):
15261540
def setUp(self):
15271541
super(CompilerTest, self).setUp()

test/distributed/test_c10d_gloo.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2405,6 +2405,10 @@ def test_collectives(self):
24052405
def test_allreduce_coalesced(self):
24062406
self._test_allreduce_coalesced(backend="gloo")
24072407

2408+
@requires_gloo()
2409+
def test_all_to_all_single(self):
2410+
self._test_all_to_all_single(backend="gloo")
2411+
24082412
@requires_gloo()
24092413
def test_allgather_coalesced(self):
24102414
store = dist.FileStore(self.file_name, self.world_size)

test/distributed/test_c10d_nccl.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2948,6 +2948,11 @@ def test_collectives(self):
29482948
def test_allreduce_coalesced(self):
29492949
self._test_allreduce_coalesced(backend="nccl")
29502950

2951+
@requires_nccl()
2952+
@skip_if_lt_x_gpu(1)
2953+
def test_all_to_all_single(self):
2954+
self._test_all_to_all_single(backend="nccl")
2955+
29512956
@requires_nccl()
29522957
@skip_if_lt_x_gpu(1)
29532958
def test_allgather_base(self):

torch/csrc/distributed/c10d/Ops.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,21 @@ c10::intrusive_ptr<Work> alltoall_(
173173
AllToAllOptions{std::chrono::milliseconds(timeout)});
174174
}
175175

176+
c10::intrusive_ptr<Work> alltoall_base_(
177+
at::Tensor& output,
178+
at::Tensor& input,
179+
const c10::intrusive_ptr<ProcessGroup>& process_group,
180+
std::vector<int64_t> output_split_sizes,
181+
std::vector<int64_t> input_split_sizes,
182+
int64_t timeout) {
183+
return process_group->alltoall_base(
184+
output,
185+
input,
186+
output_split_sizes,
187+
input_split_sizes,
188+
AllToAllOptions{std::chrono::milliseconds(timeout)});
189+
}
190+
176191
c10::intrusive_ptr<Work> barrier(
177192
const c10::intrusive_ptr<ProcessGroup>& process_group,
178193
const std::vector<int64_t>& device_ids,
@@ -271,6 +286,9 @@ TORCH_LIBRARY(c10d, m) {
271286
m.def(
272287
"alltoall_",
273288
dispatch(c10::DispatchKey::CompositeExplicitAutograd, alltoall_));
289+
m.def(
290+
"alltoall_base_",
291+
dispatch(c10::DispatchKey::CompositeExplicitAutograd, alltoall_base_));
274292
m.def(
275293
"barrier",
276294
dispatch(c10::DispatchKey::CompositeExplicitAutograd, barrier));
@@ -523,6 +541,31 @@ c10::intrusive_ptr<Work> alltoall(
523541
output_tensors, input_tensors, process_group, opts.timeout.count());
524542
}
525543

544+
c10::intrusive_ptr<Work> alltoall_base(
545+
const c10::intrusive_ptr<ProcessGroup>& process_group,
546+
at::Tensor& output,
547+
at::Tensor& input,
548+
std::vector<int64_t> output_split_sizes,
549+
std::vector<int64_t> input_split_sizes,
550+
const AllToAllOptions& opts) {
551+
static auto op = c10::Dispatcher::singleton()
552+
.findSchemaOrThrow("c10d::alltoall_base_", "")
553+
.typed<c10::intrusive_ptr<::c10d::Work>(
554+
at::Tensor&,
555+
at::Tensor&,
556+
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
557+
std::vector<int64_t>,
558+
std::vector<int64_t>,
559+
int64_t)>();
560+
return op.call(
561+
output,
562+
input,
563+
process_group,
564+
output_split_sizes,
565+
input_split_sizes,
566+
opts.timeout.count());
567+
}
568+
526569
void monitored_barrier(
527570
const c10::intrusive_ptr<ProcessGroup>& process_group,
528571
const BarrierOptions& opts,

torch/csrc/distributed/c10d/Ops.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,14 @@ TORCH_API c10::intrusive_ptr<Work> scatter(
7373
const std::vector<std::vector<at::Tensor>>& input_tensors,
7474
const ScatterOptions& opts = {});
7575

76+
TORCH_API c10::intrusive_ptr<Work> alltoall_base(
77+
const c10::intrusive_ptr<ProcessGroup>& process_group,
78+
at::Tensor& output,
79+
at::Tensor& input,
80+
const std::vector<int64_t> outputSplitSizes,
81+
const std::vector<int64_t> inputSplitSizes,
82+
const AllToAllOptions& opts = {});
83+
7684
TORCH_API c10::intrusive_ptr<Work> alltoall(
7785
const c10::intrusive_ptr<ProcessGroup>& process_group,
7886
at::TensorList output_tensors,

torch/csrc/distributed/c10d/OpsImpl.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,36 @@ c10::intrusive_ptr<Work> alltoall_cuda_(
399399
AllToAllOptions{std::chrono::milliseconds(timeout)});
400400
}
401401

402+
c10::intrusive_ptr<Work> alltoall_base_cpu_(
403+
at::Tensor& output,
404+
at::Tensor& input,
405+
const c10::intrusive_ptr<ProcessGroup>& process_group,
406+
std::vector<int64_t> output_split_sizes,
407+
std::vector<int64_t> input_split_sizes,
408+
int64_t timeout) {
409+
return process_group->alltoall_base(
410+
output,
411+
input,
412+
output_split_sizes,
413+
input_split_sizes,
414+
AllToAllOptions{std::chrono::milliseconds(timeout)});
415+
}
416+
417+
c10::intrusive_ptr<Work> alltoall_base_cuda_(
418+
at::Tensor& output,
419+
at::Tensor& input,
420+
const c10::intrusive_ptr<ProcessGroup>& process_group,
421+
std::vector<int64_t> output_split_sizes,
422+
std::vector<int64_t> input_split_sizes,
423+
int64_t timeout) {
424+
return process_group->alltoall_base(
425+
output,
426+
input,
427+
output_split_sizes,
428+
input_split_sizes,
429+
AllToAllOptions{std::chrono::milliseconds(timeout)});
430+
}
431+
402432
c10::intrusive_ptr<Work> barrier_cpu(
403433
const c10::intrusive_ptr<ProcessGroup>& process_group,
404434
const std::vector<int64_t>& device_ids,
@@ -558,6 +588,14 @@ TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
558588
m.impl("alltoall_", alltoall_cuda_);
559589
}
560590

591+
TORCH_LIBRARY_IMPL(c10d, CPU, m) {
592+
m.impl("alltoall_base_", alltoall_base_cpu_);
593+
}
594+
595+
TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
596+
m.impl("alltoall_base_", alltoall_base_cuda_);
597+
}
598+
561599
TORCH_LIBRARY_IMPL(c10d, CPU, m) {
562600
m.impl("barrier", barrier_cpu);
563601
}

torch/csrc/distributed/c10d/init.cpp

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1440,34 +1440,26 @@ that adds a prefix to each key inserted to the store.
14401440

14411441
.def(
14421442
"alltoall_base",
1443-
&::c10d::ProcessGroup::alltoall_base,
1444-
py::arg("output_tensor"),
1445-
py::arg("input_tensor"),
1446-
py::arg("output_split_sizes"),
1447-
py::arg("input_split_sizes"),
1448-
py::arg("opts") = ::c10d::AllToAllOptions(),
1449-
py::call_guard<py::gil_scoped_release>())
1450-
1451-
.def(
1452-
"alltoall_base",
1453-
[](::c10d::ProcessGroup& self,
1443+
[](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
14541444
at::Tensor& output,
14551445
at::Tensor& input,
14561446
std::vector<int64_t> outputSplitSizes,
1457-
std::vector<int64_t> inputSplitSizes) {
1458-
return self.alltoall_base(
1447+
std::vector<int64_t> inputSplitSizes,
1448+
const ::c10d::AllToAllOptions& opts) {
1449+
return ::c10d::ops::alltoall_base(
1450+
self,
14591451
output,
14601452
input,
14611453
outputSplitSizes,
14621454
inputSplitSizes,
1463-
::c10d::AllToAllOptions());
1455+
opts);
14641456
},
14651457
py::arg("output"),
14661458
py::arg("input"),
14671459
py::arg("output_split_sizes"),
14681460
py::arg("input_split_sizes"),
1461+
py::arg("opts") = ::c10d::AllToAllOptions(),
14691462
py::call_guard<py::gil_scoped_release>())
1470-
14711463
.def(
14721464
"alltoall",
14731465
[](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,

0 commit comments

Comments
 (0)