Skip to content

Commit d39e9c1

Browse files
H-Huangpytorchmergebot
authored andcommitted
[6/N] [Dispatchable Collectives] Update recv with CPU / CUDA implementations (#83876)
* ### Changes - Updates for the recv collective ### Context #86225 Differential Revision: [D40044552](https://our.internmc.facebook.com/intern/diff/D40044552) Pull Request resolved: #83876 Approved by: https://github.com/kwen2501
1 parent d447eff commit d39e9c1

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

torch/csrc/distributed/c10d/OpsImpl.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,26 @@ c10::intrusive_ptr<Work> send_cuda(
2929
tensor_vec, static_cast<int>(dstRank), static_cast<int>(tag));
3030
}
3131

32+
c10::intrusive_ptr<Work> recv_cpu_(
33+
at::TensorList tensors,
34+
const c10::intrusive_ptr<ProcessGroup>& process_group,
35+
int64_t srcRank,
36+
int64_t tag) {
37+
auto tensor_vec = tensors.vec();
38+
return process_group->recv(
39+
tensor_vec, static_cast<int>(srcRank), static_cast<int>(tag));
40+
}
41+
42+
c10::intrusive_ptr<Work> recv_cuda_(
43+
at::TensorList tensors,
44+
const c10::intrusive_ptr<ProcessGroup>& process_group,
45+
int64_t srcRank,
46+
int64_t tag) {
47+
auto tensor_vec = tensors.vec();
48+
return process_group->recv(
49+
tensor_vec, static_cast<int>(srcRank), static_cast<int>(tag));
50+
}
51+
3252
std::tuple<std::vector<at::Tensor>, c10::intrusive_ptr<Work>> broadcast_cpu_(
3353
at::TensorList tensors,
3454
const c10::intrusive_ptr<ProcessGroup>& process_group,
@@ -105,6 +125,14 @@ TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
105125
m.impl("send", send_cuda);
106126
}
107127

128+
TORCH_LIBRARY_IMPL(c10d, CPU, m) {
129+
m.impl("recv_", recv_cpu_);
130+
}
131+
132+
TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
133+
m.impl("recv_", recv_cuda_);
134+
}
135+
108136
TORCH_LIBRARY_IMPL(c10d, CPU, m) {
109137
m.impl("broadcast_", broadcast_cpu_);
110138
}

0 commit comments

Comments
 (0)