@@ -440,6 +440,7 @@ def test_device_mesh_parent_child_hash(self):
440440 ep_mesh = ep_mesh_1 if self .rank < self .world_size // 2 else ep_mesh_2
441441 # ep_mesh is considered different from mesh_2d["TP"]
442442 self .assertEqual (mesh_2d ["TP" ]._flatten_mesh_list , ep_mesh ._flatten_mesh_list )
443+ self .assertEqual (mesh_2d ["TP" ]._layout , ep_mesh ._layout )
443444 self .assertEqual (mesh_2d ["TP" ].mesh .shape , ep_mesh .mesh .shape )
444445 self .assertEqual (mesh_2d ["TP" ].device_type , ep_mesh .device_type )
445446 self .assertNotEqual (mesh_2d ["TP" ].mesh_dim_names , ep_mesh .mesh_dim_names )
@@ -454,6 +455,7 @@ def test_device_mesh_parent_child_hash(self):
454455 )
455456 # another_mesh is considered the same as ep_mesh
456457 self .assertEqual (ep_mesh ._flatten_mesh_list , another_mesh ._flatten_mesh_list )
458+ self .assertEqual (ep_mesh ._layout , another_mesh ._layout )
457459 self .assertEqual (ep_mesh .mesh .shape , another_mesh .mesh .shape )
458460 self .assertEqual (ep_mesh .device_type , another_mesh .device_type )
459461 self .assertEqual (ep_mesh .mesh_dim_names , another_mesh .mesh_dim_names )
@@ -539,7 +541,6 @@ def test_from_group_with_mesh_shape_2d(self):
539541 mesh_dim_names = ("dp_replicate" , "dp_shard" ),
540542 )
541543
542- # self.assertEqual(ref_mesh._dim_group_names, dp_mesh._dim_group_names)
543544 for mesh_dim_group , ref_mesh_dim_group in zip (
544545 dp_mesh .get_all_groups (), ref_mesh .get_all_groups ()
545546 ):
@@ -800,6 +801,10 @@ def test_get_item_3d(self):
800801 # Test slicing out 1D mesh from a sub-2D mesh.
801802 shard_mesh = hsdp_mesh_2 ["Shard" ]
802803 self .assertEqual (shard_mesh .mesh .tolist (), shard_group [shard_group_idx ])
804+ replicate_mesh = hsdp_mesh_2 ["Replicate" ]
805+ self .assertEqual (
806+ replicate_mesh .mesh .tolist (), replicate_group [replicate_group_idx ]
807+ )
803808
804809 @with_comms
805810 def test_cache_and_reuse_submesh_slice_result (self ):
@@ -838,11 +843,14 @@ def test_get_item_3d_noncontiguous_slicing(self):
838843 # Check on the current dp_local_rank, whether the cp mesh tensor is the same.
839844 self .assertEqual (dp_cp_mesh .mesh [dp_local_rank ], cp_mesh .mesh )
840845
841- with self .assertRaisesRegex (
842- KeyError ,
843- "Invalid mesh_dim_names" ,
844- ):
845- mesh_3d ["cp" , "dp" ]
846+ # Support transpose slicing.
847+ cp_dp_mesh = mesh_3d ["cp" , "dp" ]
848+ expected_mesh_tensor = (
849+ torch .tensor ([[0 , 4 ], [1 , 5 ]], dtype = torch .int )
850+ if self .rank in (0 , 1 , 4 , 5 )
851+ else torch .tensor ([[2 , 6 ], [3 , 7 ]], dtype = torch .int )
852+ )
853+ self .assertEqual (cp_dp_mesh .mesh , expected_mesh_tensor )
846854
847855 @with_comms
848856 def test_flatten_mesh_1d (self ):
@@ -875,10 +883,14 @@ def test_flatten_mesh_3d(self):
875883 self .assertEqual (flattened_dp_cp_mesh .mesh_dim_names [0 ], "dp_cp" )
876884 root_mesh = _mesh_resources .get_root_mesh (dp_cp_mesh )
877885 self .assertEqual (root_mesh , mesh_3d )
878- flatten_mesh_root_dims = _mesh_resources .flatten_name_to_root_dims [root_mesh ][
886+ flatten_mesh_layout = _mesh_resources .flatten_name_to_root_layout [root_mesh ][
879887 "dp_cp"
880888 ]
881- self .assertEqual (flatten_mesh_root_dims , (0 , 1 ))
889+ self .assertEqual (flatten_mesh_layout , flattened_dp_cp_mesh ._layout )
890+ self .assertEqual (
891+ flattened_dp_cp_mesh ._layout .global_ranks (8 ),
892+ [[0 , 2 , 4 , 6 ], [1 , 3 , 5 , 7 ]],
893+ )
882894
883895 ref_pg_count = _world .group_count
884896 # Calling flatten again should not create a new pg.
@@ -893,10 +905,14 @@ def test_flatten_mesh_3d(self):
893905 self .assertEqual (flattened_dp_tp_mesh .mesh_dim_names [0 ], "dp_tp" )
894906 root_mesh = _mesh_resources .get_root_mesh (dp_tp_mesh )
895907 self .assertEqual (root_mesh , mesh_3d )
896- flatten_mesh_root_dims = _mesh_resources .flatten_name_to_root_dims [root_mesh ][
897- "dp_tp"
898- ]
899- self .assertEqual (flatten_mesh_root_dims , (0 , 2 ))
908+ flatten_mesh_root_layout = _mesh_resources .flatten_name_to_root_layout [
909+ root_mesh
910+ ]["dp_tp" ]
911+ self .assertEqual (flatten_mesh_root_layout , flattened_dp_tp_mesh ._layout )
912+ self .assertEqual (
913+ flattened_dp_tp_mesh ._layout .global_ranks (8 ),
914+ [[0 , 1 , 4 , 5 ], [2 , 3 , 6 , 7 ]],
915+ )
900916
901917 # Test flatten with a flattened mesh_dim_name
902918 cp_tp_mesh = mesh_3d ["cp" , "tp" ]
@@ -1498,6 +1514,91 @@ def test_composition(self):
14981514 right_l = _Layout ((2 ,), (3 ,))
14991515 orig_l .composition (right_l )
15001516
1517+ def test_check_overlap (self ):
1518+ """Test the check_overlap method for various layout configurations."""
1519+ # Test 1: Valid layout - no overlap
1520+ # sizes=(2,3), strides=(6,1) - stride 6 > span 3, so no overlap
1521+ layout1 = _Layout ((2 , 3 ), (6 , 1 ))
1522+ self .assertTrue (layout1 .check_overlap ())
1523+
1524+ # Test 2: Invalid layout - overlap due to stride < previous span
1525+ # sizes=(2,3), strides=(2,1) - stride 2 < span 3, causes overlap
1526+ layout2 = _Layout ((2 , 3 ), (2 , 1 ))
1527+ self .assertFalse (layout2 .check_overlap ())
1528+
1529+ # Test 3: Invalid layout - duplicate strides
1530+ # sizes=(2,3), strides=(1,1) - same stride, causes overlap
1531+ layout3 = _Layout ((2 , 3 ), (1 , 1 ))
1532+ self .assertFalse (layout3 .check_overlap ())
1533+
1534+ # Test 4: Valid layout - single dimension
1535+ layout4 = _Layout ((4 ,), (1 ,))
1536+ self .assertTrue (layout4 .check_overlap ())
1537+
1538+ # Test 5: Valid layout - exact boundary case
1539+ # sizes=(2,3), strides=(3,1) - stride 3 == span 3, valid
1540+ layout5 = _Layout ((2 , 3 ), (3 , 1 ))
1541+ self .assertTrue (layout5 .check_overlap ())
1542+
1543+ # Test 6: Valid layout - multi-dimensional with proper spacing
1544+ layout6 = _Layout ((2 , 2 , 2 ), (8 , 4 , 1 ))
1545+ self .assertTrue (layout6 .check_overlap ())
1546+
1547+ # Test 7: Invalid layout - middle dimension overlaps
1548+ layout7 = _Layout ((2 , 2 , 2 ), (4 , 1 , 2 ))
1549+ self .assertTrue (layout7 .check_overlap ())
1550+
1551+ def test_to_remapping_tensor (self ):
1552+ """Test the to_remapping_tensor method for various scenarios."""
1553+ # Test 1: Consecutive ranks, full world - should return logical groups directly
1554+ original_mesh = torch .tensor ([[0 , 1 ], [2 , 3 ]], dtype = torch .int )
1555+ layout1 = _Layout ((2 , 2 ), (2 , 1 )) # row-major 2x2
1556+ result1 = layout1 .to_remapping_tensor (original_mesh , world_size = 4 )
1557+ expected1 = torch .tensor ([[[0 , 1 ], [2 , 3 ]]], dtype = torch .int )
1558+ self .assertEqual (result1 , expected1 )
1559+
1560+ # Test 2: Non-consecutive ranks - should map to actual ranks
1561+ original_mesh = torch .tensor ([[10 , 20 ], [30 , 40 ]], dtype = torch .int )
1562+ layout2 = _Layout ((2 , 2 ), (2 , 1 ))
1563+ result2 = layout2 .to_remapping_tensor (original_mesh , world_size = 4 )
1564+ expected2 = torch .tensor ([[[10 , 20 ], [30 , 40 ]]], dtype = torch .int )
1565+ self .assertEqual (result2 , expected2 )
1566+
1567+ # Test 3: Partial world (mesh smaller than world_size) - requires stride scaling
1568+ original_mesh = torch .tensor ([1 , 2 ], dtype = torch .int )
1569+ layout3 = _Layout ((2 ,), (4 ,)) # stride=4 for world_size=8
1570+ result3 = layout3 .to_remapping_tensor (original_mesh , world_size = 8 )
1571+ expected3 = torch .tensor ([[1 , 2 ]], dtype = torch .int )
1572+ self .assertEqual (result3 , expected3 )
1573+
1574+ # Test 4: 1D layout with consecutive ranks
1575+ original_mesh = torch .tensor ([0 , 1 , 2 , 3 ], dtype = torch .int )
1576+ layout4 = _Layout ((4 ,), (1 ,))
1577+ result4 = layout4 .to_remapping_tensor (original_mesh , world_size = 4 )
1578+ expected4 = torch .tensor ([[0 , 1 , 2 , 3 ]], dtype = torch .int )
1579+ self .assertEqual (result4 , expected4 )
1580+
1581+ # Test 5: Complex strided layout with non-consecutive ranks
1582+ original_mesh = torch .tensor ([5 , 10 , 15 , 20 ], dtype = torch .int )
1583+ layout5 = _Layout ((2 , 2 ), (2 , 1 ))
1584+ result5 = layout5 .to_remapping_tensor (original_mesh , world_size = 4 )
1585+ expected5 = torch .tensor ([[[5 , 10 ], [15 , 20 ]]], dtype = torch .int )
1586+ self .assertEqual (result5 , expected5 )
1587+
1588+ # Test 6: Tensor Cute representation of a 2D mesh
1589+ original_mesh = torch .tensor ([[0 , 2 ], [1 , 3 ]], dtype = torch .int )
1590+ layout6 = _Layout ((2 , 2 ), (1 , 2 )) # column-major style
1591+ result6 = layout6 .to_remapping_tensor (original_mesh , world_size = 4 )
1592+ expected6 = torch .tensor ([[[0 , 2 ], [1 , 3 ]]], dtype = torch .int )
1593+ self .assertEqual (result6 , expected6 )
1594+
1595+ # Test 7: Layout with different stride pattern
1596+ original_mesh = torch .tensor ([0 , 2 , 1 , 4 ], dtype = torch .int )
1597+ layout7 = _Layout ((2 , 2 ), (1 , 2 )) # column-major style
1598+ result7 = layout7 .to_remapping_tensor (original_mesh , world_size = 4 )
1599+ expected7 = torch .tensor ([[[0 , 1 ], [2 , 4 ]]], dtype = torch .int )
1600+ self .assertEqual (result7 , expected7 )
1601+
15011602
15021603if __name__ == "__main__" :
15031604 run_tests ()
0 commit comments