@@ -440,7 +440,6 @@ def test_device_mesh_parent_child_hash(self):
440440 ep_mesh = ep_mesh_1 if self .rank < self .world_size // 2 else ep_mesh_2
441441 # ep_mesh is considered different from mesh_2d["TP"]
442442 self .assertEqual (mesh_2d ["TP" ]._flatten_mesh_list , ep_mesh ._flatten_mesh_list )
443- self .assertEqual (mesh_2d ["TP" ]._layout , ep_mesh ._layout )
444443 self .assertEqual (mesh_2d ["TP" ].mesh .shape , ep_mesh .mesh .shape )
445444 self .assertEqual (mesh_2d ["TP" ].device_type , ep_mesh .device_type )
446445 self .assertNotEqual (mesh_2d ["TP" ].mesh_dim_names , ep_mesh .mesh_dim_names )
@@ -455,7 +454,6 @@ def test_device_mesh_parent_child_hash(self):
455454 )
456455 # another_mesh is considered the same as ep_mesh
457456 self .assertEqual (ep_mesh ._flatten_mesh_list , another_mesh ._flatten_mesh_list )
458- self .assertEqual (ep_mesh ._layout , another_mesh ._layout )
459457 self .assertEqual (ep_mesh .mesh .shape , another_mesh .mesh .shape )
460458 self .assertEqual (ep_mesh .device_type , another_mesh .device_type )
461459 self .assertEqual (ep_mesh .mesh_dim_names , another_mesh .mesh_dim_names )
@@ -541,6 +539,7 @@ def test_from_group_with_mesh_shape_2d(self):
541539 mesh_dim_names = ("dp_replicate" , "dp_shard" ),
542540 )
543541
542+ # self.assertEqual(ref_mesh._dim_group_names, dp_mesh._dim_group_names)
544543 for mesh_dim_group , ref_mesh_dim_group in zip (
545544 dp_mesh .get_all_groups (), ref_mesh .get_all_groups ()
546545 ):
@@ -801,10 +800,6 @@ def test_get_item_3d(self):
801800 # Test slicing out 1D mesh from a sub-2D mesh.
802801 shard_mesh = hsdp_mesh_2 ["Shard" ]
803802 self .assertEqual (shard_mesh .mesh .tolist (), shard_group [shard_group_idx ])
804- replicate_mesh = hsdp_mesh_2 ["Replicate" ]
805- self .assertEqual (
806- replicate_mesh .mesh .tolist (), replicate_group [replicate_group_idx ]
807- )
808803
809804 @with_comms
810805 def test_cache_and_reuse_submesh_slice_result (self ):
@@ -878,17 +873,12 @@ def test_flatten_mesh_3d(self):
878873 flattened_dp_cp_mesh = dp_cp_mesh ._flatten ()
879874 self .assertEqual (dp_cp_mesh .mesh .flatten (), flattened_dp_cp_mesh .mesh )
880875 self .assertEqual (flattened_dp_cp_mesh .mesh_dim_names [0 ], "dp_cp" )
881- self .assertEqual (flattened_dp_cp_mesh .get_group ().group_desc , "mesh_dp_cp" )
882876 root_mesh = _mesh_resources .get_root_mesh (dp_cp_mesh )
883877 self .assertEqual (root_mesh , mesh_3d )
884- flatten_mesh_layout = _mesh_resources .root_to_flatten_mapping [root_mesh ][
878+ flatten_mesh_root_dims = _mesh_resources .flatten_name_to_root_dims [root_mesh ][
885879 "dp_cp"
886- ]._layout
887- self .assertEqual (flatten_mesh_layout , flattened_dp_cp_mesh ._layout )
888- self .assertEqual (
889- flattened_dp_cp_mesh ._layout .global_ranks (8 ),
890- [[0 , 2 , 4 , 6 ], [1 , 3 , 5 , 7 ]],
891- )
880+ ]
881+ self .assertEqual (flatten_mesh_root_dims , (0 , 1 ))
892882
893883 ref_pg_count = _world .group_count
894884 # Calling flatten again should not create a new pg.
@@ -903,19 +893,10 @@ def test_flatten_mesh_3d(self):
903893 self .assertEqual (flattened_dp_tp_mesh .mesh_dim_names [0 ], "dp_tp" )
904894 root_mesh = _mesh_resources .get_root_mesh (dp_tp_mesh )
905895 self .assertEqual (root_mesh , mesh_3d )
906- flatten_mesh_root_layout = _mesh_resources .root_to_flatten_mapping [root_mesh ][
896+ flatten_mesh_root_dims = _mesh_resources .flatten_name_to_root_dims [root_mesh ][
907897 "dp_tp"
908- ]._layout
909- self .assertEqual (flatten_mesh_root_layout , flattened_dp_tp_mesh ._layout )
910- self .assertEqual (
911- flattened_dp_tp_mesh ._layout .global_ranks (8 ),
912- [[0 , 1 , 4 , 5 ], [2 , 3 , 6 , 7 ]],
913- )
914- with self .assertRaisesRegex (
915- NotImplementedError ,
916- "Currently, this only allows slicing out a contiguous flattened dim" ,
917- ):
918- mesh_3d ["dp_tp" , "cp" ]
898+ ]
899+ self .assertEqual (flatten_mesh_root_dims , (0 , 2 ))
919900
920901 # Test flatten with a flattened mesh_dim_name
921902 cp_tp_mesh = mesh_3d ["cp" , "tp" ]
@@ -1556,50 +1537,6 @@ def test_check_non_overlap(self):
15561537 layout8 = _Layout ((3 , 2 ), (2 , 3 ))
15571538 self .assertTrue (layout8 .check_non_overlap ())
15581539
1559- def test_remap_to_tensor (self ):
1560- """Test the remap_to_tensor method for various scenarios."""
1561- # Test 1: Consecutive ranks, full world - should return logical groups directly
1562- original_mesh = torch .tensor ([[0 , 1 ], [2 , 3 ]], dtype = torch .int )
1563- layout1 = _Layout ((2 , 2 ), (2 , 1 )) # row-major 2x2
1564- result1 = layout1 .remap_to_tensor (original_mesh )
1565- expected1 = torch .tensor ([[[0 , 1 ], [2 , 3 ]]], dtype = torch .int )
1566- self .assertEqual (result1 , expected1 )
1567-
1568- # Test 2: Non-consecutive ranks - should map to actual ranks
1569- original_mesh = torch .tensor ([[10 , 20 ], [30 , 40 ]], dtype = torch .int )
1570- layout2 = _Layout ((2 , 2 ), (2 , 1 ))
1571- result2 = layout2 .remap_to_tensor (original_mesh )
1572- expected2 = torch .tensor ([[[10 , 20 ], [30 , 40 ]]], dtype = torch .int )
1573- self .assertEqual (result2 , expected2 )
1574-
1575- # Test 4: 1D layout with consecutive ranks
1576- original_mesh = torch .tensor ([0 , 1 , 2 , 3 ], dtype = torch .int )
1577- layout4 = _Layout ((4 ,), (1 ,))
1578- result4 = layout4 .remap_to_tensor (original_mesh )
1579- expected4 = torch .tensor ([[0 , 1 , 2 , 3 ]], dtype = torch .int )
1580- self .assertEqual (result4 , expected4 )
1581-
1582- # Test 5: Complex strided layout with non-consecutive ranks
1583- original_mesh = torch .tensor ([5 , 10 , 15 , 20 ], dtype = torch .int )
1584- layout5 = _Layout ((2 , 2 ), (2 , 1 ))
1585- result5 = layout5 .remap_to_tensor (original_mesh )
1586- expected5 = torch .tensor ([[[5 , 10 ], [15 , 20 ]]], dtype = torch .int )
1587- self .assertEqual (result5 , expected5 )
1588-
1589- # Test 6: Tensor Cute representation of a 2D mesh
1590- original_mesh = torch .tensor ([[0 , 2 ], [1 , 3 ]], dtype = torch .int )
1591- layout6 = _Layout ((2 , 2 ), (1 , 2 )) # column-major style
1592- result6 = layout6 .remap_to_tensor (original_mesh )
1593- expected6 = torch .tensor ([[[0 , 1 ], [2 , 3 ]]], dtype = torch .int )
1594- self .assertEqual (result6 , expected6 )
1595-
1596- # Test 7: Layout with different stride pattern
1597- original_mesh = torch .tensor ([0 , 2 , 1 , 4 ], dtype = torch .int )
1598- layout7 = _Layout ((2 , 2 ), (1 , 2 )) # column-major style
1599- result7 = layout7 .remap_to_tensor (original_mesh )
1600- expected7 = torch .tensor ([[[0 , 1 ], [2 , 4 ]]], dtype = torch .int )
1601- self .assertEqual (result7 , expected7 )
1602-
16031540
16041541if __name__ == "__main__" :
16051542 run_tests ()
0 commit comments