-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[FSDP][2/N] _summon_full_params -> _unshard_params
#92297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/92297
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit c8420cc: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: 998f0f5 Pull Request resolved: pytorch#92297
**Overview**
This PR stack will add support for unsharding FSDP's sharded parameters for `fully_shard`. This PR takes the first step by doing some internal refactoring.
- The existing API for wrapper FSDP is the static method `summon_full_params()`, which calls into the helper `_summon_full_params()`.
- This PR refactors:
- `summon_full_params()` core logic to `_unshard_params()`
- `_summon_full_params()` to `_unshard_params_recurse()`, which has a `recurse: bool` argument
- Previous `_unshard_params()` to `_unshard_fsdp_state_params()`, which applies to a single FSDP state
**Details**
- This PR introduces `_get_fsdp_states_with_modules()` and `_get_root_fsdp_states_with_modules()`, which additionally return the modules along with the FSDP states. The modules are needed for handling `FlatParameter` registration.
- We may be able to remove this if we clean up the `use_orig_params=True` vs. `False` code paths because for `True`, the `FlatParameter` is not registered, meaning that it does not need to be de-registered.
- Since `fully_shard` requires `use_orig_params=True`, we may not need `_get_fsdp_states_with_modules()` and `_get_root_fsdp_root_modules()`; however, I prefer to make the separation of FSDP state and module explicit for now for clarity.
**Follow-Ups**
- `writeback=True` and `rank0_only=True` raises an error. The previous explanation was:
> is not supported, as model parameter shapes will be different across ranks, and writing to them can lead to inconsistencies across ranks when the context is exited.
I am not exactly sure what the different model parameter shapes refers to. However, I believe that we can support `writeback=True` and `rank0_only=True` by broadcasting the `FlatParameter` from rank 0 in the `finally`, writing back, and freeing. This should not increase the peak memory since rank 0 already holds the unsharded `FlatParameter` in GPU memory before writing back and nonzero ranks do not have any other unsharded `FlatParameter`s in GPU memory.
[ghstack-poisoned]
**Overview**
This PR stack will add support for unsharding FSDP's sharded parameters for `fully_shard`. This PR takes the first step by doing some internal refactoring.
- The existing API for wrapper FSDP is the static method `summon_full_params()`, which calls into the helper `_summon_full_params()`.
- This PR refactors:
- `summon_full_params()` core logic to `_unshard_params()`
- `_summon_full_params()` to `_unshard_params_recurse()`, which has a `recurse: bool` argument
- Previous `_unshard_params()` to `_unshard_fsdp_state_params()`, which applies to a single FSDP state
**Details**
- This PR introduces `_get_fsdp_states_with_modules()` and `_get_root_fsdp_states_with_modules()`, which additionally return the modules along with the FSDP states. The modules are needed for handling `FlatParameter` registration.
- We may be able to remove this if we clean up the `use_orig_params=True` vs. `False` code paths because for `True`, the `FlatParameter` is not registered, meaning that it does not need to be de-registered.
- Since `fully_shard` requires `use_orig_params=True`, we may not need `_get_fsdp_states_with_modules()` and `_get_root_fsdp_root_modules()`; however, I prefer to make the separation of FSDP state and module explicit for now for clarity.
**Follow-Ups**
- `writeback=True` and `rank0_only=True` raises an error. The previous explanation was:
> is not supported, as model parameter shapes will be different across ranks, and writing to them can lead to inconsistencies across ranks when the context is exited.
I am not exactly sure what the different model parameter shapes refers to. However, I believe that we can support `writeback=True` and `rank0_only=True` by broadcasting the `FlatParameter` from rank 0 in the `finally`, writing back, and freeing. This should not increase the peak memory since rank 0 already holds the unsharded `FlatParameter` in GPU memory before writing back and nonzero ranks do not have any other unsharded `FlatParameter`s in GPU memory.
[ghstack-poisoned]
**Overview**
This PR stack will add support for unsharding FSDP's sharded parameters for `fully_shard`. This PR takes the first step by doing some internal refactoring.
- The existing API for wrapper FSDP is the static method `summon_full_params()`, which calls into the helper `_summon_full_params()`.
- This PR refactors:
- `summon_full_params()` core logic to `_unshard_params()`
- `_summon_full_params()` to `_unshard_params_recurse()`, which has a `recurse: bool` argument
- Previous `_unshard_params()` to `_unshard_fsdp_state_params()`, which applies to a single FSDP state
**Details**
- This PR introduces `_get_fsdp_states_with_modules()` and `_get_root_fsdp_states_with_modules()`, which additionally return the modules along with the FSDP states. The modules are needed for handling `FlatParameter` registration.
- We may be able to remove this if we clean up the `use_orig_params=True` vs. `False` code paths because for `True`, the `FlatParameter` is not registered, meaning that it does not need to be de-registered.
- Since `fully_shard` requires `use_orig_params=True`, we may not need `_get_fsdp_states_with_modules()` and `_get_root_fsdp_root_modules()`; however, I prefer to make the separation of FSDP state and module explicit for now for clarity.
**Follow-Ups**
- `writeback=True` and `rank0_only=True` raises an error. The previous explanation was:
> is not supported, as model parameter shapes will be different across ranks, and writing to them can lead to inconsistencies across ranks when the context is exited.
I am not exactly sure what the different model parameter shapes refers to. However, I believe that we can support `writeback=True` and `rank0_only=True` by broadcasting the `FlatParameter` from rank 0 in the `finally`, writing back, and freeing. This should not increase the peak memory since rank 0 already holds the unsharded `FlatParameter` in GPU memory before writing back and nonzero ranks do not have any other unsharded `FlatParameter`s in GPU memory.
[ghstack-poisoned]
rohan-varma
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we add unittests for summon_full_params composable path?
| "to them can lead to inconsistencies across ranks when the " | ||
| "context is exited." | ||
| ) | ||
| # TODO: Rank 0 can broadcast the `FlatParameter` to allow all ranks to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we file an issue for this? would it work for use_orig_params=True as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should work for both use_orig_params=True and False. I will file an issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if recurse: | ||
| with contextlib.ExitStack() as stack: | ||
| # TODO (awgu): The traversal function does not traverse through | ||
| # incompatible composable APIs. Verify if this is the desired |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you elaborate, what's an example of this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fully_shard(
Module(
replicate(
Submodule(
fully_shard(Subsubmodule),
Subsubmodule,
),
Submodule,
)
Because the traversal utils do not go through incompatible composable APIs (here, replicate), calling _unshard_params on the root Module will not unshard the parameters of the fully sharded Subsubmodule.
Yes, this has not been added yet. (I have a local [4/N] commit that does add a frontend for that path, but I did not open a PR for it since we have not finalized what the API should look like.) I will add tests when we include that. |
**Overview**
This PR stack will add support for unsharding FSDP's sharded parameters for `fully_shard`. This PR takes the first step by doing some internal refactoring.
- The existing API for wrapper FSDP is the static method `summon_full_params()`, which calls into the helper `_summon_full_params()`.
- This PR refactors:
- `summon_full_params()` core logic to `_unshard_params()`
- `_summon_full_params()` to `_unshard_params_recurse()`, which has a `recurse: bool` argument
- Previous `_unshard_params()` to `_unshard_fsdp_state_params()`, which applies to a single FSDP state
**Details**
- This PR introduces `_get_fsdp_states_with_modules()` and `_get_root_fsdp_states_with_modules()`, which additionally return the modules along with the FSDP states. The modules are needed for handling `FlatParameter` registration.
- We may be able to remove this if we clean up the `use_orig_params=True` vs. `False` code paths because for `True`, the `FlatParameter` is not registered, meaning that it does not need to be de-registered.
- Since `fully_shard` requires `use_orig_params=True`, we may not need `_get_fsdp_states_with_modules()` and `_get_root_fsdp_root_modules()`; however, I prefer to make the separation of FSDP state and module explicit for now for clarity.
**Follow-Ups**
- `writeback=True` and `rank0_only=True` raises an error. The previous explanation was:
> is not supported, as model parameter shapes will be different across ranks, and writing to them can lead to inconsistencies across ranks when the context is exited.
I am not exactly sure what the different model parameter shapes refers to. However, I believe that we can support `writeback=True` and `rank0_only=True` by broadcasting the `FlatParameter` from rank 0 in the `finally`, writing back, and freeing. This should not increase the peak memory since rank 0 already holds the unsharded `FlatParameter` in GPU memory before writing back and nonzero ranks do not have any other unsharded `FlatParameter`s in GPU memory.
[ghstack-poisoned]
…n-dev-setup * origin: (898 commits) Move dynamo.optimizations.distributed to backends (pytorch#93408) Remove cuda 11.6 from nightly (pytorch#93979) Refactor dynamo register_backend/BACKENDS (pytorch#93389) Remove cuda 11.6 from CI replace with 11.7 (pytorch#93406) [Dynamo] Rename `GuardBuilder.guarded_code` -> `check_fn_manager` (pytorch#93934) Revert "Remove CUDA 11.6 from nightly builds (pytorch#93404)" Revert "[inductor] fix crash issue when input is a view tensor (pytorch#90150)" Basic Validation for FSDP `state_dict` transformations of modules with persistent buffers (pytorch#93396) Merge Inductor perf smoke test with other inductor CI tests (pytorch#93395) [inductor] Don't import torchvision (pytorch#93027) [FSDP][3/N] Refactor `summon_full_params` unit tests (pytorch#92298) [FSDP][2/N] `_summon_full_params` -> `_unshard_params` (pytorch#92297) Remove CUDA 11.6 from nightly builds (pytorch#93404) Mark buffers that reuse other buffers (pytorch#93329) Refactor to allow reuse of SchedulerNode.allocate (pytorch#93328) retire sparse_mask_helper (pytorch#91714) update fbgemm third party (pytorch#93907) [inductor] fix crash issue when input is a view tensor (pytorch#90150) [Inductor] add config for weight prepacking (pytorch#93811) Check for none for NNModuleVariable.__module__ (pytorch#93326) ...
Stack from ghstack:
summon_full_paramsunit tests #92298 [FSDP][3/N] Refactorsummon_full_paramsunit tests_summon_full_params->_unshard_params#92297 [FSDP][2/N]_summon_full_params->_unshard_paramsOverview
This PR stack will add support for unsharding FSDP's sharded parameters for
fully_shard. This PR takes the first step by doing some internal refactoring.summon_full_params(), which calls into the helper_summon_full_params().summon_full_params()core logic to_unshard_params()_summon_full_params()to_unshard_params_recurse(), which has arecurse: boolargument_unshard_params()to_unshard_fsdp_state_params(), which applies to a single FSDP stateDetails
_get_fsdp_states_with_modules()and_get_root_fsdp_states_with_modules(), which additionally return the modules along with the FSDP states. The modules are needed for handlingFlatParameterregistration.use_orig_params=Truevs.Falsecode paths because forTrue, theFlatParameteris not registered, meaning that it does not need to be de-registered.fully_shardrequiresuse_orig_params=True, we may not need_get_fsdp_states_with_modules()and_get_root_fsdp_root_modules(); however, I prefer to make the separation of FSDP state and module explicit for now for clarity.Follow-Ups
writeback=Trueandrank0_only=Trueraises an error. The previous explanation was:I am not exactly sure what the different model parameter shapes refers to. However, I believe that we can support
writeback=Trueandrank0_only=Trueby broadcasting theFlatParameterfrom rank 0 in thefinally, writing back, and freeing. This should not increase the peak memory since rank 0 already holds the unshardedFlatParameterin GPU memory before writing back and nonzero ranks do not have any other unshardedFlatParameters in GPU memory.