-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[Device Mesh] Add an option to decouple PGs when it comes device mesh save #167590
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/fduwjj/236/base
Are you sure you want to change the base?
Conversation
… save [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/167590
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 61f5996 with merge base a5f3035 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
fegin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add an unit test to demonstrate torch.save/torch.load when decouple_backend_at_save is True.
| if state.get("dim_group_names"): | ||
| self._dim_group_names = state["dim_group_names"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the key step. We should add some comment here. How do users attach the PG after torch.load?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm also unclear on this. If we landed this PR as-is, then how is someone supposed to use it?
Maybe we should separate this into 2 PRs
- just warn in getstate to discourage people from using it, explain the risk, and point to DCP docs
- the rest of this PR, but also include better usage example and how to 'bind' the loaded DTensor into a new mesh.
I am not sure how high prio (2) is, but if we do it, we should do it right and have a good doc + UX for it
…device mesh save" cc H-Huang awgu wanchaol fegin wz337 wconstab d4l3k pragupta msaroufim dcci [ghstack-poisoned]
| if not self.decouple_backend_at_save and hasattr(self, "_dim_group_names"): | ||
| logger.warning( | ||
| "Save device mesh via torch.save with pg names and will be deprecated in PT 2.11. " | ||
| "Users are welcome to use Distributed checkpoint (DCP) or re-create pgs in the same order" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is probably not detailed enough to be helpful.
- welcome to use DCP: to be clear, this suggestion is that the user rewrite their flow, though, we should at least point to a doc or tutorial
- re-create pgs in the same order: this suggestion is not detailed enough to be actionable IMO.
How about this?
"Starting in PyTorch 2.11, torch.save will save a DeviceMesh without including ProcessGroup information, and loading a saved DeviceMesh will require manually recreating the same configuration of ProcessGroups and binding them to the loaded DeviceMesh. See <link to example> for more information on how to do this. Alternatively, use DCP <link to tutorial> to save and load DTensors in a format that supports resharding and can be loaded on a different mesh configuration."
Stack from ghstack (oldest at bottom):
The rationale behind this PR is we want to create a module level flag which decouples PG info (names) during torch.save and torch.load for DeviceMesh (and DTensor) The reason is that we want users to explicitly create PGs (or deviceMesh) while do the torch.load instead of reusing the PG name saved. Because if users don't create PG or created the PG in the wrong order, the loaded device mesh will not be working.
Also, we know directly changing this behavior is BC breaking, so we add a flag and warning messages for it, so that we will clean it up later on.
cc @H-Huang @awgu @wanchaol @fegin @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci