Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/distributed/test_c10d_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1466,7 +1466,7 @@ def _call_collective_with_varying_tensors(self, backend, collective, *args):
# multi tensor collectives
if collective == dist.barrier:
collective()
elif collective == dist.all_gather:
elif collective in (dist.all_gather, dist.gather):
collective([tensor], tensor, *args)
elif collective == dist.scatter:
collective(tensor, [tensor], *args)
Expand Down
32 changes: 32 additions & 0 deletions torch/csrc/distributed/c10d/OpsImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,30 @@ reduce_scatter_cuda_(
output_tensors, work);
}

c10::intrusive_ptr<Work> gather_cpu_(
const std::vector<std::vector<at::Tensor>>& output_tensors,
const std::vector<at::Tensor>& input_tensors,
const c10::intrusive_ptr<ProcessGroup>& process_group,
int64_t root_rank,
int64_t timeout) {
return process_group->gather(
const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors),
const_cast<std::vector<at::Tensor>&>(input_tensors),
GatherOptions{root_rank, std::chrono::milliseconds(timeout)});
}

c10::intrusive_ptr<Work> gather_cuda_(
const std::vector<std::vector<at::Tensor>>& output_tensors,
const std::vector<at::Tensor>& input_tensors,
const c10::intrusive_ptr<ProcessGroup>& process_group,
int64_t root_rank,
int64_t timeout) {
return process_group->gather(
const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors),
const_cast<std::vector<at::Tensor>&>(input_tensors),
GatherOptions{root_rank, std::chrono::milliseconds(timeout)});
}

std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> scatter_cpu_(
const std::vector<at::Tensor>& output_tensors,
const std::vector<std::vector<at::Tensor>>& input_tensors,
Expand Down Expand Up @@ -359,6 +383,14 @@ TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
m.impl("reduce_scatter_", reduce_scatter_cuda_);
}

TORCH_LIBRARY_IMPL(c10d, CPU, m) {
m.impl("gather_", gather_cpu_);
}

TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
m.impl("gather_", gather_cuda_);
}

TORCH_LIBRARY_IMPL(c10d, CPU, m) {
m.impl("scatter_", scatter_cpu_);
}
Expand Down