-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[1/N] [Dispatchable Collectives] Create Backend class #83679
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
3bf089d
[1/N] [Dispatchable Collectives] Create Backend class
H-Huang e26a7d4
Update on "[1/N] [Dispatchable Collectives] Create Backend class"
H-Huang 34e6900
Update on "[1/N] [Dispatchable Collectives] Create Backend class"
H-Huang 386bb80
Update on "[1/N] [Dispatchable Collectives] Create Backend class"
H-Huang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| #include <c10/util/Logging.h> | ||
| #include <c10d/Backend.hpp> | ||
| #include <fmt/format.h> | ||
|
|
||
| namespace c10d { | ||
|
|
||
| Backend::Backend(int rank, int size) : rank_(rank), size_(size) { | ||
| C10_LOG_API_USAGE_ONCE("c10d.backend"); | ||
| } | ||
|
|
||
| Backend::~Backend() {} | ||
|
|
||
| void Backend::init() { | ||
| C10_LOG_API_USAGE_ONCE(fmt::format("c10d.backend_{}", getBackendName())); | ||
| } | ||
|
|
||
| } // namespace c10d |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,276 @@ | ||
| #pragma once | ||
|
|
||
| #include <condition_variable> | ||
| #include <memory> | ||
| #include <mutex> | ||
| #include <stdexcept> | ||
| #include <unordered_map> | ||
| #include <vector> | ||
|
|
||
| #include <ATen/ATen.h> | ||
| #include <c10/macros/Macros.h> | ||
|
|
||
| #include <c10d/ProcessGroup.hpp> | ||
| #include <c10d/Types.hpp> | ||
| #include <c10d/Utils.hpp> | ||
| #include <c10d/debug.h> | ||
| #include <c10d/sequence_num.hpp> | ||
|
|
||
| constexpr auto kDefaultTimeout = | ||
| std::chrono::milliseconds(30 * 60 * 1000); | ||
|
|
||
| namespace c10d { | ||
|
|
||
| // Options is a base struct that defines the basic options | ||
| // when constructing a Backend. Each Backend subclass should | ||
| // extend this struct and define its options if it wants to provide more | ||
| // config options (beyond basic ones defined here) to end user. | ||
| struct TORCH_API Options : torch::CustomClassHolder { | ||
| explicit Options( | ||
| std::string backend, | ||
| std::chrono::milliseconds timeout = kDefaultTimeout) | ||
| : timeout(timeout), backend(backend) {} | ||
| virtual ~Options() = default; | ||
|
|
||
| std::chrono::milliseconds timeout; | ||
|
|
||
| // backend name | ||
| const std::string backend; | ||
| }; | ||
|
|
||
| class TORCH_API Backend : public torch::CustomClassHolder { | ||
| public: | ||
| explicit Backend(int rank, int size); | ||
| virtual ~Backend() = 0; | ||
|
|
||
| // Subclasses must override this method to return the backend name | ||
| virtual const std::string getBackendName() const { | ||
| TORCH_INTERNAL_ASSERT(false, "getBackendName is not implemented."); | ||
| }; | ||
|
|
||
| virtual c10::intrusive_ptr<ProcessGroup::Work> broadcast( | ||
| std::vector<at::Tensor>& /* tensors */, | ||
| const BroadcastOptions& /* opts */ = BroadcastOptions()) { | ||
| TORCH_CHECK( | ||
| false, | ||
| c10::str("Backend ", getBackendName(), "does not support broadcast")); | ||
| } | ||
|
|
||
| virtual c10::intrusive_ptr<ProcessGroup::Work> allreduce( | ||
| std::vector<at::Tensor>& /* tensors */, | ||
| const AllreduceOptions& /* opts */ = AllreduceOptions()) { | ||
| TORCH_CHECK( | ||
| false, | ||
| c10::str("Backend ", getBackendName(), "does not support allreduce")); | ||
| } | ||
|
|
||
| virtual c10::intrusive_ptr<ProcessGroup::Work> allreduce_coalesced( | ||
| std::vector<at::Tensor>& /* tensors */, | ||
| const AllreduceCoalescedOptions& /* opts */ = | ||
| AllreduceCoalescedOptions()) { | ||
| TORCH_CHECK( | ||
| false, | ||
| c10::str( | ||
| "Backend ", | ||
| getBackendName(), | ||
| "does not support allreduce_coalesced")); | ||
| } | ||
|
|
||
| virtual c10::intrusive_ptr<ProcessGroup::Work> reduce( | ||
| std::vector<at::Tensor>& /* tensors */, | ||
| const ReduceOptions& /* opts */ = ReduceOptions()) { | ||
| TORCH_CHECK( | ||
| false, | ||
| c10::str("Backend ", getBackendName(), "does not support reduce")); | ||
| } | ||
|
|
||
| virtual c10::intrusive_ptr<ProcessGroup::Work> allgather( | ||
| std::vector<std::vector<at::Tensor>>& /* outputTensors */, | ||
| std::vector<at::Tensor>& /* inputTensors */, | ||
| const AllgatherOptions& /* opts */ = AllgatherOptions()) { | ||
| TORCH_CHECK( | ||
| false, | ||
| c10::str("Backend ", getBackendName(), "does not support allgather")); | ||
| } | ||
|
|
||
| // Gathers a single tensor inputBuffer into a single buffer outputBuffer that | ||
| // is interpreted as a contigious collection of size inputBuffer * WORLD_SIZE. | ||
| // For implementers of ProcessGroup API and advanced users only. | ||
| // Note: this function will be deprecated in near future. | ||
| virtual c10::intrusive_ptr<ProcessGroup::Work> _allgather_base( | ||
| at::Tensor& /* outputBuffer */, | ||
| at::Tensor& /* inputBuffer */, | ||
| const AllgatherOptions& /* opts */ = AllgatherOptions()) { | ||
| TORCH_CHECK( | ||
| false, | ||
| c10::str( | ||
| "Backend ", getBackendName(), "does not support _allgather_base")); | ||
| } | ||
|
|
||
| // This function is deprecated and will be moved out of Backend to comms: | ||
| // * do not add dependencies on this function, | ||
| // * do not implement it in your Backend, implement _allgather_base | ||
| // instead. | ||
| virtual c10::intrusive_ptr<ProcessGroup::Work> allgather_coalesced( | ||
| std::vector<std::vector<at::Tensor>>& /* outputTensorLists */, | ||
| std::vector<at::Tensor>& /* inputTensors */, | ||
| const AllgatherOptions& /* opts */ = AllgatherOptions()) { | ||
| TORCH_CHECK( | ||
| false, | ||
| c10::str( | ||
| "Backend ", | ||
| getBackendName(), | ||
| "does not support allgather_coalesced")); | ||
| } | ||
|
|
||
| virtual c10::intrusive_ptr<ProcessGroup::Work> gather( | ||
| std::vector<std::vector<at::Tensor>>& /* outputTensors */, | ||
| std::vector<at::Tensor>& /* inputTensors */, | ||
| const GatherOptions& /* opts */ = GatherOptions()) { | ||
| TORCH_CHECK( | ||
| false, | ||
| c10::str("Backend ", getBackendName(), "does not support gather")); | ||
| } | ||
|
|
||
| virtual c10::intrusive_ptr<ProcessGroup::Work> scatter( | ||
| std::vector<at::Tensor>& /* outputTensors */, | ||
| std::vector<std::vector<at::Tensor>>& /* inputTensors */, | ||
| const ScatterOptions& /* opts */ = ScatterOptions()) { | ||
| TORCH_CHECK( | ||
| false, | ||
| c10::str("Backend ", getBackendName(), "does not support scatter")); | ||
| } | ||
|
|
||
| virtual c10::intrusive_ptr<ProcessGroup::Work> reduce_scatter( | ||
| std::vector<at::Tensor>& /* outputTensors */, | ||
| std::vector<std::vector<at::Tensor>>& /* inputTensors */, | ||
| const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) { | ||
| TORCH_CHECK( | ||
| false, | ||
| c10::str( | ||
| "Backend ", getBackendName(), "does not support reduce_scatter")); | ||
| } | ||
|
|
||
| virtual c10::intrusive_ptr<ProcessGroup::Work> _reduce_scatter_base( | ||
| at::Tensor& /* outputBuffer */, | ||
| at::Tensor& /* inputBuffer */, | ||
| const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) { | ||
| TORCH_CHECK( | ||
| false, | ||
| c10::str( | ||
| "Backend ", | ||
| getBackendName(), | ||
| "does not support _reduce_scatter_base")); | ||
| } | ||
|
|
||
| virtual c10::intrusive_ptr<ProcessGroup::Work> alltoall_base( | ||
| at::Tensor& /* outputBuffer */, | ||
| at::Tensor& /* inputBuffer */, | ||
| std::vector<int64_t>& /* outputSplitSizes */, | ||
| std::vector<int64_t>& /* inputSplitSizes */, | ||
| const AllToAllOptions& /* opts */ = AllToAllOptions()) { | ||
| TORCH_CHECK( | ||
| false, | ||
| c10::str( | ||
| "Backend ", getBackendName(), "does not support alltoall_base")); | ||
| } | ||
|
|
||
| virtual c10::intrusive_ptr<ProcessGroup::Work> alltoall( | ||
| std::vector<at::Tensor>& /* outputTensors */, | ||
| std::vector<at::Tensor>& /* inputTensors */, | ||
| const AllToAllOptions& opts = AllToAllOptions()) { | ||
| TORCH_CHECK( | ||
| false, | ||
| c10::str("Backend ", getBackendName(), "does not support alltoall")); | ||
| } | ||
|
|
||
| virtual void monitoredBarrier( | ||
| const BarrierOptions& /* unused */, | ||
| bool /* unused */ = false) { | ||
| auto backendName = getBackendName(); | ||
| TORCH_CHECK( | ||
| false, | ||
| c10::str( | ||
| "Backend ", | ||
| backendName, | ||
| " does not support monitoredBarrier, only GLOO supports monitored barrier.")); | ||
| } | ||
|
|
||
| // Agrees on an initial sequence number for the whole group by having rank 0 | ||
| // create it and broadcast it to other ranks using the store. Only implemented | ||
| // for GLOO and NCCL backends currently. | ||
| virtual void setSequenceNumberForGroup() { | ||
| auto backendName = getBackendName(); | ||
| TORCH_CHECK( | ||
| false, | ||
| c10::str( | ||
| "Backend ", | ||
| backendName, | ||
| " does not yet support sequence numbers.")); | ||
| } | ||
|
|
||
| // Retrieves the current sequence number for the whole group, which should be | ||
| // in sync. If the returned number is not consistent across the group, it | ||
| // may indicate that there is some sort of collective desynchronization. | ||
| virtual uint64_t getSequenceNumberForGroup() { | ||
| auto backendName = getBackendName(); | ||
| TORCH_CHECK( | ||
| false, | ||
| c10::str( | ||
| "Backend ", | ||
| backendName, | ||
| " does not yet support sequence numbers.")); | ||
| } | ||
|
|
||
| virtual c10::intrusive_ptr<ProcessGroup::Work> send( | ||
| std::vector<at::Tensor>& /* tensors */, | ||
| int /* dstRank */, | ||
| int /* tag */) { | ||
| TORCH_CHECK( | ||
| false, c10::str("Backend ", getBackendName(), "does not support send")); | ||
| } | ||
|
|
||
| virtual c10::intrusive_ptr<ProcessGroup::Work> recv( | ||
| std::vector<at::Tensor>& /* tensors */, | ||
| int /* srcRank */, | ||
| int /* tag */) { | ||
| TORCH_CHECK( | ||
| false, c10::str("Backend ", getBackendName(), "does not support recv")); | ||
| } | ||
|
|
||
| virtual c10::intrusive_ptr<ProcessGroup::Work> recvAnysource( | ||
| std::vector<at::Tensor>& /* tensors */, | ||
| int /* tag */) { | ||
| TORCH_CHECK( | ||
| false, | ||
| c10::str( | ||
| "Backend ", getBackendName(), "does not support recvAnysource")); | ||
| } | ||
|
|
||
| virtual c10::intrusive_ptr<ProcessGroup::Work> barrier( | ||
| const BarrierOptions& /* opts */ = BarrierOptions()) { | ||
| TORCH_CHECK( | ||
| false, | ||
| c10::str("Backend ", getBackendName(), "does not support barrier")); | ||
| } | ||
|
|
||
| int getRank() const { | ||
| return rank_; | ||
| } | ||
|
|
||
| int getSize() const { | ||
| return size_; | ||
| } | ||
|
|
||
| protected: | ||
| // Implementations of this interface need to call this to setup | ||
| // appropriate logging etc. | ||
| void init(); | ||
|
|
||
| // Optional sequence number structure for matching collectives. | ||
| c10::optional<c10d::SequenceNum> sequenceNum_ = c10::nullopt; | ||
| const int rank_; | ||
| const int size_; | ||
| }; | ||
|
|
||
| } // namespace c10d | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for our record: we'd need to think whether
SequenceNumberrelated methods should stay here or be moved toProcessGroup. For the purpose of smooth refactorization, they can stay here for now.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, agree on this!