Skip to content

Commit 649e89a

Browse files
committed
Allow Process Group to support multiple backends and move PG specfic implementations to backend class
ghstack-source-id: 58d6d12 Pull Request resolved: #88330
1 parent fddac50 commit 649e89a

20 files changed

+678
-199
lines changed

test/forward_backward_compatibility/check_forward_backward_compatibility.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,8 +285,7 @@
285285
("aten::_flash_scaled_dot_product_attention", datetime.date(2022, 11, 1)),
286286
("aten::_scaled_dot_product_attention", datetime.date(2022, 11, 1)),
287287
# Distributed c10d ops are all going to be updated
288-
("c10d::.*", datetime.date(2022, 10, 31)),
289-
("c10d::allgather_", datetime.date(2022, 10, 1)),
288+
("c10d::.*", datetime.date(2022, 12, 31)),
290289
("aten::to_padded_tensor", datetime.date(2022, 10, 1)),
291290
("aten::nested_to_padded_tensor", datetime.date(2022, 10, 1)),
292291
("aten::nested_tensor", datetime.date(2022, 10, 15)),

torch/csrc/distributed/c10d/Backend.hpp

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include <ATen/ATen.h>
1111
#include <c10/macros/Macros.h>
1212

13-
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
1413
#include <torch/csrc/distributed/c10d/Work.hpp>
1514
#include <torch/csrc/distributed/c10d/Types.hpp>
1615
#include <torch/csrc/distributed/c10d/Utils.hpp>
@@ -22,28 +21,28 @@ constexpr auto kDefaultTimeout =
2221

2322
namespace c10d {
2423

25-
// Options is a base struct that defines the basic options
26-
// when constructing a Backend. Each Backend subclass should
27-
// extend this struct and define its options if it wants to provide more
28-
// config options (beyond basic ones defined here) to end user.
29-
struct TORCH_API Options : torch::CustomClassHolder {
30-
explicit Options(
31-
std::string backend,
32-
std::chrono::milliseconds timeout = kDefaultTimeout)
33-
: timeout(timeout), backend(backend) {}
34-
virtual ~Options() = default;
35-
36-
std::chrono::milliseconds timeout;
37-
38-
// backend name
39-
const std::string backend;
40-
};
41-
4224
class TORCH_API Backend : public torch::CustomClassHolder {
4325
public:
4426
explicit Backend(int rank, int size);
4527
virtual ~Backend() = 0;
4628

29+
int getRank() const {
30+
return rank_;
31+
}
32+
33+
int getSize() const {
34+
return size_;
35+
}
36+
37+
virtual void startCoalescing() {
38+
// no-op for backends that have not implemented startCoalescing
39+
}
40+
41+
virtual void endCoalescing(
42+
std::vector<c10::intrusive_ptr<Work>>& /* reqs */) {
43+
// no-op for backends that have not implemented endCoalescing
44+
}
45+
4746
// Subclasses must override this method to return the backend name
4847
virtual const std::string getBackendName() const {
4948
TORCH_INTERNAL_ASSERT(false, "getBackendName is not implemented.");
@@ -255,14 +254,6 @@ class TORCH_API Backend : public torch::CustomClassHolder {
255254
c10::str("Backend ", getBackendName(), "does not support barrier"));
256255
}
257256

258-
int getRank() const {
259-
return rank_;
260-
}
261-
262-
int getSize() const {
263-
return size_;
264-
}
265-
266257
protected:
267258
// Implementations of this interface need to call this to setup
268259
// appropriate logging etc.
@@ -272,6 +263,9 @@ class TORCH_API Backend : public torch::CustomClassHolder {
272263
c10::optional<c10d::SequenceNum> sequenceNum_ = c10::nullopt;
273264
const int rank_;
274265
const int size_;
266+
// Debug level setting. It is parsed once when ProcessGroup is constructed and
267+
// remains the same across use of this process group.
268+
DebugLevel dist_debug_level_;
275269
};
276270

277271
} // namespace c10d

torch/csrc/distributed/c10d/Ops.cpp

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,13 @@ c10::intrusive_ptr<Work> reduce_(
6060
std::tuple<std::vector<std::vector<at::Tensor>>, c10::intrusive_ptr<Work>>
6161
allgather_(
6262
const std::vector<std::vector<at::Tensor>>& output_tensors,
63-
const std::vector<at::Tensor>& input_tensors,
63+
at::TensorList input_tensors,
6464
const c10::intrusive_ptr<ProcessGroup>& process_group,
6565
int64_t timeout) {
66+
auto input_tensors_vec = input_tensors.vec();
6667
auto work = process_group->allgather(
6768
const_cast<std::vector<std::vector<at::Tensor>>&>(output_tensors),
68-
const_cast<std::vector<at::Tensor>&>(input_tensors),
69+
input_tensors_vec,
6970
AllgatherOptions{std::chrono::milliseconds(timeout)});
7071

7172
// Copy output tensors (not storage) so that this can be used in a functional
@@ -132,6 +133,7 @@ c10::intrusive_ptr<Work> alltoall_(
132133
}
133134

134135
c10::intrusive_ptr<Work> barrier(
136+
at::Tensor /* unused */,
135137
const c10::intrusive_ptr<ProcessGroup>& process_group,
136138
const std::vector<int64_t>& device_ids,
137139
int64_t timeout) {
@@ -252,17 +254,18 @@ c10::intrusive_ptr<Work> allreduce(
252254
c10::intrusive_ptr<Work> allgather(
253255
const c10::intrusive_ptr<ProcessGroup>& process_group,
254256
const std::vector<std::vector<at::Tensor>>& output_tensors,
255-
const std::vector<at::Tensor>& input_tensors,
257+
at::TensorList input_tensors,
256258
const AllgatherOptions& opts) {
257259
static auto op = c10::Dispatcher::singleton()
258260
.findSchemaOrThrow("c10d::allgather_", "")
259261
.typed<std::tuple<
260262
std::vector<std::vector<at::Tensor>>,
261263
c10::intrusive_ptr<Work>>(
262264
const std::vector<std::vector<at::Tensor>>&,
263-
const std::vector<at::Tensor>&,
265+
at::TensorList,
264266
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
265267
int64_t)>();
268+
266269
return std::get<1>(op.call(
267270
output_tensors, input_tensors, process_group, opts.timeout.count()));
268271
}
@@ -376,10 +379,22 @@ c10::intrusive_ptr<Work> barrier(
376379
static auto op = c10::Dispatcher::singleton()
377380
.findSchemaOrThrow("c10d::barrier", "")
378381
.typed<c10::intrusive_ptr<::c10d::Work>(
382+
at::Tensor,
379383
const c10::intrusive_ptr<::c10d::ProcessGroup>&,
380384
const std::vector<int64_t>&,
381385
int64_t)>();
382-
return op.call(process_group, opts.device_ids, opts.timeout.count());
386+
387+
// Default to using cpu implementation
388+
at::Tensor tensor = at::empty({0}, at::TensorOptions().device(at::kCPU));
389+
// if opts.device_ids or backend is nccl are specified then use cuda
390+
// implementation
391+
// TODO: getBackendName() is always "NOT DEFINED"
392+
if (opts.device_ids.size() > 0 || process_group->getBackendName() == "nccl") {
393+
// set cuda tensor
394+
tensor = at::empty(
395+
{0}, at::TensorOptions().device(at::kCUDA, opts.device_ids[0]));
396+
}
397+
return op.call(tensor, process_group, opts.device_ids, opts.timeout.count());
383398
}
384399

385400
c10::intrusive_ptr<Work> send(

torch/csrc/distributed/c10d/Ops.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ TORCH_API c10::intrusive_ptr<Work> allreduce(
2424
TORCH_API c10::intrusive_ptr<Work> allgather(
2525
const c10::intrusive_ptr<ProcessGroup>& process_group,
2626
const std::vector<std::vector<at::Tensor>>& output_tensors,
27-
const std::vector<at::Tensor>& input_tensors,
27+
at::TensorList input_tensors,
2828
const AllgatherOptions& opts = {});
2929

3030
TORCH_API c10::intrusive_ptr<Work> reduce_scatter(

0 commit comments

Comments
 (0)