[DeviceMesh][ez] Extract the pg creation as a util function#163930
[DeviceMesh][ez] Extract the pg creation as a util function#163930fduwjj wants to merge 2 commits intogh/fduwjj/210/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163930
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit bccf9e9 with merge base 5fcde74 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This is just to extract common logic into a util function because we will use it many times for the following stack of Device Mesh refactoring. cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
) While refactoring the bookkeeping for DeviceMesh while leveraging CuTe layout, we found that we need to have two more util functions. One is to check whether one layout has overlap inside it or not. For example, (2,2):(2:1) has no overlap while (2,2):(2:2) has overlap. Pull Request resolved: #163367 Approved by: https://github.com/fegin ghstack dependencies: #163212, #163288, #163928, #163930
This is just to extract common logic into a util function because we will use it many times for the following stack of Device Mesh refactoring. Pull Request resolved: #163930 Approved by: https://github.com/fegin ghstack dependencies: #163212, #163288, #163928
) While refactoring the bookkeeping for DeviceMesh while leveraging CuTe layout, we found that we need to have two more util functions. One is to check whether one layout has overlap inside it or not. For example, (2,2):(2:1) has no overlap while (2,2):(2:2) has overlap. Pull Request resolved: #163367 Approved by: https://github.com/fegin ghstack dependencies: #163212, #163288, #163928, #163930
| mesh = DeviceMesh( | ||
| device_type, | ||
| mesh_nd, | ||
| mesh_dim_names=mesh_dim_names, | ||
| backend_override=backend_override, | ||
| _init_backend=_init_backend, | ||
| ) | ||
| if cur_rank in mesh_nd: | ||
| res_mesh = mesh |
There was a problem hiding this comment.
Thanks for splitting this out as I had asked! I didn't get to review this before it was landed, but there's still something I don't understand. I get that we need to call the "PG creation API" multiple times on each rank, even for the ranks that don't participate, but I don't get why we need to call the DeviceMesh constructor multiple times!
Could we instead call the PG creation API directly and just invoke the DeviceMesh constructor once with the right PG?
…163930) This is just to extract common logic into a util function because we will use it many times for the following stack of Device Mesh refactoring. Pull Request resolved: pytorch#163930 Approved by: https://github.com/fegin ghstack dependencies: pytorch#163212, pytorch#163288, pytorch#163928
…rch#163367) While refactoring the bookkeeping for DeviceMesh while leveraging CuTe layout, we found that we need to have two more util functions. One is to check whether one layout has overlap inside it or not. For example, (2,2):(2:1) has no overlap while (2,2):(2:2) has overlap. Pull Request resolved: pytorch#163367 Approved by: https://github.com/fegin ghstack dependencies: pytorch#163212, pytorch#163288, pytorch#163928, pytorch#163930
Stack from ghstack (oldest at bottom):
This is just to extract common logic into a util function because we will use it many times for the following stack of Device Mesh refactoring.
cc @H-Huang @awgu @wanchaol @fegin @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci