Skip to content

Commit 22e219d

Browse files
Revert "[DeviceMesh] Simplifying internal bookkeeping with CuTe layout (#163213)"
This reverts commit b098514. Reverted #163213 on behalf of https://github.com/yangw-dev due to caused internal test failure ([comment](#163213 (comment)))
1 parent bdc0a42 commit 22e219d

File tree

3 files changed

+141
-246
lines changed

3 files changed

+141
-246
lines changed

test/distributed/test_device_mesh.py

Lines changed: 7 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,6 @@ def test_device_mesh_parent_child_hash(self):
440440
ep_mesh = ep_mesh_1 if self.rank < self.world_size // 2 else ep_mesh_2
441441
# ep_mesh is considered different from mesh_2d["TP"]
442442
self.assertEqual(mesh_2d["TP"]._flatten_mesh_list, ep_mesh._flatten_mesh_list)
443-
self.assertEqual(mesh_2d["TP"]._layout, ep_mesh._layout)
444443
self.assertEqual(mesh_2d["TP"].mesh.shape, ep_mesh.mesh.shape)
445444
self.assertEqual(mesh_2d["TP"].device_type, ep_mesh.device_type)
446445
self.assertNotEqual(mesh_2d["TP"].mesh_dim_names, ep_mesh.mesh_dim_names)
@@ -455,7 +454,6 @@ def test_device_mesh_parent_child_hash(self):
455454
)
456455
# another_mesh is considered the same as ep_mesh
457456
self.assertEqual(ep_mesh._flatten_mesh_list, another_mesh._flatten_mesh_list)
458-
self.assertEqual(ep_mesh._layout, another_mesh._layout)
459457
self.assertEqual(ep_mesh.mesh.shape, another_mesh.mesh.shape)
460458
self.assertEqual(ep_mesh.device_type, another_mesh.device_type)
461459
self.assertEqual(ep_mesh.mesh_dim_names, another_mesh.mesh_dim_names)
@@ -541,6 +539,7 @@ def test_from_group_with_mesh_shape_2d(self):
541539
mesh_dim_names=("dp_replicate", "dp_shard"),
542540
)
543541

542+
# self.assertEqual(ref_mesh._dim_group_names, dp_mesh._dim_group_names)
544543
for mesh_dim_group, ref_mesh_dim_group in zip(
545544
dp_mesh.get_all_groups(), ref_mesh.get_all_groups()
546545
):
@@ -801,10 +800,6 @@ def test_get_item_3d(self):
801800
# Test slicing out 1D mesh from a sub-2D mesh.
802801
shard_mesh = hsdp_mesh_2["Shard"]
803802
self.assertEqual(shard_mesh.mesh.tolist(), shard_group[shard_group_idx])
804-
replicate_mesh = hsdp_mesh_2["Replicate"]
805-
self.assertEqual(
806-
replicate_mesh.mesh.tolist(), replicate_group[replicate_group_idx]
807-
)
808803

809804
@with_comms
810805
def test_cache_and_reuse_submesh_slice_result(self):
@@ -878,17 +873,12 @@ def test_flatten_mesh_3d(self):
878873
flattened_dp_cp_mesh = dp_cp_mesh._flatten()
879874
self.assertEqual(dp_cp_mesh.mesh.flatten(), flattened_dp_cp_mesh.mesh)
880875
self.assertEqual(flattened_dp_cp_mesh.mesh_dim_names[0], "dp_cp")
881-
self.assertEqual(flattened_dp_cp_mesh.get_group().group_desc, "mesh_dp_cp")
882876
root_mesh = _mesh_resources.get_root_mesh(dp_cp_mesh)
883877
self.assertEqual(root_mesh, mesh_3d)
884-
flatten_mesh_layout = _mesh_resources.root_to_flatten_mapping[root_mesh][
878+
flatten_mesh_root_dims = _mesh_resources.flatten_name_to_root_dims[root_mesh][
885879
"dp_cp"
886-
]._layout
887-
self.assertEqual(flatten_mesh_layout, flattened_dp_cp_mesh._layout)
888-
self.assertEqual(
889-
flattened_dp_cp_mesh._layout.global_ranks(8),
890-
[[0, 2, 4, 6], [1, 3, 5, 7]],
891-
)
880+
]
881+
self.assertEqual(flatten_mesh_root_dims, (0, 1))
892882

893883
ref_pg_count = _world.group_count
894884
# Calling flatten again should not create a new pg.
@@ -903,19 +893,10 @@ def test_flatten_mesh_3d(self):
903893
self.assertEqual(flattened_dp_tp_mesh.mesh_dim_names[0], "dp_tp")
904894
root_mesh = _mesh_resources.get_root_mesh(dp_tp_mesh)
905895
self.assertEqual(root_mesh, mesh_3d)
906-
flatten_mesh_root_layout = _mesh_resources.root_to_flatten_mapping[root_mesh][
896+
flatten_mesh_root_dims = _mesh_resources.flatten_name_to_root_dims[root_mesh][
907897
"dp_tp"
908-
]._layout
909-
self.assertEqual(flatten_mesh_root_layout, flattened_dp_tp_mesh._layout)
910-
self.assertEqual(
911-
flattened_dp_tp_mesh._layout.global_ranks(8),
912-
[[0, 1, 4, 5], [2, 3, 6, 7]],
913-
)
914-
with self.assertRaisesRegex(
915-
NotImplementedError,
916-
"Currently, this only allows slicing out a contiguous flattened dim",
917-
):
918-
mesh_3d["dp_tp", "cp"]
898+
]
899+
self.assertEqual(flatten_mesh_root_dims, (0, 2))
919900

920901
# Test flatten with a flattened mesh_dim_name
921902
cp_tp_mesh = mesh_3d["cp", "tp"]
@@ -1556,50 +1537,6 @@ def test_check_non_overlap(self):
15561537
layout8 = _Layout((3, 2), (2, 3))
15571538
self.assertTrue(layout8.check_non_overlap())
15581539

1559-
def test_remap_to_tensor(self):
1560-
"""Test the remap_to_tensor method for various scenarios."""
1561-
# Test 1: Consecutive ranks, full world - should return logical groups directly
1562-
original_mesh = torch.tensor([[0, 1], [2, 3]], dtype=torch.int)
1563-
layout1 = _Layout((2, 2), (2, 1)) # row-major 2x2
1564-
result1 = layout1.remap_to_tensor(original_mesh)
1565-
expected1 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int)
1566-
self.assertEqual(result1, expected1)
1567-
1568-
# Test 2: Non-consecutive ranks - should map to actual ranks
1569-
original_mesh = torch.tensor([[10, 20], [30, 40]], dtype=torch.int)
1570-
layout2 = _Layout((2, 2), (2, 1))
1571-
result2 = layout2.remap_to_tensor(original_mesh)
1572-
expected2 = torch.tensor([[[10, 20], [30, 40]]], dtype=torch.int)
1573-
self.assertEqual(result2, expected2)
1574-
1575-
# Test 4: 1D layout with consecutive ranks
1576-
original_mesh = torch.tensor([0, 1, 2, 3], dtype=torch.int)
1577-
layout4 = _Layout((4,), (1,))
1578-
result4 = layout4.remap_to_tensor(original_mesh)
1579-
expected4 = torch.tensor([[0, 1, 2, 3]], dtype=torch.int)
1580-
self.assertEqual(result4, expected4)
1581-
1582-
# Test 5: Complex strided layout with non-consecutive ranks
1583-
original_mesh = torch.tensor([5, 10, 15, 20], dtype=torch.int)
1584-
layout5 = _Layout((2, 2), (2, 1))
1585-
result5 = layout5.remap_to_tensor(original_mesh)
1586-
expected5 = torch.tensor([[[5, 10], [15, 20]]], dtype=torch.int)
1587-
self.assertEqual(result5, expected5)
1588-
1589-
# Test 6: Tensor Cute representation of a 2D mesh
1590-
original_mesh = torch.tensor([[0, 2], [1, 3]], dtype=torch.int)
1591-
layout6 = _Layout((2, 2), (1, 2)) # column-major style
1592-
result6 = layout6.remap_to_tensor(original_mesh)
1593-
expected6 = torch.tensor([[[0, 1], [2, 3]]], dtype=torch.int)
1594-
self.assertEqual(result6, expected6)
1595-
1596-
# Test 7: Layout with different stride pattern
1597-
original_mesh = torch.tensor([0, 2, 1, 4], dtype=torch.int)
1598-
layout7 = _Layout((2, 2), (1, 2)) # column-major style
1599-
result7 = layout7.remap_to_tensor(original_mesh)
1600-
expected7 = torch.tensor([[[0, 1], [2, 4]]], dtype=torch.int)
1601-
self.assertEqual(result7, expected7)
1602-
16031540

16041541
if __name__ == "__main__":
16051542
run_tests()

torch/distributed/_mesh_layout.py

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from dataclasses import dataclass
88
from itertools import product
99

10-
import torch
1110
from torch.distributed._pycute import (
1211
coalesce,
1312
complement,
@@ -244,54 +243,3 @@ def check_non_overlap(self) -> bool:
244243
"""
245244
ranks = self.all_ranks_from_zero()
246245
return len(ranks) == len(set(ranks))
247-
248-
def remap_to_tensor(
249-
self,
250-
mesh_tensor: torch.Tensor,
251-
) -> torch.Tensor:
252-
"""
253-
Leverage layout as an index for mesh tensor that re-maps the indexes after layout
254-
transformation to actual device ranks.
255-
256-
With this method, the cute layout serves as the backend of indices bookkeeping for the
257-
mesh tensor when it comes to flatten, unflatten and slicing operations. The actual mesh
258-
tensor still represents the actual device assignment and ranks. We need this function
259-
to specify device allocation and create backend for a mesh. Although any transform of mesh tensors
260-
can be treated as a view or subset of mesh tensor, we do need to use the actual view or
261-
sub-tensor for DeviceMesh and its backend creation.
262-
263-
The shape of the `mesh_tensor` can be any size because users can define a device mesh with any
264-
shapes. But we can further refactor the code so that internally we can only support 1D mesh tensor
265-
and reconstruct the mesh tensor with the shape of the layout when accessed by users.
266-
#TODO: Only support 1D mesh tensor stored internally and reconstruct the mesh tensor via layout.
267-
268-
Examples:
269-
270-
Case 1 - Consecutive ranks, full world:
271-
original_mesh_tensor = [[0,1],[2,3]] # 2x2 mesh, ranks 0-3
272-
world_size = 4
273-
layout = Layout(2:2)
274-
Return: [[0,2],[1,3]]
275-
276-
Case 2 - Non-consecutive ranks:
277-
original_mesh_tensor = [[10,20],[30,40]] # custom rank assignment
278-
world_size = 4
279-
layout = Layout(2:2)
280-
Return: [[[10,30],[20,40]]]
281-
282-
Args:
283-
mesh_tensor: The concrete mesh tensor with actual device ranks
284-
285-
Returns:
286-
torch.Tensor: A tensor representing the actual device allocation from mesh_tensor
287-
"""
288-
complement_layout = self.complement(mesh_tensor.numel())
289-
290-
return (
291-
mesh_tensor.flatten()
292-
.as_strided(
293-
flatten(complement_layout.sizes) + flatten(self.sizes),
294-
flatten(complement_layout.strides) + flatten(self.strides),
295-
)
296-
.reshape(-1, *(self[i].numel() for i in range(len(self))))
297-
)

0 commit comments

Comments
 (0)