Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ libtorch_core_sources = sorted(

# These files are the only ones that are supported on Windows.
libtorch_distributed_base_sources = [
"torch/csrc/distributed/c10d/Backend.cpp",
"torch/csrc/distributed/c10d/FileStore.cpp",
"torch/csrc/distributed/c10d/GlooDeviceFactory.cpp",
"torch/csrc/distributed/c10d/Ops.cpp",
Expand Down
17 changes: 17 additions & 0 deletions torch/csrc/distributed/c10d/Backend.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include <c10/util/Logging.h>
#include <c10d/Backend.hpp>
#include <fmt/format.h>

namespace c10d {

Backend::Backend(int rank, int size) : rank_(rank), size_(size) {
C10_LOG_API_USAGE_ONCE("c10d.backend");
}

Backend::~Backend() {}

void Backend::init() {
C10_LOG_API_USAGE_ONCE(fmt::format("c10d.backend_{}", getBackendName()));
}

} // namespace c10d
276 changes: 276 additions & 0 deletions torch/csrc/distributed/c10d/Backend.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
#pragma once

#include <condition_variable>
#include <memory>
#include <mutex>
#include <stdexcept>
#include <unordered_map>
#include <vector>

#include <ATen/ATen.h>
#include <c10/macros/Macros.h>

#include <c10d/ProcessGroup.hpp>
#include <c10d/Types.hpp>
#include <c10d/Utils.hpp>
#include <c10d/debug.h>
#include <c10d/sequence_num.hpp>

constexpr auto kDefaultTimeout =
std::chrono::milliseconds(30 * 60 * 1000);

namespace c10d {

// Options is a base struct that defines the basic options
// when constructing a Backend. Each Backend subclass should
// extend this struct and define its options if it wants to provide more
// config options (beyond basic ones defined here) to end user.
struct TORCH_API Options : torch::CustomClassHolder {
explicit Options(
std::string backend,
std::chrono::milliseconds timeout = kDefaultTimeout)
: timeout(timeout), backend(backend) {}
virtual ~Options() = default;

std::chrono::milliseconds timeout;

// backend name
const std::string backend;
};

class TORCH_API Backend : public torch::CustomClassHolder {
public:
explicit Backend(int rank, int size);
virtual ~Backend() = 0;

// Subclasses must override this method to return the backend name
virtual const std::string getBackendName() const {
TORCH_INTERNAL_ASSERT(false, "getBackendName is not implemented.");
};

virtual c10::intrusive_ptr<ProcessGroup::Work> broadcast(
std::vector<at::Tensor>& /* tensors */,
const BroadcastOptions& /* opts */ = BroadcastOptions()) {
TORCH_CHECK(
false,
c10::str("Backend ", getBackendName(), "does not support broadcast"));
}

virtual c10::intrusive_ptr<ProcessGroup::Work> allreduce(
std::vector<at::Tensor>& /* tensors */,
const AllreduceOptions& /* opts */ = AllreduceOptions()) {
TORCH_CHECK(
false,
c10::str("Backend ", getBackendName(), "does not support allreduce"));
}

virtual c10::intrusive_ptr<ProcessGroup::Work> allreduce_coalesced(
std::vector<at::Tensor>& /* tensors */,
const AllreduceCoalescedOptions& /* opts */ =
AllreduceCoalescedOptions()) {
TORCH_CHECK(
false,
c10::str(
"Backend ",
getBackendName(),
"does not support allreduce_coalesced"));
}

virtual c10::intrusive_ptr<ProcessGroup::Work> reduce(
std::vector<at::Tensor>& /* tensors */,
const ReduceOptions& /* opts */ = ReduceOptions()) {
TORCH_CHECK(
false,
c10::str("Backend ", getBackendName(), "does not support reduce"));
}

virtual c10::intrusive_ptr<ProcessGroup::Work> allgather(
std::vector<std::vector<at::Tensor>>& /* outputTensors */,
std::vector<at::Tensor>& /* inputTensors */,
const AllgatherOptions& /* opts */ = AllgatherOptions()) {
TORCH_CHECK(
false,
c10::str("Backend ", getBackendName(), "does not support allgather"));
}

// Gathers a single tensor inputBuffer into a single buffer outputBuffer that
// is interpreted as a contigious collection of size inputBuffer * WORLD_SIZE.
// For implementers of ProcessGroup API and advanced users only.
// Note: this function will be deprecated in near future.
virtual c10::intrusive_ptr<ProcessGroup::Work> _allgather_base(
at::Tensor& /* outputBuffer */,
at::Tensor& /* inputBuffer */,
const AllgatherOptions& /* opts */ = AllgatherOptions()) {
TORCH_CHECK(
false,
c10::str(
"Backend ", getBackendName(), "does not support _allgather_base"));
}

// This function is deprecated and will be moved out of Backend to comms:
// * do not add dependencies on this function,
// * do not implement it in your Backend, implement _allgather_base
// instead.
virtual c10::intrusive_ptr<ProcessGroup::Work> allgather_coalesced(
std::vector<std::vector<at::Tensor>>& /* outputTensorLists */,
std::vector<at::Tensor>& /* inputTensors */,
const AllgatherOptions& /* opts */ = AllgatherOptions()) {
TORCH_CHECK(
false,
c10::str(
"Backend ",
getBackendName(),
"does not support allgather_coalesced"));
}

virtual c10::intrusive_ptr<ProcessGroup::Work> gather(
std::vector<std::vector<at::Tensor>>& /* outputTensors */,
std::vector<at::Tensor>& /* inputTensors */,
const GatherOptions& /* opts */ = GatherOptions()) {
TORCH_CHECK(
false,
c10::str("Backend ", getBackendName(), "does not support gather"));
}

virtual c10::intrusive_ptr<ProcessGroup::Work> scatter(
std::vector<at::Tensor>& /* outputTensors */,
std::vector<std::vector<at::Tensor>>& /* inputTensors */,
const ScatterOptions& /* opts */ = ScatterOptions()) {
TORCH_CHECK(
false,
c10::str("Backend ", getBackendName(), "does not support scatter"));
}

virtual c10::intrusive_ptr<ProcessGroup::Work> reduce_scatter(
std::vector<at::Tensor>& /* outputTensors */,
std::vector<std::vector<at::Tensor>>& /* inputTensors */,
const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) {
TORCH_CHECK(
false,
c10::str(
"Backend ", getBackendName(), "does not support reduce_scatter"));
}

virtual c10::intrusive_ptr<ProcessGroup::Work> _reduce_scatter_base(
at::Tensor& /* outputBuffer */,
at::Tensor& /* inputBuffer */,
const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) {
TORCH_CHECK(
false,
c10::str(
"Backend ",
getBackendName(),
"does not support _reduce_scatter_base"));
}

virtual c10::intrusive_ptr<ProcessGroup::Work> alltoall_base(
at::Tensor& /* outputBuffer */,
at::Tensor& /* inputBuffer */,
std::vector<int64_t>& /* outputSplitSizes */,
std::vector<int64_t>& /* inputSplitSizes */,
const AllToAllOptions& /* opts */ = AllToAllOptions()) {
TORCH_CHECK(
false,
c10::str(
"Backend ", getBackendName(), "does not support alltoall_base"));
}

virtual c10::intrusive_ptr<ProcessGroup::Work> alltoall(
std::vector<at::Tensor>& /* outputTensors */,
std::vector<at::Tensor>& /* inputTensors */,
const AllToAllOptions& opts = AllToAllOptions()) {
TORCH_CHECK(
false,
c10::str("Backend ", getBackendName(), "does not support alltoall"));
}

virtual void monitoredBarrier(
const BarrierOptions& /* unused */,
bool /* unused */ = false) {
auto backendName = getBackendName();
TORCH_CHECK(
false,
c10::str(
"Backend ",
backendName,
" does not support monitoredBarrier, only GLOO supports monitored barrier."));
}

// Agrees on an initial sequence number for the whole group by having rank 0
// create it and broadcast it to other ranks using the store. Only implemented
// for GLOO and NCCL backends currently.
virtual void setSequenceNumberForGroup() {
auto backendName = getBackendName();
TORCH_CHECK(
false,
c10::str(
"Backend ",
backendName,
" does not yet support sequence numbers."));
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for our record: we'd need to think whether SequenceNumber related methods should stay here or be moved to ProcessGroup. For the purpose of smooth refactorization, they can stay here for now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, agree on this!


// Retrieves the current sequence number for the whole group, which should be
// in sync. If the returned number is not consistent across the group, it
// may indicate that there is some sort of collective desynchronization.
virtual uint64_t getSequenceNumberForGroup() {
auto backendName = getBackendName();
TORCH_CHECK(
false,
c10::str(
"Backend ",
backendName,
" does not yet support sequence numbers."));
}

virtual c10::intrusive_ptr<ProcessGroup::Work> send(
std::vector<at::Tensor>& /* tensors */,
int /* dstRank */,
int /* tag */) {
TORCH_CHECK(
false, c10::str("Backend ", getBackendName(), "does not support send"));
}

virtual c10::intrusive_ptr<ProcessGroup::Work> recv(
std::vector<at::Tensor>& /* tensors */,
int /* srcRank */,
int /* tag */) {
TORCH_CHECK(
false, c10::str("Backend ", getBackendName(), "does not support recv"));
}

virtual c10::intrusive_ptr<ProcessGroup::Work> recvAnysource(
std::vector<at::Tensor>& /* tensors */,
int /* tag */) {
TORCH_CHECK(
false,
c10::str(
"Backend ", getBackendName(), "does not support recvAnysource"));
}

virtual c10::intrusive_ptr<ProcessGroup::Work> barrier(
const BarrierOptions& /* opts */ = BarrierOptions()) {
TORCH_CHECK(
false,
c10::str("Backend ", getBackendName(), "does not support barrier"));
}

int getRank() const {
return rank_;
}

int getSize() const {
return size_;
}

protected:
// Implementations of this interface need to call this to setup
// appropriate logging etc.
void init();

// Optional sequence number structure for matching collectives.
c10::optional<c10d::SequenceNum> sequenceNum_ = c10::nullopt;
const int rank_;
const int size_;
};

} // namespace c10d