Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions test/distributed/test_c10d_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1522,6 +1522,20 @@ def _test_allreduce_coalesced(self, backend):
for tensor in tensors:
self.assertEqual(tensor, torch.ones(10, 10) * self.world_size)

def _test_all_to_all_single(self, backend):
store = dist.FileStore(self.file_name, self.world_size)
dist.init_process_group(
backend,
world_size=self.world_size,
rank=self.rank,
store=store,
)
device = "cuda" if backend == "nccl" else "cpu"
# test alltoall_base
input_tensor = torch.ones(2, 2, device=torch.device(device))
output_tensor = torch.zeros(2, 2, device=torch.device(device))
dist.all_to_all_single(output_tensor, input_tensor)

class CompilerTest(MultiProcessTestCase):
def setUp(self):
super(CompilerTest, self).setUp()
Expand Down
4 changes: 4 additions & 0 deletions test/distributed/test_c10d_gloo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2417,6 +2417,10 @@ def test_collectives(self):
def test_allreduce_coalesced(self):
self._test_allreduce_coalesced(backend="gloo")

@requires_gloo()
def test_all_to_all_single(self):
self._test_all_to_all_single(backend="gloo")

@requires_gloo()
def test_allgather_coalesced(self):
store = dist.FileStore(self.file_name, self.world_size)
Expand Down
5 changes: 5 additions & 0 deletions test/distributed/test_c10d_nccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2948,6 +2948,11 @@ def test_collectives(self):
def test_allreduce_coalesced(self):
self._test_allreduce_coalesced(backend="nccl")

@requires_nccl()
@skip_if_lt_x_gpu(1)
def test_all_to_all_single(self):
self._test_all_to_all_single(backend="nccl")

@requires_nccl()
@skip_if_lt_x_gpu(1)
def test_allgather_base(self):
Expand Down
43 changes: 43 additions & 0 deletions torch/csrc/distributed/c10d/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,21 @@ c10::intrusive_ptr<Work> alltoall_(
AllToAllOptions{std::chrono::milliseconds(timeout)});
}

c10::intrusive_ptr<Work> alltoall_base_(
at::Tensor& output,
at::Tensor& input,
const c10::intrusive_ptr<ProcessGroup>& process_group,
std::vector<int64_t> output_split_sizes,
std::vector<int64_t> input_split_sizes,
int64_t timeout) {
return process_group->alltoall_base(
output,
input,
output_split_sizes,
input_split_sizes,
AllToAllOptions{std::chrono::milliseconds(timeout)});
}

c10::intrusive_ptr<Work> barrier(
const c10::intrusive_ptr<ProcessGroup>& process_group,
const std::vector<int64_t>& device_ids,
Expand Down Expand Up @@ -271,6 +286,9 @@ TORCH_LIBRARY(c10d, m) {
m.def(
"alltoall_",
dispatch(c10::DispatchKey::CompositeExplicitAutograd, alltoall_));
m.def(
"alltoall_base_",
dispatch(c10::DispatchKey::CompositeExplicitAutograd, alltoall_base_));
m.def(
"barrier",
dispatch(c10::DispatchKey::CompositeExplicitAutograd, barrier));
Expand Down Expand Up @@ -523,6 +541,31 @@ c10::intrusive_ptr<Work> alltoall(
output_tensors, input_tensors, process_group, opts.timeout.count());
}

c10::intrusive_ptr<Work> alltoall_base(
const c10::intrusive_ptr<ProcessGroup>& process_group,
at::Tensor& output,
at::Tensor& input,
std::vector<int64_t> output_split_sizes,
std::vector<int64_t> input_split_sizes,
const AllToAllOptions& opts) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("c10d::alltoall_base_", "")
.typed<c10::intrusive_ptr<::c10d::Work>(
at::Tensor&,
at::Tensor&,
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
std::vector<int64_t>,
std::vector<int64_t>,
int64_t)>();
return op.call(
output,
input,
process_group,
output_split_sizes,
input_split_sizes,
opts.timeout.count());
}

void monitored_barrier(
const c10::intrusive_ptr<ProcessGroup>& process_group,
const BarrierOptions& opts,
Expand Down
8 changes: 8 additions & 0 deletions torch/csrc/distributed/c10d/Ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ TORCH_API c10::intrusive_ptr<Work> scatter(
const std::vector<std::vector<at::Tensor>>& input_tensors,
const ScatterOptions& opts = {});

TORCH_API c10::intrusive_ptr<Work> alltoall_base(
const c10::intrusive_ptr<ProcessGroup>& process_group,
at::Tensor& output,
at::Tensor& input,
const std::vector<int64_t> outputSplitSizes,
const std::vector<int64_t> inputSplitSizes,
const AllToAllOptions& opts = {});

TORCH_API c10::intrusive_ptr<Work> alltoall(
const c10::intrusive_ptr<ProcessGroup>& process_group,
at::TensorList output_tensors,
Expand Down
38 changes: 38 additions & 0 deletions torch/csrc/distributed/c10d/OpsImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,36 @@ c10::intrusive_ptr<Work> alltoall_cuda_(
AllToAllOptions{std::chrono::milliseconds(timeout)});
}

c10::intrusive_ptr<Work> alltoall_base_cpu_(
at::Tensor& output,
at::Tensor& input,
const c10::intrusive_ptr<ProcessGroup>& process_group,
std::vector<int64_t> output_split_sizes,
std::vector<int64_t> input_split_sizes,
int64_t timeout) {
return process_group->alltoall_base(
output,
input,
output_split_sizes,
input_split_sizes,
AllToAllOptions{std::chrono::milliseconds(timeout)});
}

c10::intrusive_ptr<Work> alltoall_base_cuda_(
at::Tensor& output,
at::Tensor& input,
const c10::intrusive_ptr<ProcessGroup>& process_group,
std::vector<int64_t> output_split_sizes,
std::vector<int64_t> input_split_sizes,
int64_t timeout) {
return process_group->alltoall_base(
output,
input,
output_split_sizes,
input_split_sizes,
AllToAllOptions{std::chrono::milliseconds(timeout)});
}

c10::intrusive_ptr<Work> barrier_cpu(
const c10::intrusive_ptr<ProcessGroup>& process_group,
const std::vector<int64_t>& device_ids,
Expand Down Expand Up @@ -558,6 +588,14 @@ TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
m.impl("alltoall_", alltoall_cuda_);
}

TORCH_LIBRARY_IMPL(c10d, CPU, m) {
m.impl("alltoall_base_", alltoall_base_cpu_);
}

TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
m.impl("alltoall_base_", alltoall_base_cuda_);
}

TORCH_LIBRARY_IMPL(c10d, CPU, m) {
m.impl("barrier", barrier_cpu);
}
Expand Down
22 changes: 7 additions & 15 deletions torch/csrc/distributed/c10d/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1440,34 +1440,26 @@ that adds a prefix to each key inserted to the store.

.def(
"alltoall_base",
&::c10d::ProcessGroup::alltoall_base,
py::arg("output_tensor"),
py::arg("input_tensor"),
py::arg("output_split_sizes"),
py::arg("input_split_sizes"),
py::arg("opts") = ::c10d::AllToAllOptions(),
py::call_guard<py::gil_scoped_release>())

.def(
"alltoall_base",
[](::c10d::ProcessGroup& self,
[](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
at::Tensor& output,
at::Tensor& input,
std::vector<int64_t> outputSplitSizes,
std::vector<int64_t> inputSplitSizes) {
return self.alltoall_base(
std::vector<int64_t> inputSplitSizes,
const ::c10d::AllToAllOptions& opts) {
return ::c10d::ops::alltoall_base(
self,
output,
input,
outputSplitSizes,
inputSplitSizes,
::c10d::AllToAllOptions());
opts);
},
py::arg("output"),
py::arg("input"),
py::arg("output_split_sizes"),
py::arg("input_split_sizes"),
py::arg("opts") = ::c10d::AllToAllOptions(),
py::call_guard<py::gil_scoped_release>())

.def(
"alltoall",
[](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
Expand Down