Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions test/distributed/test_c10d_gloo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1415,11 +1415,10 @@ def test_barrier_implies_wait(self):
def test_round_robin(self):
num_process_groups = 2
store = c10d.FileStore(self.file_name, self.world_size)
c10d.init_process_group(backend="gloo", store=store, rank=self.rank, world_size=self.world_size)
pg = c10d._round_robin_process_groups(
[
self._create_process_group_gloo(
c10d.PrefixStore(str(i), store), self.rank, self.world_size, self.opts()
)
c10d.new_group(pg_options=self.opts())
for i in range(num_process_groups)
]
)
Expand All @@ -1434,13 +1433,12 @@ def test_round_robin(self):
@requires_gloo()
def test_round_robin_create_destroy(self):
store = c10d.FileStore(self.file_name, self.world_size)
c10d.init_process_group(backend="gloo", store=store, rank=self.rank, world_size=self.world_size)

def create(num, prefix):
return c10d._round_robin_process_groups(
[
self._create_process_group_gloo(
c10d.PrefixStore("%s/%d" % (prefix, i), store), self.rank, self.world_size, self.opts()
)
c10d.new_group(pg_options=self.opts())
for i in range(num)
]
)
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/distributed/c10d/ProcessGroupRoundRobin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace c10d {
ProcessGroupRoundRobin::ProcessGroupRoundRobin(
int rank,
int size,
std::vector<c10::intrusive_ptr<Backend>> processGroups)
std::vector<c10::intrusive_ptr<ProcessGroup>> processGroups)
: ProcessGroup(rank, size), processGroups_(std::move(processGroups)) {
TORCH_WARN(
"ProcessGroupRoundRobin is deprecated and scheduled to be removed after this current release (1.13). ",
Expand Down Expand Up @@ -114,7 +114,7 @@ c10::intrusive_ptr<Work> ProcessGroupRoundRobin::barrier(
TORCH_CHECK(false, "ProcessGroupRoundRobin does not support barrier");
};

const c10::intrusive_ptr<Backend>& ProcessGroupRoundRobin::next() {
const c10::intrusive_ptr<ProcessGroup>& ProcessGroupRoundRobin::next() {
auto& processGroup = *iterator_;
iterator_++;
if (iterator_ == processGroups_.end()) {
Expand Down
8 changes: 4 additions & 4 deletions torch/csrc/distributed/c10d/ProcessGroupRoundRobin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class TORCH_API ProcessGroupRoundRobin final : public ProcessGroup {
explicit ProcessGroupRoundRobin(
int rank,
int size,
std::vector<c10::intrusive_ptr<Backend>> processGroups);
std::vector<c10::intrusive_ptr<ProcessGroup>> processGroups);

~ProcessGroupRoundRobin() override;

Expand Down Expand Up @@ -103,11 +103,11 @@ class TORCH_API ProcessGroupRoundRobin final : public ProcessGroup {
const BarrierOptions& opts = BarrierOptions()) override;

private:
std::vector<c10::intrusive_ptr<Backend>> processGroups_;
std::vector<c10::intrusive_ptr<Backend>>::const_iterator iterator_;
std::vector<c10::intrusive_ptr<ProcessGroup>> processGroups_;
std::vector<c10::intrusive_ptr<ProcessGroup>>::const_iterator iterator_;

// Returns the next ProcessGroup to use.
const c10::intrusive_ptr<Backend>& next();
const c10::intrusive_ptr<ProcessGroup>& next();
};

} // namespace c10d
2 changes: 1 addition & 1 deletion torch/csrc/distributed/c10d/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1627,7 +1627,7 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
#ifndef _WIN32
module.def(
"_round_robin_process_groups",
[](std::vector<c10::intrusive_ptr<::c10d::Backend>> processGroups)
[](std::vector<c10::intrusive_ptr<::c10d::ProcessGroup>> processGroups)
-> c10::intrusive_ptr<::c10d::ProcessGroup> {
if (processGroups.size() == 0) {
throw std::invalid_argument("Specify at least 1 process group");
Expand Down