Skip to content

Commit 29b1ab2

Browse files
committed
[WIP][DeviceMesh] Use a shared_state to cache pg per layout, root_mesh and rank_map
ghstack-source-id: 49b401d Pull Request resolved: #166010
1 parent 35dce00 commit 29b1ab2

File tree

7 files changed

+430
-207
lines changed

7 files changed

+430
-207
lines changed

test/distributed/test_device_mesh.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,6 +1000,9 @@ def test_unflatten_mesh_3d(self):
10001000
)
10011001
non_ep_mesh = global_mesh._unflatten(0, (2, 2, 2), ("dp", "cp", "tp"))
10021002
ep_mesh = global_mesh._unflatten(0, (2, 2, 2), ("dp", "ep", "ep_tp"))
1003+
# test pg caching when unflatten into same layout.
1004+
self.assertEqual(non_ep_mesh["dp"].get_group(), ep_mesh["dp"].get_group())
1005+
self.assertEqual(non_ep_mesh["tp"].get_group(), ep_mesh["ep_tp"].get_group())
10031006
self.assertEqual(non_ep_mesh["cp"].mesh, ep_mesh["ep"].mesh)
10041007
self.assertEqual(non_ep_mesh["tp"].mesh, ep_mesh["ep_tp"].mesh)
10051008
mesh_3d = global_mesh._unflatten(0, (4, 2, 1), ("dp", "cp", "tp"))

torch/csrc/distributed/c10d/Backend.hpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include <functional>
34
#include <memory>
45
#include <utility>
56
#include <vector>
@@ -48,6 +49,12 @@ class TORCH_API Backend : public torch::CustomClassHolder {
4849
const std::string backend;
4950
std::string group_name;
5051
std::vector<uint64_t> global_ranks_in_group;
52+
53+
bool operator==(const Options& other) const noexcept {
54+
return timeout == other.timeout && backend == other.backend &&
55+
group_name == other.group_name &&
56+
global_ranks_in_group == other.global_ranks_in_group;
57+
}
5158
};
5259

5360
explicit Backend(int rank, int size);
@@ -511,3 +518,24 @@ class TORCH_API Backend : public torch::CustomClassHolder {
511518
};
512519

513520
} // namespace c10d
521+
522+
// small helper
523+
inline void hash_combine(std::size_t& seed, std::size_t value) noexcept {
524+
seed ^= value + 0x9e3779b97f4a7c15ULL + (seed << 6) + (seed >> 2);
525+
}
526+
527+
namespace std {
528+
529+
template <>
530+
struct hash<c10d::Backend::Options> {
531+
std::size_t operator()(const c10d::Backend::Options& o) const noexcept {
532+
std::size_t h = 0;
533+
hash_combine(h, std::hash<long long>{}(o.timeout.count()));
534+
hash_combine(h, std::hash<std::string>{}(o.backend));
535+
hash_combine(h, std::hash<std::string>{}(o.group_name));
536+
for (auto x : o.global_ranks_in_group)
537+
hash_combine(h, std::hash<uint64_t>{}(x));
538+
return h;
539+
}
540+
};
541+
} // namespace std

torch/csrc/distributed/c10d/ProcessGroupGloo.hpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,23 @@ class TORCH_API ProcessGroupGloo : public Backend {
260260

261261
std::vector<std::shared_ptr<::gloo::transport::Device>> devices;
262262
int threads;
263+
264+
bool operator==(const Options& other) const noexcept {
265+
// 1) compare base first
266+
if (!static_cast<const Backend::Options&>(*this).operator==(other))
267+
return false;
268+
269+
// 2) compare devices by identity
270+
if (devices.size() != other.devices.size())
271+
return false;
272+
for (size_t i = 0; i < devices.size(); ++i) {
273+
if (devices[i].get() != other.devices[i].get()) // pointer identity
274+
return false;
275+
}
276+
277+
// 3) compare added scalar fields
278+
return threads == other.threads;
279+
}
263280
};
264281

265282
const std::string getBackendName() const override {
@@ -494,4 +511,24 @@ class TORCH_API ProcessGroupGloo : public Backend {
494511

495512
} // namespace c10d
496513

514+
namespace std {
515+
template <>
516+
struct hash<c10d::ProcessGroupGloo::Options> {
517+
std::size_t operator()(
518+
const c10d::ProcessGroupGloo::Options& o) const noexcept {
519+
std::size_t h = 0;
520+
// reuse base hash
521+
hash_combine(
522+
h,
523+
std::hash<c10d::Backend::Options>{}(
524+
static_cast<const c10d::Backend::Options&>(o)));
525+
// add derived fields
526+
for (auto const& dev : o.devices)
527+
hash_combine(h, std::hash<const void*>{}(dev.get()));
528+
hash_combine(h, std::hash<int>{}(o.threads));
529+
return h;
530+
}
531+
};
532+
} // namespace std
533+
497534
#endif // USE_C10D_GLOO

torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,33 @@ class TORCH_API ProcessGroupNCCL : public Backend {
550550
// the int value of `NCCL_SPLIT_NOCOLOR` (-1) instead.
551551
int split_color{-2};
552552
#endif
553+
554+
bool operator==(const Options& other) const noexcept {
555+
// 1) compare base first
556+
if (!static_cast<const Backend::Options&>(*this).operator==(other))
557+
return false;
558+
559+
// 2) simple fields
560+
if (is_high_priority_stream != other.is_high_priority_stream) {
561+
return false;
562+
}
563+
if (split_color != other.split_color) {
564+
return false;
565+
}
566+
567+
// 3) split_from: compare by identity
568+
if (split_from.get() != other.split_from.get()) {
569+
return false;
570+
}
571+
572+
#ifdef NCCL_HAS_CONFIG
573+
// 4) config
574+
if (std::memcmp(&config, &other.config, sizeof(ncclConfig_t)) != 0) {
575+
return false;
576+
}
577+
#endif
578+
return true;
579+
}
553580
};
554581

555582
// Helper class related to TORCH_NCCL_DESYNC_DEBUG
@@ -1504,4 +1531,46 @@ typedef bool (*gil_checker_t)();
15041531
TORCH_API gil_checker_t& get_gil_checker();
15051532
} // namespace c10d
15061533

