Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions test/distributed/test_device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,26 @@ def test_unflatten_mesh_3d(self):
)
w.wait()

@with_comms
def test_unflatten_mesh_3d_with_pg_cache(self):
# Turn on gate for not saving PG names for device mesh when it comes to torch.save.
# This also turns on pg cache
DeviceMesh.decouple_backend_at_save = True
# Test unflatten from a dummy world mesh, which is the case we need for Expert Parallelism(EP).
global_mesh = init_device_mesh(
self.device_type,
(8,),
mesh_dim_names=("world",),
)
non_ep_mesh = global_mesh._unflatten(0, (2, 2, 2), ("dp", "cp", "tp"))
ep_mesh = global_mesh._unflatten(0, (2, 2, 2), ("dp", "ep", "ep_tp"))
Comment on lines +1065 to +1066
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between these two lines vs a user giving multiple names to a dimension? Or a complex name such as "cp_or_ep"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is to generate a use case where we want to test PG cache. The cache is per layout so when multiple names are assigned to one dimension, the PG will be shared.

self.assertEqual(non_ep_mesh["cp"].mesh, ep_mesh["ep"].mesh)
self.assertEqual(non_ep_mesh["tp"].mesh, ep_mesh["ep_tp"].mesh)
# test pg caching when unflatten into same layout.
self.assertEqual(non_ep_mesh["dp"].get_group(), ep_mesh["dp"].get_group())
self.assertEqual(non_ep_mesh["tp"].get_group(), ep_mesh["ep_tp"].get_group())
DeviceMesh.decouple_backend_at_save = False

@with_comms
def test_concatenate_2d(self):
mesh_shape = (2, 4)
Expand Down
28 changes: 28 additions & 0 deletions torch/csrc/distributed/c10d/Backend.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <functional>
#include <memory>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -48,6 +49,12 @@ class TORCH_API Backend : public torch::CustomClassHolder {
const std::string backend;
std::string group_name;
std::vector<uint64_t> global_ranks_in_group;

bool operator==(const Options& other) const noexcept {
return timeout == other.timeout && backend == other.backend &&
group_name == other.group_name &&
global_ranks_in_group == other.global_ranks_in_group;
}
};

explicit Backend(int rank, int size);
Expand Down Expand Up @@ -511,3 +518,24 @@ class TORCH_API Backend : public torch::CustomClassHolder {
};

} // namespace c10d

// small helper
inline void hash_combine(std::size_t& seed, std::size_t value) noexcept {
seed ^= value + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2);
}

namespace std {

template <>
struct hash<c10d::Backend::Options> {
std::size_t operator()(const c10d::Backend::Options& o) const noexcept {
std::size_t h = 0;
hash_combine(h, std::hash<long long>{}(o.timeout.count()));
hash_combine(h, std::hash<std::string>{}(o.backend));
hash_combine(h, std::hash<std::string>{}(o.group_name));
for (auto x : o.global_ranks_in_group)
hash_combine(h, std::hash<uint64_t>{}(x));
return h;
}
};
} // namespace std
37 changes: 37 additions & 0 deletions torch/csrc/distributed/c10d/ProcessGroupGloo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,23 @@ class TORCH_API ProcessGroupGloo : public Backend {

std::vector<std::shared_ptr<::gloo::transport::Device>> devices;
int threads;

bool operator==(const Options& other) const noexcept {
// 1) compare base first
if (!static_cast<const Backend::Options&>(*this).operator==(other))
return false;

// 2) compare devices by identity
if (devices.size() != other.devices.size())
return false;
for (size_t i = 0; i < devices.size(); ++i) {
if (devices[i].get() != other.devices[i].get()) // pointer identity
return false;
}

// 3) compare added scalar fields
return threads == other.threads;
}
};

const std::string getBackendName() const override {
Expand Down Expand Up @@ -494,4 +511,24 @@ class TORCH_API ProcessGroupGloo : public Backend {

} // namespace c10d

namespace std {
template <>
struct hash<c10d::ProcessGroupGloo::Options> {
std::size_t operator()(
const c10d::ProcessGroupGloo::Options& o) const noexcept {
std::size_t h = 0;
// reuse base hash
hash_combine(
h,
std::hash<c10d::Backend::Options>{}(
static_cast<const c10d::Backend::Options&>(o)));
// add derived fields
for (auto const& dev : o.devices)
hash_combine(h, std::hash<const void*>{}(dev.get()));
hash_combine(h, std::hash<int>{}(o.threads));
return h;
}
};
} // namespace std

#endif // USE_C10D_GLOO
69 changes: 69 additions & 0 deletions torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,33 @@ class TORCH_API ProcessGroupNCCL : public Backend {
// the int value of `NCCL_SPLIT_NOCOLOR` (-1) instead.
int split_color{-2};
#endif

bool operator==(const Options& other) const noexcept {
// 1) compare base first
if (!static_cast<const Backend::Options&>(*this).operator==(other))
return false;

// 2) simple fields
if (is_high_priority_stream != other.is_high_priority_stream) {
return false;
}
if (split_color != other.split_color) {
return false;
}

// 3) split_from: compare by identity
if (split_from.get() != other.split_from.get()) {
return false;
}

#ifdef NCCL_HAS_CONFIG
// 4) config
if (std::memcmp(&config, &other.config, sizeof(ncclConfig_t)) != 0) {
return false;
}
#endif
return true;
}
};

// Helper class related to TORCH_NCCL_DESYNC_DEBUG
Expand Down Expand Up @@ -1504,4 +1531,46 @@ typedef bool (*gil_checker_t)();
TORCH_API gil_checker_t& get_gil_checker();
} // namespace c10d

