-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[FSDP()][27/N] Add forward hook registration #88040
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/88040
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit d930488: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
[ghstack-poisoned]
ghstack-source-id: 019499e Pull Request resolved: pytorch#88040
[ghstack-poisoned]
ghstack-source-id: 35fae68 Pull Request resolved: pytorch#88040
rohan-varma
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
couple of minor questions, thanks for adding this!
| def test_training(self): | ||
| """Tests training (forward, backward, optimizer).""" | ||
| device = torch.device("cuda") | ||
| local_model = Model(device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we have support for composable FSDP + meta device? Is there a source of truth where we can find the feature set covered by composable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we have support for composable FSDP + meta device?
I believe there should be because the composable FSDP constructor includes the same module materialization logic as the normal FSDP constructor.
Is there a source of truth where we can find the feature set covered by composable?
This is difficult to document right now since we are still prototyping. As I continue testing and thinking about the design, I may realize some sharp edges that prevent feature parity. I will try to stabilize soon. (same for use_orig_params=True)
| elif _handles_key: | ||
| _assert_in_training_states(state, [TrainingState.IDLE]) | ||
| allowed_states = [TrainingState.IDLE] | ||
| if _is_composable(state): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are allowed states different in composable vs non-composable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First, note that this is TrainingState, which is per state: _FSDPState object, and not HandleTrainingState, which is per FlatParamHandle / FlatParameter. (_FSDPState is Union[_State, FullyShardedDataParallel], where _State is from torch/distributed/_composable/contract.py).
For composable, state represents the local FSDP root (without wrapping). Upon the first FlatParameter's pre-backward hook, the state will transition to FORWARD_BACKWARD. For any subsequent FlatParameter's pre-backward hooks, state.training_state will already be in FORWARD_BACKWARD. However, each FlatParamHandle's training state transitions like you expect (i.e. IDLE -> BACKWARD_PRE here).
This shows why I had to refactor TrainingState in an earlier PR. We have to stratify to accommodate the difference between state: _FSDPState and FlatParamHandle / FlatParameter.
This PR adds the forward hook registration to composable FSDP and adds a unit test for the runtime. Pull Request resolved: pytorch#88040 Approved by: https://github.com/zhaojuanmao, https://github.com/rohan-varma
This PR adds the forward hook registration to composable FSDP and adds a unit test for the runtime. Pull Request resolved: pytorch#88040 Approved by: https://github.com/zhaojuanmao, https://github.com/rohan-varma
Stack from ghstack:
fully_shard()onlyFULL_SHARD#88260 [FSDP()][Easy] Makefully_shard()onlyFULL_SHARDfully_shard()abide by@contract! #88235 [FSDP()] Havefully_shard()abide by@contract!_Stateto_FSDPState#88234 [FSDP()][Easy] Rename_Stateto_FSDPStatefully_shard()and move to_composable/#88233 [FSDP()] Rename tofully_shard()and move to_composable/TrainingStatetransition #88232 [FSDP][Easy] Remove unneededTrainingStatetransitionunflat_param_name->fqnfor consistency #88123 [FSDP] Renameunflat_param_name->fqnfor consistency_get_buffer_names()#88122 [FSDP] Simplify_get_buffer_names()torch.no_grad()context when offloading to CPU #88121 [FSDP] Remove unneededtorch.no_grad()context when offloading to CPU_lazy_init()into_fsdp_root_pre_forward()#87941 [FSDP()][26/N] Move_lazy_init()into_fsdp_root_pre_forward()_post_forward_reshard()#87940 [FSDP()][25/N] Add_post_forward_reshard()_lazy_init()#87939 [FSDP()][24/N] Refactor_lazy_init()_cast_buffers()#87935 [FSDP()][21/N] Refactor and fix_cast_buffers()dtypetobuffer_name_to_dtype#87934 [FSDP] Renamedtypetobuffer_name_to_dtypedevicearg from_cast_buffers()#87933 [FSDP] Removedevicearg from_cast_buffers()pre_forward_unshard()#87931 [FSDP()][18/N] Refactorpre_forward_unshard()_fsdp_root_pre_forward()#87930 [FSDP()][17/N] Refactor_fsdp_root_pre_forward()_init_streams()#87928 [FSDP()][15/N] Refactor_init_streams()This PR adds the forward hook registration to composable FSDP and adds a unit test for the runtime.