Skip to content

Commit da123be

Browse files
committed
[DeviceMesh] Simplifying internal bookkeeping with CuTe layout
ghstack-source-id: eb6752b Pull Request resolved: #163213
1 parent 88a7906 commit da123be

File tree

3 files changed

+354
-133
lines changed

3 files changed

+354
-133
lines changed

test/distributed/test_device_mesh.py

Lines changed: 113 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,7 @@ def test_device_mesh_parent_child_hash(self):
440440
ep_mesh = ep_mesh_1 if self.rank < self.world_size // 2 else ep_mesh_2
441441
# ep_mesh is considered different from mesh_2d["TP"]
442442
self.assertEqual(mesh_2d["TP"]._flatten_mesh_list, ep_mesh._flatten_mesh_list)
443+
self.assertEqual(mesh_2d["TP"]._layout, ep_mesh._layout)
443444
self.assertEqual(mesh_2d["TP"].mesh.shape, ep_mesh.mesh.shape)
444445
self.assertEqual(mesh_2d["TP"].device_type, ep_mesh.device_type)
445446
self.assertNotEqual(mesh_2d["TP"].mesh_dim_names, ep_mesh.mesh_dim_names)
@@ -454,6 +455,7 @@ def test_device_mesh_parent_child_hash(self):
454455
)
455456
# another_mesh is considered the same as ep_mesh
456457
self.assertEqual(ep_mesh._flatten_mesh_list, another_mesh._flatten_mesh_list)
458+
self.assertEqual(ep_mesh._layout, another_mesh._layout)
457459
self.assertEqual(ep_mesh.mesh.shape, another_mesh.mesh.shape)
458460
self.assertEqual(ep_mesh.device_type, another_mesh.device_type)
459461
self.assertEqual(ep_mesh.mesh_dim_names, another_mesh.mesh_dim_names)
@@ -539,7 +541,6 @@ def test_from_group_with_mesh_shape_2d(self):
539541
mesh_dim_names=("dp_replicate", "dp_shard"),
540542
)
541543

542-
# self.assertEqual(ref_mesh._dim_group_names, dp_mesh._dim_group_names)
543544
for mesh_dim_group, ref_mesh_dim_group in zip(
544545
dp_mesh.get_all_groups(), ref_mesh.get_all_groups()
545546
):
@@ -800,6 +801,10 @@ def test_get_item_3d(self):
800801
# Test slicing out 1D mesh from a sub-2D mesh.
801802
shard_mesh = hsdp_mesh_2["Shard"]
802803
self.assertEqual(shard_mesh.mesh.tolist(), shard_group[shard_group_idx])
804+
replicate_mesh = hsdp_mesh_2["Replicate"]
805+
self.assertEqual(
806+
replicate_mesh.mesh.tolist(), replicate_group[replicate_group_idx]
807+
)
803808

804809
@with_comms
805810
def test_cache_and_reuse_submesh_slice_result(self):
@@ -838,11 +843,14 @@ def test_get_item_3d_noncontiguous_slicing(self):
838843
# Check on the current dp_local_rank, whether the cp mesh tensor is the same.
839844
self.assertEqual(dp_cp_mesh.mesh[dp_local_rank], cp_mesh.mesh)
840845

841-
with self.assertRaisesRegex(
842-
KeyError,
843-
"Invalid mesh_dim_names",
844-
):
845-
mesh_3d["cp", "dp"]
846+
# Support transpose slicing.
847+
cp_dp_mesh = mesh_3d["cp", "dp"]
848+
expected_mesh_tensor = (
849+
torch.tensor([[0, 4], [1, 5]], dtype=torch.int)
850+
if self.rank in (0, 1, 4, 5)
851+
else torch.tensor([[2, 6], [3, 7]], dtype=torch.int)
852+
)
853+
self.assertEqual(cp_dp_mesh.mesh, expected_mesh_tensor)
846854

847855
@with_comms
848856
def test_flatten_mesh_1d(self):
@@ -875,10 +883,14 @@ def test_flatten_mesh_3d(self):
875883
self.assertEqual(flattened_dp_cp_mesh.mesh_dim_names[0], "dp_cp")
876884
root_mesh = _mesh_resources.get_root_mesh(dp_cp_mesh)
877885
self.assertEqual(root_mesh, mesh_3d)
878-
flatten_mesh_root_dims = _mesh_resources.flatten_name_to_root_dims[root_mesh][
886+
flatten_mesh_layout = _mesh_resources.flatten_name_to_root_layout[root_mesh][
879887
"dp_cp"
880888
]
881-
self.assertEqual(flatten_mesh_root_dims, (0, 1))
889+
self.assertEqual(flatten_mesh_layout, flattened_dp_cp_mesh._layout)
890+
self.assertEqual(
891+
flattened_dp_cp_mesh._layout.global_ranks(8),
892+
[[0, 2, 4, 6], [1, 3, 5, 7]],
893+
)
882894

883895
ref_pg_count = _world.group_count
884896
# Calling flatten again should not create a new pg.
@@ -893,10 +905,14 @@ def test_flatten_mesh_3d(self):
893905
self.assertEqual(flattened_dp_tp_mesh.mesh_dim_names[0], "dp_tp")
894906
root_mesh = _mesh_resources.get_root_mesh(dp_tp_mesh)
895907
self.assertEqual(root_mesh, mesh_3d)
896-
flatten_mesh_root_dims = _mesh_resources.flatten_name_to_root_dims[root_mesh][
897-
"dp_tp"
898-
]
899-
self.assertEqual(flatten_mesh_root_dims, (0, 2))
908+
flatten_mesh_root_layout = _mesh_resources.flatten_name_to_root_layout[
909+
root_mesh
910+
]["dp_tp"]
911+
self.assertEqual(flatten_mesh_root_layout, flattened_dp_tp_mesh._layout)
912+
self.assertEqual(
913+
flattened_dp_tp_mesh._layout.global_ranks(8),
914+
[[0, 1, 4, 5], [2, 3, 6, 7]],
915+
)
900916

901917
# Test flatten with a flattened mesh_dim_name
902918
cp_tp_mesh = mesh_3d["cp", "tp"]
@@ -1498,6 +1514,91 @@ def test_composition(self):
14981514
right_l = _Layout((2,), (3,))
14991515
orig_l.composition(right_l)
15001516

1517+
def test_check_overlap(self):
1518+
"""Test the check_overlap method for various layout configurations."""
1519+
# Test 1: Valid layout - no overlap
1520+
# sizes=(2,3), strides=(6,1) - stride 6 > span 3, so no overlap
1521+
layout1 = _Layout((2, 3), (6, 1))
1522+
self.assertTrue(layout1.check_overlap())
1523+
1524+
# Test 2: Invalid layout - overlap due to stride < previous span
1525+
# sizes=(2,3), strides=(2,1) - stride 2 < span 3, causes overlap
1526+
layout2 = _Layout((2, 3), (2, 1))
1527+
self.assertFalse(layout2.check_overlap())
1528+
1529+
# Test 3: Invalid layout - duplicate strides
1530+
# sizes=(2,3), strides=(1,1) - same stride, causes overlap
1531+
layout3 = _Layout((2, 3), (1, 1))
1532+
self.assertFalse(layout3.check_overlap())
1533+
1534+
# Test 4: Valid layout - single dimension
1535+
layout4 = _Layout((4,), (1,))
1536+
self.assertTrue(layout4.check_overlap())
1537+
1538+
# Test 5: Valid layout - exact boundary case
1539+
# sizes=(2,3), strides=(3,1) - stride 3 == span 3, valid
1540+
layout5 = _Layout((2, 3), (3, 1))
1541+
self.assertTrue(layout5.check_overlap())
1542+
1543+
# Test 6: Valid layout - multi-dimensional with proper spacing
1544+
layout6 = _Layout((2, 2, 2), (8, 4, 1))
1545+
self.assertTrue(layout6.check_overlap())
1546+
1547+
# Test 7: Invalid layout - middle dimension overlaps
1548+
layout7 = _Layout((2, 2, 2), (4, 1, 2))
1549+
self.assertTrue(layout7.check_overlap())
1550+
1551+
def test_to_remapping_tensor(self):
1552+
"""Test the to_remapping_tensor method for various scenarios."""
1553+
# Test 1: Consecutive ranks, full world - should return logical groups directly
1554+
original_mesh = torch.tensor([[0, 1], [2, 3]], dtype=torch.int)
1555+
layout1 = _Layout((2, 2), (2, 1)) # row-major 2x2
1556+
result1 = layout1.to_remapping_tensor(original_mesh, world_size=4)
1557+
expected1 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int)
1558+
self.assertEqual(result1, expected1)
1559+
1560+
# Test 2: Non-consecutive ranks - should map to actual ranks
1561+
original_mesh = torch.tensor([[10, 20], [30, 40]], dtype=torch.int)
1562+
layout2 = _Layout((2, 2), (2, 1))
1563+
result2 = layout2.to_remapping_tensor(original_mesh, world_size=4)
1564+
expected2 = torch.tensor([[[10, 20], [30, 40]]], dtype=torch.int)
1565+
self.assertEqual(result2, expected2)
1566+
1567+
# Test 3: Partial world (mesh smaller than world_size) - requires stride scaling
1568+
original_mesh = torch.tensor([1, 2], dtype=torch.int)
1569+
layout3 = _Layout((2,), (4,)) # stride=4 for world_size=8
1570+
result3 = layout3.to_remapping_tensor(original_mesh, world_size=8)
1571+
expected3 = torch.tensor([[1, 2]], dtype=torch.int)
1572+
self.assertEqual(result3, expected3)
1573+
1574+
# Test 4: 1D layout with consecutive ranks
1575+
original_mesh = torch.tensor([0, 1, 2, 3], dtype=torch.int)
1576+
layout4 = _Layout((4,), (1,))
1577+
result4 = layout4.to_remapping_tensor(original_mesh, world_size=4)
1578+
expected4 = torch.tensor([[0, 1, 2, 3]], dtype=torch.int)
1579+
self.assertEqual(result4, expected4)
1580+
1581+
# Test 5: Complex strided layout with non-consecutive ranks
1582+
original_mesh = torch.tensor([5, 10, 15, 20], dtype=torch.int)
1583+
layout5 = _Layout((2, 2), (2, 1))
1584+
result5 = layout5.to_remapping_tensor(original_mesh, world_size=4)
1585+
expected5 = torch.tensor([[[5, 10], [15, 20]]], dtype=torch.int)
1586+
self.assertEqual(result5, expected5)
1587+
1588+
# Test 6: Tensor Cute representation of a 2D mesh
1589+
original_mesh = torch.tensor([[0, 2], [1, 3]], dtype=torch.int)
1590+
layout6 = _Layout((2, 2), (1, 2)) # column-major style
1591+
result6 = layout6.to_remapping_tensor(original_mesh, world_size=4)
1592+
expected6 = torch.tensor([[[0, 2], [1, 3]]], dtype=torch.int)
1593+
self.assertEqual(result6, expected6)
1594+
1595+
# Test 7: Layout with different stride pattern
1596+
original_mesh = torch.tensor([0, 2, 1, 4], dtype=torch.int)
1597+
layout7 = _Layout((2, 2), (1, 2)) # column-major style
1598+
result7 = layout7.to_remapping_tensor(original_mesh, world_size=4)
1599+
expected7 = torch.tensor([[[0, 1], [2, 4]]], dtype=torch.int)
1600+
self.assertEqual(result7, expected7)
1601+
15011602