#ifdef NCCL_HAS_CONFIG
inline std::size_t hash_nccl_config(const ncclConfig_t& cfg) noexcept {
const unsigned char* p = reinterpret_cast<const unsigned char*>(&cfg);
std::size_t h = 0;
for (std::size_t i = 0; i < sizeof(cfg); ++i) {
hash_combine(h, static_cast<std::size_t>(p[i]));
}
return h;
}
#endif

namespace std {

template <>
struct hash<c10d::ProcessGroupNCCL::Options> {
std::size_t operator()(
const c10d::ProcessGroupNCCL::Options& o) const noexcept {
std::size_t h = 0;

// 1) base
hash_combine(
h,
std::hash<c10d::Backend::Options>{}(
static_cast<const c10d::Backend::Options&>(o)));

// 2) trivial extras
hash_combine(h, std::hash<bool>{}(o.is_high_priority_stream));
hash_combine(h, std::hash<int>{}(o.split_color));

// 3) pointer identity for split_from
hash_combine(h, std::hash<const void*>{}(o.split_from.get()));

#ifdef NCCL_HAS_CONFIG
// 4) config — option A: hash bytes
hash_combine(h, hash_nccl_config(o.config));
#endif
return h;
}
};

} // namespace std

#endif // USE_C10D_NCCL
27 changes: 25 additions & 2 deletions torch/csrc/distributed/c10d/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3107,7 +3107,14 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
.def_readwrite(
"global_ranks_in_group",
&::c10d::Backend::Options::global_ranks_in_group)
.def_readwrite("group_name", &::c10d::Backend::Options::group_name);
.def_readwrite("group_name", &::c10d::Backend::Options::group_name)
.def(
"__eq__",
[](const ::c10d::Backend::Options& a,
const ::c10d::Backend::Options& b) { return a == b; })
.def("__hash__", [](const ::c10d::Backend::Options& a) {
return std::hash<::c10d::Backend::Options>{}(a);
});

#ifdef USE_C10D_GLOO
auto processGroupGloo =
Expand All @@ -3121,7 +3128,14 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
processGroupGloo, "_Options", backendOptions)
.def(py::init<>())
.def_readwrite("_devices", &::c10d::ProcessGroupGloo::Options::devices)
.def_readwrite("_threads", &::c10d::ProcessGroupGloo::Options::threads);
.def_readwrite("_threads", &::c10d::ProcessGroupGloo::Options::threads)
.def(
"__eq__",
[](const ::c10d::ProcessGroupGloo::Options& a,
const ::c10d::ProcessGroupGloo::Options& b) { return a == b; })
.def("__hash__", [](const ::c10d::ProcessGroupGloo::Options& a) {
return std::hash<::c10d::ProcessGroupGloo::Options>{}(a);
});

processGroupGloo
.def_static(
Expand Down Expand Up @@ -3481,6 +3495,15 @@ Example::
"split_from", &::c10d::ProcessGroupNCCL::Options::split_from)
.def_readwrite(
"split_color", &::c10d::ProcessGroupNCCL::Options::split_color)
.def(
"__eq__",
[](const ::c10d::ProcessGroupNCCL::Options& a,
const ::c10d::ProcessGroupNCCL::Options& b) { return a == b; })
.def(
"__hash__",
[](const ::c10d::ProcessGroupNCCL::Options& a) {
return std::hash<::c10d::ProcessGroupNCCL::Options>{}(a);
})
.def(
"__copy__",
[](const ::c10d::ProcessGroupNCCL::Options& self) {
Expand Down
4 changes: 3 additions & 1 deletion torch/distributed/_local_tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,9 @@ def get_coordinate(self: DeviceMesh) -> Optional[list[int] | None]:

coords: list[dict[int, int]] = [{} for _ in range(self.ndim)]
for r in lm.ranks:
rank_tensor = self._layout.remap_to_tensor(self._rank_map)
rank_tensor = self._layout.remap_to_tensor(
self._shared_state.get_rank_map()
)
rank_coords = (rank_tensor == r).nonzero().tolist()
assert len(rank_coords) == 1
for d, c in enumerate(rank_coords[0][1:]):
Expand Down
Loading
Loading