Skip to content

Commit 3a3500f

Browse files
H-Huangpytorchmergebot
authored andcommitted
[13/N] Update gather with CPU/CUDA implementations (#86409)
Differential Revision: [D40181612](https://our.internmc.facebook.com/intern/diff/D40181612) Pull Request resolved: #86409 Approved by: https://github.com/kwen2501
1 parent 1af9b38 commit 3a3500f

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

test/distributed/test_c10d_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1466,7 +1466,7 @@ def _call_collective_with_varying_tensors(self, backend, collective, *args):
14661466
# multi tensor collectives
14671467
if collective == dist.barrier:
14681468
collective()
1469-
elif collective == dist.all_gather:
1469+
elif collective in (dist.all_gather, dist.gather):
14701470
collective([tensor], tensor, *args)
14711471
elif collective == dist.scatter:
14721472
collective(tensor, [tensor], *args)

torch/csrc/distributed/c10d/OpsImpl.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,30 @@ reduce_scatter_cuda_(
219219
output_tensors, work);
220220
}
221221

222+
c10::intrusive_ptr<Work> gather_cpu_(
223+
const std::vector<std::vector<at::Tensor>>& output_tensors,
224+
const std::vector<at::Tensor>& input_tensors,
225+
const c10::intrusive_ptr<ProcessGroup>& process_group,
226+
int64_t root_rank,
227+
int64_t timeout) {
228+
return process_group->gather(
229+
const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors),
230+
const_cast<std::vector<at::Tensor>&>(input_tensors),
231+
GatherOptions{root_rank, std::chrono::milliseconds(timeout)});
232+
}
233+
234+
c10::intrusive_ptr<Work> gather_cuda_(
235+
const std::vector<std::vector<at::Tensor>>& output_tensors,
236+
const std::vector<at::Tensor>& input_tensors,
237+
const c10::intrusive_ptr<ProcessGroup>& process_group,
238+
int64_t root_rank,
239+
int64_t timeout) {
240+
return process_group->gather(
241+
const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors),
242+
const_cast<std::vector<at::Tensor>&>(input_tensors),
243+
GatherOptions{root_rank, std::chrono::milliseconds(timeout)});
244+
}
245+
222246
std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> scatter_cpu_(
223247
const std::vector<at::Tensor>& output_tensors,
224248
const std::vector<std::vector<at::Tensor>>& input_tensors,
@@ -359,6 +383,14 @@ TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
359383
m.impl("reduce_scatter_", reduce_scatter_cuda_);
360384
}
361385

386+
TORCH_LIBRARY_IMPL(c10d, CPU, m) {
387+
m.impl("gather_", gather_cpu_);
388+
}
389+
390+
TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
391+
m.impl("gather_", gather_cuda_);
392+
}
393+
362394
TORCH_LIBRARY_IMPL(c10d, CPU, m) {
363395
m.impl("scatter_", scatter_cpu_);
364396
}

0 commit comments

Comments
 (0)