Skip to content

Commit 882d71a

Browse files
H-Huangpytorchmergebot
authored andcommitted
[1/N] [Dispatchable Collectives] Create Backend class (#83679)
### Changes: - Create a new Backend class which contains collectives similar to that of https://github.com/pytorch/pytorch/blob/master/torch/csrc/distributed/c10d/ProcessGroup.hpp. #### Motivation In future PRs, the existing ProcessGroupNCCL/Gloo/UCC will be migrated to derive from this Backend class. The idea is that we will repurpose ProcessGroup to instead contain a list of Backends (ProcessGroupNCCL/Gloo/UCC) and perform dispatching to them based on tensor type. Differential Revision: [D38839213](https://our.internmc.facebook.com/intern/diff/D38839213) Pull Request resolved: #83679 Approved by: https://github.com/kwen2501
1 parent 7834f55 commit 882d71a

File tree

3 files changed

+294
-0
lines changed

3 files changed

+294
-0
lines changed

build_variables.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,7 @@ libtorch_core_sources = sorted(
459459

460460
# These files are the only ones that are supported on Windows.
461461
libtorch_distributed_base_sources = [
462+
"torch/csrc/distributed/c10d/Backend.cpp",
462463
"torch/csrc/distributed/c10d/FileStore.cpp",
463464
"torch/csrc/distributed/c10d/GlooDeviceFactory.cpp",
464465
"torch/csrc/distributed/c10d/Ops.cpp",
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#include <c10/util/Logging.h>
2+
#include <c10d/Backend.hpp>
3+
#include <fmt/format.h>
4+
5+
namespace c10d {
6+
7+
Backend::Backend(int rank, int size) : rank_(rank), size_(size) {
8+
C10_LOG_API_USAGE_ONCE("c10d.backend");
9+
}
10+
11+
Backend::~Backend() {}
12+
13+
void Backend::init() {
14+
C10_LOG_API_USAGE_ONCE(fmt::format("c10d.backend_{}", getBackendName()));
15+
}
16+
17+
} // namespace c10d
Lines changed: 276 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,276 @@
1+
#pragma once
2+
3+
#include <condition_variable>
4+
#include <memory>
5+
#include <mutex>
6+
#include <stdexcept>
7+
#include <unordered_map>
8+
#include <vector>
9+
10+
#include <ATen/ATen.h>
11+
#include <c10/macros/Macros.h>
12+
13+
#include <c10d/ProcessGroup.hpp>
14+
#include <c10d/Types.hpp>
15+
#include <c10d/Utils.hpp>
16+
#include <c10d/debug.h>
17+
#include <c10d/sequence_num.hpp>
18+
19+
constexpr auto kDefaultTimeout =
20+
std::chrono::milliseconds(30 * 60 * 1000);
21+
22+
namespace c10d {
23+
24+
// Options is a base struct that defines the basic options
25+
// when constructing a Backend. Each Backend subclass should
26+
// extend this struct and define its options if it wants to provide more
27+
// config options (beyond basic ones defined here) to end user.
28+
struct TORCH_API Options : torch::CustomClassHolder {
29+
explicit Options(
30+
std::string backend,
31+
std::chrono::milliseconds timeout = kDefaultTimeout)
32+
: timeout(timeout), backend(backend) {}
33+
virtual ~Options() = default;
34+
35+
std::chrono::milliseconds timeout;
36+
37+
// backend name
38+
const std::string backend;
39+
};
40+
41+
class TORCH_API Backend : public torch::CustomClassHolder {
42+
public:
43+
explicit Backend(int rank, int size);
44+
virtual ~Backend() = 0;
45+
46+
// Subclasses must override this method to return the backend name
47+
virtual const std::string getBackendName() const {
48+
TORCH_INTERNAL_ASSERT(false, "getBackendName is not implemented.");
49+
};
50+
51+
virtual c10::intrusive_ptr<ProcessGroup::Work> broadcast(
52+
std::vector<at::Tensor>& /* tensors */,
53+
const BroadcastOptions& /* opts */ = BroadcastOptions()) {
54+
TORCH_CHECK(
55+
false,
56+
c10::str("Backend ", getBackendName(), "does not support broadcast"));
57+
}
58+
59+
virtual c10::intrusive_ptr<ProcessGroup::Work> allreduce(
60+
std::vector<at::Tensor>& /* tensors */,
61+
const AllreduceOptions& /* opts */ = AllreduceOptions()) {
62+
TORCH_CHECK(
63+
false,
64+
c10::str("Backend ", getBackendName(), "does not support allreduce"));
65+
}
66+
67+
virtual c10::intrusive_ptr<ProcessGroup::Work> allreduce_coalesced(
68+
std::vector<at::Tensor>& /* tensors */,
69+
const AllreduceCoalescedOptions& /* opts */ =
70+
AllreduceCoalescedOptions()) {
71+
TORCH_CHECK(
72+
false,
73+
c10::str(
74+
"Backend ",
75+
getBackendName(),
76+
"does not support allreduce_coalesced"));
77+
}
78+
79+
virtual c10::intrusive_ptr<ProcessGroup::Work> reduce(
80+
std::vector<at::Tensor>& /* tensors */,
81+
const ReduceOptions& /* opts */ = ReduceOptions()) {
82+
TORCH_CHECK(
83+
false,
84+
c10::str("Backend ", getBackendName(), "does not support reduce"));
85+
}
86+
87+
virtual c10::intrusive_ptr<ProcessGroup::Work> allgather(
88+
std::vector<std::vector<at::Tensor>>& /* outputTensors */,
89+
std::vector<at::Tensor>& /* inputTensors */,
90+
const AllgatherOptions& /* opts */ = AllgatherOptions()) {
91+
TORCH_CHECK(
92+
false,
93+
c10::str("Backend ", getBackendName(), "does not support allgather"));
94+
}
95+
96+
// Gathers a single tensor inputBuffer into a single buffer outputBuffer that
97+
// is interpreted as a contigious collection of size inputBuffer * WORLD_SIZE.
98+
// For implementers of ProcessGroup API and advanced users only.
99+
// Note: this function will be deprecated in near future.
100+
virtual c10::intrusive_ptr<ProcessGroup::Work> _allgather_base(
101+
at::Tensor& /* outputBuffer */,
102+
at::Tensor& /* inputBuffer */,
103+
const AllgatherOptions& /* opts */ = AllgatherOptions()) {
104+
TORCH_CHECK(
105+
false,
106+
c10::str(
107+
"Backend ", getBackendName(), "does not support _allgather_base"));
108+
}
109+
110+
// This function is deprecated and will be moved out of Backend to comms:
111+
// * do not add dependencies on this function,
112+
// * do not implement it in your Backend, implement _allgather_base
113+
// instead.
114+
virtual c10::intrusive_ptr<ProcessGroup::Work> allgather_coalesced(
115+
std::vector<std::vector<at::Tensor>>& /* outputTensorLists */,
116+
std::vector<at::Tensor>& /* inputTensors */,
117+
const AllgatherOptions& /* opts */ = AllgatherOptions()) {
118+
TORCH_CHECK(
119+
false,
120+
c10::str(
121+
"Backend ",
122+
getBackendName(),
123+
"does not support allgather_coalesced"));
124+
}
125+
126+
virtual c10::intrusive_ptr<ProcessGroup::Work> gather(
127+
std::vector<std::vector<at::Tensor>>& /* outputTensors */,
128+
std::vector<at::Tensor>& /* inputTensors */,
129+
const GatherOptions& /* opts */ = GatherOptions()) {
130+
TORCH_CHECK(
131+
false,
132+
c10::str("Backend ", getBackendName(), "does not support gather"));
133+
}
134+
135+
virtual c10::intrusive_ptr<ProcessGroup::Work> scatter(
136+
std::vector<at::Tensor>& /* outputTensors */,
137+
std::vector<std::vector<at::Tensor>>& /* inputTensors */,
138+
const ScatterOptions& /* opts */ = ScatterOptions()) {
139+
TORCH_CHECK(
140+
false,
141+
c10::str("Backend ", getBackendName(), "does not support scatter"));
142+
}
143+
144+
virtual c10::intrusive_ptr<ProcessGroup::Work> reduce_scatter(
145+
std::vector<at::Tensor>& /* outputTensors */,
146+
std::vector<std::vector<at::Tensor>>& /* inputTensors */,
147+
const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) {
148+
TORCH_CHECK(
149+
false,
150+
c10::str(
151+
"Backend ", getBackendName(), "does not support reduce_scatter"));
152+
}
153+
154+
virtual c10::intrusive_ptr<ProcessGroup::Work> _reduce_scatter_base(
155+
at::Tensor& /* outputBuffer */,
156+
at::Tensor& /* inputBuffer */,
157+
const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) {
158+
TORCH_CHECK(
159+
false,
160+
c10::str(
161+
"Backend ",
162+
getBackendName(),
163+
"does not support _reduce_scatter_base"));
164+
}
165+
166+
virtual c10::intrusive_ptr<ProcessGroup::Work> alltoall_base(
167+
at::Tensor& /* outputBuffer */,
168+
at::Tensor& /* inputBuffer */,
169+
std::vector<int64_t>& /* outputSplitSizes */,
170+
std::vector<int64_t>& /* inputSplitSizes */,
171+
const AllToAllOptions& /* opts */ = AllToAllOptions()) {
172+
TORCH_CHECK(
173+
false,
174+
c10::str(
175+
"Backend ", getBackendName(), "does not support alltoall_base"));
176+
}
177+
178+
virtual c10::intrusive_ptr<ProcessGroup::Work> alltoall(
179+
std::vector<at::Tensor>& /* outputTensors */,
180+
std::vector<at::Tensor>& /* inputTensors */,
181+
const AllToAllOptions& opts = AllToAllOptions()) {
182+
TORCH_CHECK(
183+
false,
184+
c10::str("Backend ", getBackendName(), "does not support alltoall"));
185+
}
186+
187+
virtual void monitoredBarrier(
188+
const BarrierOptions& /* unused */,
189+
bool /* unused */ = false) {
190+
auto backendName = getBackendName();
191+
TORCH_CHECK(
192+
false,
193+
c10::str(
194+
"Backend ",
195+
backendName,
196+
" does not support monitoredBarrier, only GLOO supports monitored barrier."));
197+
}
198+
199+
// Agrees on an initial sequence number for the whole group by having rank 0
200+
// create it and broadcast it to other ranks using the store. Only implemented
201+
// for GLOO and NCCL backends currently.
202+
virtual void setSequenceNumberForGroup() {
203+
auto backendName = getBackendName();
204+
TORCH_CHECK(
205+
false,
206+
c10::str(
207+
"Backend ",
208+
backendName,
209+
" does not yet support sequence numbers."));
210+
}
211+
212+
// Retrieves the current sequence number for the whole group, which should be
213+
// in sync. If the returned number is not consistent across the group, it
214+
// may indicate that there is some sort of collective desynchronization.
215+
virtual uint64_t getSequenceNumberForGroup() {
216+
auto backendName = getBackendName();
217+
TORCH_CHECK(
218+
false,
219+
c10::str(
220+
"Backend ",
221+
backendName,
222+
" does not yet support sequence numbers."));
223+
}
224+
225+
virtual c10::intrusive_ptr<ProcessGroup::Work> send(
226+
std::vector<at::Tensor>& /* tensors */,
227+
int /* dstRank */,
228+
int /* tag */) {
229+
TORCH_CHECK(
230+
false, c10::str("Backend ", getBackendName(), "does not support send"));
231+
}
232+
233+
virtual c10::intrusive_ptr<ProcessGroup::Work> recv(
234+
std::vector<at::Tensor>& /* tensors */,
235+
int /* srcRank */,
236+
int /* tag */) {
237+
TORCH_CHECK(
238+
false, c10::str("Backend ", getBackendName(), "does not support recv"));
239+
}
240+
241+
virtual c10::intrusive_ptr<ProcessGroup::Work> recvAnysource(
242+
std::vector<at::Tensor>& /* tensors */,
243+
int /* tag */) {
244+
TORCH_CHECK(
245+
false,
246+
c10::str(
247+
"Backend ", getBackendName(), "does not support recvAnysource"));
248+
}
249+
250+
virtual c10::intrusive_ptr<ProcessGroup::Work> barrier(
251+
const BarrierOptions& /* opts */ = BarrierOptions()) {
252+
TORCH_CHECK(
253+
false,
254+
c10::str("Backend ", getBackendName(), "does not support barrier"));
255+
}
256+
257+
int getRank() const {
258+
return rank_;
259+
}
260+
261+
int getSize() const {
262+
return size_;
263+
}
264+
265+
protected:
266+
// Implementations of this interface need to call this to setup
267+
// appropriate logging etc.
268+
void init();
269+
270+
// Optional sequence number structure for matching collectives.
271+
c10::optional<c10d::SequenceNum> sequenceNum_ = c10::nullopt;
272+
const int rank_;
273+
const int size_;
274+
};
275+
276+
} // namespace c10d

0 commit comments

Comments
 (0)