Skip to content

Commit b558c98

Browse files
RohitRathore1pytorchmergebot
authored andcommitted
Add regression test for get_root_mesh with multiple independent meshes (#164731)
Fixes #163330 I tried to reproduce the bug with my 4-GPU setup (the original issue used 8 GPUs). I created several different test scenarios, trying to trigger the bug by: - creating two different device meshes - slicing them in various ways - checking if get_root_mesh() would get confused but the bug didn't show up! Everything worked correctly in `2.10`. I found that there was a massive refactoring of the `DeviceMesh` code (PR #163213) that landed on October 2nd. That PR completely rewrote how `DeviceMesh` tracks relationships between parent meshes and submeshes using. It seems like this refactoring fixed the bug! But I added a regression test to make sure it doesn't come back. The test (`test_get_root_mesh_multiple_independent_meshes`) does exactly what the bug report described: - creates two independent meshes - slices them both - verifies that each submesh correctly points back to its real parent - makes sure submeshes from mesh1 don't incorrectly claim mesh2 as their parent Pull Request resolved: #164731 Approved by: https://github.com/fduwjj
1 parent 415e641 commit b558c98

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

test/distributed/test_device_mesh.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,35 @@ def test_set_mesh_dim_group_options(self):
350350
# Fake pg only have BackendType as BackendType::CUSTOM.
351351
self.assertEqual(mesh.get_group(1)._get_backend_name(), "custom")
352352

353+
@with_comms
354+
def test_get_root_mesh_multiple_independent_meshes(self):
355+
# regression test for issue #163330
356+
# when creating multiple independent device meshes and slicing them,
357+
# get_root_mesh should return the correct parent mesh for each submesh
358+
mesh1 = init_device_mesh(
359+
self.device_type,
360+
(2, 2),
361+
mesh_dim_names=("dp", "tp"),
362+
)
363+
mesh1_dp = mesh1["dp"]
364+
mesh1_tp = mesh1["tp"]
365+
366+
mesh2 = init_device_mesh(
367+
self.device_type,
368+
(2, 2),
369+
mesh_dim_names=("dim1", "dim2"),
370+
)
371+
mesh2_dim1 = mesh2["dim1"]
372+
mesh2_dim2 = mesh2["dim2"]
373+
374+
self.assertEqual(_mesh_resources.get_root_mesh(mesh1_dp), mesh1)
375+
self.assertEqual(_mesh_resources.get_root_mesh(mesh1_tp), mesh1)
376+
self.assertEqual(_mesh_resources.get_root_mesh(mesh2_dim1), mesh2)
377+
self.assertEqual(_mesh_resources.get_root_mesh(mesh2_dim2), mesh2)
378+
379+
self.assertNotEqual(_mesh_resources.get_root_mesh(mesh1_dp), mesh2)
380+
self.assertNotEqual(_mesh_resources.get_root_mesh(mesh1_tp), mesh2)
381+
353382

354383
class DeviceMeshTestNDim(DTensorTestBase):
355384
@property

0 commit comments

Comments
 (0)