15021603
if __name__ == "__main__":
15031604
run_tests()

torch/distributed/_mesh_layout.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from dataclasses import dataclass
88
from itertools import product
99

10+
import torch
1011
from torch.distributed._pycute import (
1112
coalesce,
1213
complement,
@@ -74,6 +75,16 @@ def __getitem__(self, i: int) -> "_MeshLayout":
7475
layout = super().__getitem__(i)
7576
return _MeshLayout(layout.shape, layout.stride)
7677

78+
def __getstate__(self) -> dict[str, IntTuple]:
79+
return {
80+
"shape": self.shape,
81+
"stride": self.stride,
82+
}
83+
84+
def __setstate__(self, state: dict[str, IntTuple]) -> None:
85+
object.__setattr__(self, "shape", state["shape"])
86+
object.__setattr__(self, "stride", state["stride"])
87+
7788
def coalesce(self) -> "_MeshLayout":
7889
"""
7990
A layout is represented by (sizes):(strides), e.g. (3,2):(4,2).
@@ -210,3 +221,144 @@ def global_ranks(self, world_size: int) -> list[list[int]]:
210221
[group_offset + group_rank for group_rank in self.member_ranks()]
211222
for group_offset in self.complement(world_size).member_ranks()
212223
]
224+
225+
def check_non_overlap(self) -> bool:
226+
"""
227+
Check if the layout has any overlap between the ranks it generates. If there is overlap,
228+
we return False, otherwise True.
229+
230+
Aside from indice 0, indices from each dim of the layout must be non-overlapping.
231+
232+
Here is how it works:
233+
1. Sort dimensions by stride (smallest stride first)
234+
2. For each dimension, check if:
235+
- It has the same stride as previous dimension (duplicate mapping)
236+
- Its stride overlaps with the previous dimension's span
237+
238+
A dimension's "span" is size * stride, representing the address space it covers.
239+
240+
Example 1 - Valid (no overlap):
241+
Layout: sizes=(2,3), strides=(6,1)
242+
- Dim 1: stride=1, span=3*1=3, covers addresses [0,1,2]
243+
- Dim 0: stride=6, span=2*6=12, covers addresses [0,6]
244+
→ No overlap since 6 > 3
245+
246+
Example 2 - Invalid (overlap):
247+
Layout: sizes=(2,3), strides=(2,1)
248+
- Dim 1: stride=1, span=3*1=3, covers addresses [0,1,2]
249+
- Dim 0: stride=2, span=2*2=4, covers addresses [0,2]
250+
→ Overlap! stride=2 < span=3, so addresses [0,2] are duplicated
251+
252+
Returns:
253+
bool: True if no overlap exists (valid layout), False if overlap detected
254+
"""
255+
previous_span = -1
256+
previous_stride = -1
257+
for size, stride in sorted(self.sizes_and_strides, key=lambda x: x[1]):
258+
if size == 1:
259+
continue
260+
if previous_stride == stride or stride < previous_span:
261+
return False
262+
previous_stride = stride
263+
previous_span = size * stride
264+
return True
265+
266+
def to_remapping_tensor(
267+
self,
268+
original_mesh_tensor: torch.Tensor,
269+
world_size: int,
270+
) -> torch.Tensor:
271+
"""
272+
Convert this layout into a tensor representation that maps the logical mesh
273+
structure to actual device ranks, handling cases where the mesh doesn't use
274+
consecutive ranks or doesn't span the full world size (Neither is CuTe representible).
275+
276+
With this method, the cute layout serves as the backend of indices bookkeeping for the
277+
mesh tensor when it comes to flatten, unflatten and slicing operations. The actual mesh
278+
tensor still represents the actual device assignment and ranks. We need this function
279+
to specify device allocation and create backend for a mesh.
280+
281+
Overview:
282+
1. Generate logical process groups using this layout's structure
283+
2. Check if the original mesh uses consecutive ranks (0,1,2,...)
284+
3. If consecutive: return the logical groups directly
285+
4. If non-consecutive or partial world: map logical indices to actual ranks
286+
287+
Examples:
288+
289+
Case 1 - Consecutive ranks, full world:
290+
original_mesh_tensor = [[0,1],[2,3]] # 2x2 mesh, ranks 0-3
291+
world_size = 4
292+
layout = Layout(2:2)
293+
→ Returns logical groups directly: [[0,2],[1,3]]
294+
295+
Case 2 - Non-consecutive ranks:
296+
original_mesh_tensor = [[10,20],[30,40]] # custom rank assignment
297+
world_size = 4
298+
layout = Layout(2:2)
299+
→ Maps logical indices to actual ranks: [[[10,30],[20,40]]]
300+
301+
Case 3 - Partial world (stride scaling needed):
302+
original_mesh_tensor = [[0,1]] # 1x2 mesh in world_size=8
303+
world_size = 8
304+
layout = Layout((2,), (4,)) # every 4th rank
305+
→ Scale down stride: (4,) → (1,) to fit mesh size
306+
→ Map scaled indices to actual ranks: [[0,1]]
307+
308+
Args:
309+
original_mesh_tensor: The concrete mesh tensor with actual device ranks
310+
world_size: Total number of ranks in the distributed system
311+
312+
Returns:
313+
torch.Tensor: A tensor representing the actual device rank from original_mesh_tensor
314+
"""
315+
316+
def scale_stride(scale: int, strides: IntTuple) -> IntTuple:
317+
"""
318+
Recursively scale down strides by a factor to fit within smaller mesh.
319+
320+
When layout expects world_size=8 but mesh only has 4 elements,
321+
we need to scale strides down by factor of 2 to generate valid indices.
322+
323+
Example: stride=4 with scale=2 → stride=2 (or keep as-is if stride < scale)
324+
"""
325+
if is_int(strides):
326+
return strides if strides < scale else strides // scale
327+
else:
328+
return tuple(scale_stride(scale, stride) for stride in strides)
329+
330+
# Create tensor representation of the mesh
331+
pg_ranks_by_dim = self.global_ranks(original_mesh_tensor.numel())
332+
sizes = flatten(self.sizes)
333+
tensor = torch.tensor(pg_ranks_by_dim, device="cpu", dtype=torch.int).view(
334+
-1,
335+
*sizes, # type: ignore[arg-type]
336+
)
337+
338+
# When the mesh tensor value can be represented as a cute layout, we can use the global ranks
339+
# generated by the layout directly for the mesh tensor. Otherwise, the ranks generated by the layout
340+
# will be used as indices to get the actual ranks from the original mesh tensor.
341+
if torch.equal(
342+
original_mesh_tensor.flatten().sort().values,
343+
torch.arange(
344+
original_mesh_tensor.numel(),
345+
device=original_mesh_tensor.device,
346+
dtype=original_mesh_tensor.dtype,
347+
),
348+
):
349+
return tensor
350+
351+
# This is important because the indices generated by the layout will be larger than the original mesh tensor
352+
# when the original mesh tensor does not contain all ranks in the world. So we need to scale the layout's stride
353+
# by world_size // mesh_tensor.numel() so that the indices generated by the layout will be within the range of
354+
# the original mesh tensor.
355+
if original_mesh_tensor.numel() != world_size:
356+
scale_factor = world_size // original_mesh_tensor.numel()
357+
scaled_strides = scale_stride(scale_factor, self.strides)
358+
scaled_layout = _MeshLayout(self.sizes, scaled_strides)
359+
pg_ranks_by_dim = scaled_layout.global_ranks(original_mesh_tensor.numel())
360+
tensor = torch.tensor(pg_ranks_by_dim, device="cpu", dtype=torch.int).view(
361+
-1,
362+
*sizes, # type: ignore[arg-type]
363+
)
364+
return original_mesh_tensor.flatten()[tensor]

0 commit comments

Comments
 (0)