Skip to content

Commit 23a5fd0

Browse files
committed
[6/N] [Dispatchable Collectives] Update recv with CPU / CUDA implementations
ghstack-source-id: 6c91a29 Pull Request resolved: #83876
1 parent 92fafef commit 23a5fd0

File tree

4 files changed

+43
-11
lines changed

4 files changed

+43
-11
lines changed

test/distributed/test_c10d_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1358,6 +1358,7 @@ def _test_collectives(self, backend):
13581358
)
13591359
collectives_and_args = [
13601360
(dist.send, self.rank),
1361+
(dist.recv,),
13611362
(dist.broadcast, self.rank),
13621363
(dist.all_reduce,)
13631364
]

torch/csrc/distributed/c10d/Ops.cpp

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -93,16 +93,6 @@ c10::intrusive_ptr<Work> barrier(
9393
BarrierOptions{device_ids, std::chrono::milliseconds(timeout)});
9494
}
9595

96-
c10::intrusive_ptr<Work> recv_(
97-
at::TensorList tensors,
98-
const c10::intrusive_ptr<ProcessGroup>& process_group,
99-
int64_t srcRank,
100-
int64_t tag) {
101-
auto tensor_vec = tensors.vec();
102-
return process_group->recv(
103-
tensor_vec, static_cast<int>(srcRank), static_cast<int>(tag));
104-
}
105-
10696
TORCH_LIBRARY(c10d, m) {
10797
// The following ProcessGroup and Work definations are more like declarations.
10898
// They don't expose the details of the two classes into TorchScript.
@@ -113,6 +103,8 @@ TORCH_LIBRARY(c10d, m) {
113103
// __torch_dispatch__.
114104
m.def(
115105
"send(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int dstRank, int tag) -> __torch__.torch.classes.c10d.Work");
106+
m.def(
107+
"recv_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int srcRank, int tag) -> __torch__.torch.classes.c10d.Work");
116108
m.def(
117109
"broadcast_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, int root_rank, int root_tensor, int timeout) -> __torch__.torch.classes.c10d.Work");
118110
m.def(
@@ -138,7 +130,6 @@ TORCH_LIBRARY(c10d, m) {
138130
m.def(
139131
"barrier",
140132
dispatch(c10::DispatchKey::CompositeExplicitAutograd, barrier));
141-
m.def("recv_", dispatch(c10::DispatchKey::CompositeExplicitAutograd, recv_));
142133
}
143134
} // namespace
144135

torch/csrc/distributed/c10d/OpsImpl.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,26 @@ c10::intrusive_ptr<Work> send_cuda(
2727
tensor_vec, static_cast<int>(dstRank), static_cast<int>(tag));
2828
}
2929

30+
c10::intrusive_ptr<Work> recv_cpu_(
31+
at::TensorList tensors,
32+
const c10::intrusive_ptr<ProcessGroup>& process_group,
33+
int64_t srcRank,
34+
int64_t tag) {
35+
auto tensor_vec = tensors.vec();
36+
return process_group->recv(
37+
tensor_vec, static_cast<int>(srcRank), static_cast<int>(tag));
38+
}
39+
40+
c10::intrusive_ptr<Work> recv_cuda_(
41+
at::TensorList tensors,
42+
const c10::intrusive_ptr<ProcessGroup>& process_group,
43+
int64_t srcRank,
44+
int64_t tag) {
45+
auto tensor_vec = tensors.vec();
46+
return process_group->recv(
47+
tensor_vec, static_cast<int>(srcRank), static_cast<int>(tag));
48+
}
49+
3050
c10::intrusive_ptr<Work> broadcast_cpu_(
3151
at::TensorList tensors,
3252
const c10::intrusive_ptr<ProcessGroup>& process_group,
@@ -89,6 +109,14 @@ TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
89109
m.impl("send", send_cuda);
90110
}
91111

112+
TORCH_LIBRARY_IMPL(c10d, CPU, m) {
113+
m.impl("recv_", recv_cpu_);
114+
}
115+
116+
TORCH_LIBRARY_IMPL(c10d, CUDA, m) {
117+
m.impl("recv_", recv_cuda_);
118+
}
119+
92120
TORCH_LIBRARY_IMPL(c10d, CPU, m) {
93121
m.impl("broadcast_", broadcast_cpu_);
94122
}

torch/csrc/distributed/c10d/OpsImpl.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,18 @@ c10::intrusive_ptr<Work> send_cuda(
1919
int64_t dstRank,
2020
int64_t tag);
2121

22+
c10::intrusive_ptr<Work> recv_cpu_(
23+
at::TensorList tensors,
24+
const c10::intrusive_ptr<ProcessGroup>& process_group,
25+
int64_t srcRank,
26+
int64_t tag);
27+
28+
c10::intrusive_ptr<Work> recv_cuda_(
29+
at::TensorList tensors,
30+
const c10::intrusive_ptr<ProcessGroup>& process_group,
31+
int64_t srcRank,
32+
int64_t tag);
33+
2234
c10::intrusive_ptr<Work> broadcast_cpu_(
2335
at::TensorList tensors,
2436
const c10::intrusive_ptr<ProcessGroup>& process_group,

0 commit comments

Comments
 (0)