-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[FSDP][2/N] Add util for computing shared param LCA #94197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/94197
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit a0a4cb8: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: d63f1ef Pull Request resolved: pytorch#94197
**Overview** - This PR implements a utility function `get_shared_param_info_to_lca()` that returns a `Dict[SharedParamInfo, nn.Module]` mapping `SharedParamInfo` (representing a shared parameter) to its lowest common ancestor (LCA) module. - This function can be used as a subroutine for assigning shared parameters to their LCA modules during FSDP initialization (for the composable code path in the short term). **Details** The implementation follows a simple version of [Tarjan's offline LCA algorithm](https://en.wikipedia.org/wiki/Tarjan%27s_off-line_lowest_common_ancestors_algorithm) that is based on a union-find data structure. We can use this algorithm because the set of LCA queries is fixed a priori (i.e. this is offline). Each module represents a vertex in the module tree, where there is a directed edge from parent module to child module (i.e. `p` is a parent of `c` if `c` is returned from `p.children()`). The LCA module `lca` of two modules `a` and `b` is the lowest (i.e. greatest depth) module that includes both `a` and `b` in its subtree. For the unit test, here is a visualization of the module tree:  [ghstack-poisoned]
ghstack-source-id: a3e61ab Pull Request resolved: pytorch#94197
ghstack-source-id: a3e61ab Pull Request resolved: pytorch#94197
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Stack from ghstack:
fully_shard#94198 [FSDP][3/N] Add LCA logic tofully_shardOverview
get_shared_param_info_to_lca()that returns aDict[SharedParamInfo, nn.Module]mappingSharedParamInfo(representing a shared parameter) to its lowest common ancestor (LCA) module.Details
The implementation follows a simple version of Tarjan's offline LCA algorithm that is based on a union-find data structure. We can use this algorithm because the set of LCA queries is fixed a priori (i.e. this is offline).
Each module represents a vertex in the module tree, where there is a directed edge from parent module to child module (i.e.
pis a parent ofcifcis returned fromp.children()). The LCA modulelcaof two modulesaandbis the lowest (i.e. greatest depth) module that includes bothaandbin its subtree.For the unit test, here is a visualization of the module tree:
