Skip to content

Conversation

@fduwjj
Copy link
Contributor

@fduwjj fduwjj commented Sep 19, 2025

Stack from ghstack (oldest at bottom):

Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.

cc @H-Huang @awgu @wanchaol @fegin @wz337 @wconstab @d4l3k @pragupta @ezyang @msaroufim @dcci

Differential Revision: D85409698

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 19, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163358

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit cc835fc with merge base 000f495 (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

…or submesh"


Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
…or submesh"


Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
fduwjj added a commit that referenced this pull request Sep 20, 2025
get_world_size(),
)

for mesh_nd in pg_ranks_by_dim:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?!?! Why do you need to do it for every mesh_nd? Is this because you're triggering comms to initialize PGs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so long story short, we need all ranks to call new_group which is hidden very deep in the stack to initialize PGs. Otherwise the code will hang.

…or submesh"


Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
fduwjj added a commit that referenced this pull request Sep 24, 2025
…or submesh"


Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
fduwjj added a commit that referenced this pull request Sep 25, 2025
…or submesh"


Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
fduwjj added a commit that referenced this pull request Sep 27, 2025
…or submesh"


Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
fduwjj added a commit that referenced this pull request Sep 30, 2025
…or submesh"


Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
fduwjj added a commit that referenced this pull request Oct 1, 2025
…ubmesh and SPMD use case"


Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
fduwjj added a commit that referenced this pull request Oct 23, 2025
@fduwjj
Copy link
Contributor Author

fduwjj commented Oct 23, 2025

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 2 checks: trunk / linux-jammy-py3-clang12-executorch / test (executorch, 1, 1, lf.linux.2xlarge, unstable), trunk / linux-jammy-rocm-py3.10 / test (default, 2, 2, linux.rocm.gpu.gfx942.1)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@clee2000
Copy link
Contributor

@pytorchbot revert -m "probably need to revert this one too, its stacked with #166003 (comment)" -c ghfirst

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Oct 24, 2025
…esh and SPMD use case (#163358)"

This reverts commit 5a4997d.

Reverted #163358 on behalf of https://github.com/clee2000 due to probably need to revert this one  too, its stacked with #166003 (comment) ([comment](#163358 (comment)))
@pytorchmergebot
Copy link
Collaborator

@fduwjj your PR has been successfully reverted.

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Oct 24, 2025
…ubmesh and SPMD use case"


Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.


cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
fduwjj added a commit that referenced this pull request Oct 27, 2025
@fduwjj
Copy link
Contributor Author

fduwjj commented Oct 27, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@fduwjj
Copy link
Contributor Author

fduwjj commented Oct 27, 2025

@fduwjj has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

1 similar comment
@fduwjj
Copy link
Contributor Author

fduwjj commented Oct 27, 2025

@fduwjj has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

pytorchmergebot pushed a commit that referenced this pull request Oct 27, 2025
…from root mesh (#165492)

With concatenate API, we can directly combine two meshes together rather than getting the spmd mesh from root.

Differential Revision: [D85409698](https://our.internmc.facebook.com/intern/diff/D85409698)
Pull Request resolved: #165492
Approved by: https://github.com/fegin
ghstack dependencies: #163358
tianrengao pushed a commit that referenced this pull request Oct 30, 2025
…SPMD use case (#163358)

Today FSDP needs to slicing out spmd mesh from root mesh here: https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L301. But essentially, users want is a concatenate of some submesh into a big mesh and used as a spmd mesh. This PR is tentatively trying to implement this API for users.

One thing to note is that, all sub-mesh needs to slicing/flatten or unflatten from same root mesh otherwise the indices make no sense when it comes to mesh indexing and device allocation.

Pull Request resolved: #163358
Approved by: https://github.com/fegin
tianrengao pushed a commit that referenced this pull request Oct 30, 2025
…from root mesh (#165492)

With concatenate API, we can directly combine two meshes together rather than getting the spmd mesh from root.

Differential Revision: [D85409698](https://our.internmc.facebook.com/intern/diff/D85409698)
Pull Request resolved: #165492
Approved by: https://github.com/fegin
ghstack dependencies: #163358
@github-actions github-actions bot deleted the gh/fduwjj/206/head branch November 28, 2025 02:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: DeviceMesh Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants