Skip to content

Commit df1df9d

Browse files
H-Huangpytorchmergebot
authored andcommitted
[16/N] Add _allgather_base custom op with CPU/CUDA implementation (#88889)
Differential Revision: [D41227739](https://our.internmc.facebook.com/intern/diff/D41227739) Pull Request resolved: #88889 Approved by: https://github.com/kwen2501
1 parent 3765621 commit df1df9d

File tree

5 files changed

+77
-1
lines changed

5 files changed

+77
-1
lines changed

test/distributed/test_c10d_nccl.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2958,6 +2958,23 @@ def test_collectives(self):
29582958
def test_allreduce_coalesced(self):
29592959
self._test_allreduce_coalesced(backend="nccl")
29602960

2961+
@requires_nccl()
2962+
@skip_if_lt_x_gpu(1)
2963+
def test_allgather_base(self):
2964+
store = dist.FileStore(self.file_name, self.world_size)
2965+
dist.init_process_group(
2966+
"nccl",
2967+
world_size=self.world_size,
2968+
rank=self.rank,
2969+
store=store,
2970+
)
2971+
device = "cuda"
2972+
tensor = torch.ones(10, 10, device=torch.device(device))
2973+
output_tensor = torch.zeros(10, 10, device=torch.device(device))
2974+
dist.all_gather_into_tensor(output_tensor, tensor)
2975+
self.assertEqual(output_tensor, tensor)
2976+
2977+
29612978
if __name__ == "__main__":
29622979
assert (
29632980
not torch.cuda._initialized

torch/csrc/distributed/c10d/Ops.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,13 @@ allgather_(
8888
output_tensors, work);
8989
}
9090

91+
c10::intrusive_ptr<Work> _allgather_base_(
92+
at::Tensor& output_tensor,
93+
at::Tensor& input_tensor,
94+
const c10::intrusive_ptr<ProcessGroup>& process_group) {
95+
return process_group->_allgather_base(output_tensor, input_tensor);
96+
}
97+
9198
std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> reduce_scatter_(
9299
const std::vector<at::Tensor>& output_tensors,
93100
const std::vector<std::vector<at::Tensor>>& input_tensors,
@@ -197,6 +204,9 @@ TORCH_LIBRARY(c10d, m) {
197204
m.def(
198205
"allgather_",
199206
dispatch(c10::DispatchKey::CompositeExplicitAutograd, allgather_));
207+
m.def(
208+
"_allgather_base_",
209+
dispatch(c10::DispatchKey::CompositeExplicitAutograd, _allgather_base_));
200210
m.def(
201211
"reduce_scatter_",
202212
dispatch(c10::DispatchKey::CompositeExplicitAutograd, reduce_scatter_));
@@ -303,6 +313,21 @@ c10::intrusive_ptr<Work> allgather(
303313
output_tensors, input_tensors, process_group, opts.timeout.count()));
304314
}
305315

316+
c10::intrusive_ptr<Work> _allgather_base(
317+
const c10::intrusive_ptr<ProcessGroup>& process_group,
318+
at::Tensor& output_tensor,
319+
at::Tensor& input_tensor,
320+
const AllgatherOptions& opts) {
321+
static auto op = c10::Dispatcher::singleton()
322+
.findSchemaOrThrow("c10d::_allgather_base_", "")
323+
.typed<c10::intrusive_ptr<Work>(
324+
at::Tensor&,
325+
at::Tensor&,
326+
const c10::intrusive_ptr<::c10d::ProcessGroup>&)>();
327+
328+
return op.call(output_tensor, input_tensor, process_group);
329+
}
330+
306331
c10::intrusive_ptr<Work> reduce_scatter(
307332
const c10::intrusive_ptr<ProcessGroup>& process_group,
308333
const std::vector<at::Tensor>& output_tensors,

torch/csrc/distributed/c10d/Ops.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ TORCH_API c10::intrusive_ptr<Work> allgather(
3232
const std::vector<at::Tensor>& input_tensors,
3333
const AllgatherOptions& opts = {});
3434

35+
TORCH_API c10::intrusive_ptr<Work> _allgather_base(
36+
const c10::intrusive_ptr<ProcessGroup>& process_group,
37+
at::Tensor& outputTensor,
38+
at::Tensor& inputTensor,
39+
const AllgatherOptions& opts = {});
40+
3541
TORCH_API c10::intrusive_ptr<Work> reduce_scatter(
3642
const c10::intrusive_ptr<ProcessGroup>& process_group,
3743
const std::vector<at::Tensor>& output_tensors,

torch/csrc/distributed/c10d/OpsImpl.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,20 @@ allgather_cuda_(
211211
output_tensors, work);
212212
}
213213

214+
c10::intrusive_ptr<Work> _allgather_base_cpu_(
215+
at::Tensor& output_tensor,
216+
at::Tensor& input_tensor,
217+
const c10::intrusive_ptr<ProcessGroup>& process_group) {
218+
return process_group->_allgather_base(output_tensor, input_tensor);
219+
}
220+
221+
c10::intrusive_ptr<Work> _allgather_base_cuda_(
222+
at::Tensor& output_tensor,
223+
at::Tensor& input_tensor,
224+
const c10::intrusive_ptr<ProcessGroup>& process_group) {
225+
return process_group->_allgather_base(output_tensor, input_tensor);
226+
}
227+
214228
std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>>
215229
reduce_scatter_cpu_(
216230
const std::vector<at::Tensor>& output_tensors,
@@ -409,6 +423,14 @@ TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
409423
m.impl("allgather_", allgather_cuda_);
410424
}
411425

426+
TORCH_LIBRARY_IMPL(c10d, CPU, m) {
427+
m.impl("_allgather_base_", _allgather_base_cpu_);
428+
}
429+
430+
TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
431+
m.impl("_allgather_base_", _allgather_base_cuda_);
432+
}
433+
412434
TORCH_LIBRARY_IMPL(c10d, CPU, m) {
413435
m.impl("reduce_scatter_", reduce_scatter_cpu_);
414436
}

torch/csrc/distributed/c10d/init.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1187,7 +1187,13 @@ that adds a prefix to each key inserted to the store.
11871187

11881188
.def(
11891189
"_allgather_base",
1190-
&::c10d::ProcessGroup::_allgather_base,
1190+
[](const c10::intrusive_ptr<::c10d::ProcessGroup>& self,
1191+
at::Tensor& output_tensor,
1192+
at::Tensor& input_tensor,
1193+
const ::c10d::AllgatherOptions& opts) {
1194+
return ::c10d::ops::_allgather_base(
1195+
self, output_tensor, input_tensor, opts);
1196+
},
11911197
py::arg("output"),
11921198
py::arg("input"),
11931199
py::arg("opts") = ::c10d::AllgatherOptions(),

0 commit comments

Comments
 (0)