1534+
#ifdef NCCL_HAS_CONFIG
1535+
inline std::size_t hash_nccl_config(const ncclConfig_t& cfg) noexcept {
1536+
const unsigned char* p = reinterpret_cast<const unsigned char*>(&cfg);
1537+
std::size_t h = 0;
1538+
for (std::size_t i = 0; i < sizeof(cfg); ++i) {
1539+
hash_combine(h, static_cast<std::size_t>(p[i]));
1540+
}
1541+
return h;
1542+
}
1543+
#endif
1544+
1545+
namespace std {
1546+
1547+
template <>
1548+
struct hash<c10d::ProcessGroupNCCL::Options> {
1549+
std::size_t operator()(
1550+
const c10d::ProcessGroupNCCL::Options& o) const noexcept {
1551+
std::size_t h = 0;
1552+
1553+
// 1) base
1554+
hash_combine(
1555+
h,
1556+
std::hash<c10d::Backend::Options>{}(
1557+
static_cast<const c10d::Backend::Options&>(o)));
1558+
1559+
// 2) trivial extras
1560+
hash_combine(h, std::hash<bool>{}(o.is_high_priority_stream));
1561+
hash_combine(h, std::hash<int>{}(o.split_color));
1562+
1563+
// 3) pointer identity for split_from
1564+
hash_combine(h, std::hash<const void*>{}(o.split_from.get()));
1565+
1566+
#ifdef NCCL_HAS_CONFIG
1567+
// 4) config — option A: hash bytes
1568+
hash_combine(h, hash_nccl_config(o.config));
1569+
#endif
1570+
return h;
1571+
}
1572+
};
1573+
1574+
} // namespace std
1575+
15071576
#endif // USE_C10D_NCCL

torch/csrc/distributed/c10d/init.cpp

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3107,7 +3107,14 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
31073107
.def_readwrite(
31083108
"global_ranks_in_group",
31093109
&::c10d::Backend::Options::global_ranks_in_group)
3110-
.def_readwrite("group_name", &::c10d::Backend::Options::group_name);
3110+
.def_readwrite("group_name", &::c10d::Backend::Options::group_name)
3111+
.def(
3112+
"__eq__",
3113+
[](const ::c10d::Backend::Options& a,
3114+
const ::c10d::Backend::Options& b) { return a == b; })
3115+
.def("__hash__", [](const ::c10d::Backend::Options& a) {
3116+
return std::hash<::c10d::Backend::Options>{}(a);
3117+
});
31113118

31123119
#ifdef USE_C10D_GLOO
31133120
auto processGroupGloo =
@@ -3121,7 +3128,14 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
31213128
processGroupGloo, "_Options", backendOptions)
31223129
.def(py::init<>())
31233130
.def_readwrite("_devices", &::c10d::ProcessGroupGloo::Options::devices)
3124-
.def_readwrite("_threads", &::c10d::ProcessGroupGloo::Options::threads);
3131+
.def_readwrite("_threads", &::c10d::ProcessGroupGloo::Options::threads)
3132+
.def(
3133+
"__eq__",
3134+
[](const ::c10d::ProcessGroupGloo::Options& a,
3135+
const ::c10d::ProcessGroupGloo::Options& b) { return a == b; })
3136+
.def("__hash__", [](const ::c10d::ProcessGroupGloo::Options& a) {
3137+
return std::hash<::c10d::ProcessGroupGloo::Options>{}(a);
3138+
});
31253139

31263140
processGroupGloo
31273141
.def_static(
@@ -3481,6 +3495,15 @@ Example::
34813495
"split_from", &::c10d::ProcessGroupNCCL::Options::split_from)
34823496
.def_readwrite(
34833497
"split_color", &::c10d::ProcessGroupNCCL::Options::split_color)
3498+
.def(
3499+
"__eq__",
3500+
[](const ::c10d::ProcessGroupNCCL::Options& a,
3501+
const ::c10d::ProcessGroupNCCL::Options& b) { return a == b; })
3502+
.def(
3503+
"__hash__",
3504+
[](const ::c10d::ProcessGroupNCCL::Options& a) {
3505+
return std::hash<::c10d::ProcessGroupNCCL::Options>{}(a);
3506+
})
34843507
.def(
34853508
"__copy__",
34863509
[](const ::c10d::ProcessGroupNCCL::Options& self) {

torch/distributed/_local_tensor/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -951,7 +951,9 @@ def get_coordinate(self: DeviceMesh) -> Optional[list[int] | None]:
951951

952952
coords: list[dict[int, int]] = [{} for _ in range(self.ndim)]
953953
for r in lm.ranks:
954-
rank_tensor = self._layout.remap_to_tensor(self._rank_map)
954+
rank_tensor = self._layout.remap_to_tensor(
955+
self._shared_state.get_rank_map()
956+
)
955957
rank_coords = (rank_tensor == r).nonzero().tolist()
956958
assert len(rank_coords) == 1
957959
for d, c in enumerate(rank_coords[0][1:]):

0 commit comments

Comments
 (0)