-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[FSDP()][2/N] Refactor training state #87916
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87916
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 7decb7b: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| from enum import auto, Enum | ||
|
|
||
|
|
||
| class TrainingState(Enum): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is in a private file, we do not need to make this private. I should do this more consistently (e.g. some later PRs in this stack will violate this), but I will leave that as BE follow-ups.
|
|
||
| # Import the entire FSDP file to avoid circular imports | ||
| import torch.distributed.fsdp.fully_sharded_data_parallel as FSDP | ||
| import torch.distributed.fsdp.fully_sharded_data_parallel as fsdp_file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as the previous PR: Rename FSDP to fsdp_file to avoid confusion since we sometimes import FullyShardedDataParallel as FSDP.
This PR actually has meaningful changes. We stratify `TrainingState` into two levels: one is per FSDP instance and one is per `FlatParamHandle`/`FlatParameter`. - At the FSDP instance level, we only care about `IDLE`, FSDP computation (i.e. `FORWARD_BACKWARD`), or `SUMMON_FULL_PARAMS`. These dynamically modify behavior (e.g. `summon_full_params()` forces full precision). - At the `FlatParamHandle` level, we care about the training state for invariants and debugging. Hence, we keep `IDLE`, `FORWARD`, `BACKWARD_PRE`, `BACKWARD_POST`, and `SUMMON_FULL_PARAMS`. [ghstack-poisoned]
This PR actually has meaningful changes. We stratify `TrainingState` into two levels: one is per FSDP instance and one is per `FlatParamHandle`/`FlatParameter`. - At the FSDP instance level, we only care about `IDLE`, FSDP computation (i.e. `FORWARD_BACKWARD`), or `SUMMON_FULL_PARAMS`. These dynamically modify behavior (e.g. `summon_full_params()` forces full precision). - At the `FlatParamHandle` level, we care about the training state for invariants and debugging. Hence, we keep `IDLE`, `FORWARD`, `BACKWARD_PRE`, `BACKWARD_POST`, and `SUMMON_FULL_PARAMS`. [ghstack-poisoned]
mrshenli
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
naming-only changes. LGTM
|
@pytorchbot merge -g |
Merge startedYour change will be merged once all checks on your PR pass since you used the green (-g) flag (ETA: 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
ghstack-source-id: e01809e Pull Request resolved: pytorch#87916
| @@ -0,0 +1,23 @@ | |||
| from enum import auto, Enum | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am maintaining this _common_utils.py as I refactor. Eventually, we will merge _utils.py into _common_utils.py or other files.
|
The merge job was canceled. If you believe this is a mistake,then you can re trigger it through pytorch-bot. |
|
@pytorchbot merge -g |
Merge startedYour change will be merged once all checks on your PR pass since you used the green (-g) flag (ETA: 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
Hey @awgu. |
| self._assert_state( | ||
| [TrainingState_.BACKWARD_PRE, TrainingState_.BACKWARD_POST] | ||
| self._assert_state([TrainingState.FORWARD_BACKWARD]) | ||
| self.training_state = TrainingState.FORWARD_BACKWARD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: why it is assigned the same state after it checked the state == TrainingState.FORWARD_BACKWARD
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I refactored too fast and overlooked this redundancy :)
| self._assert_state([TrainingState.FORWARD_BACKWARD]) | ||
| self.training_state = TrainingState.FORWARD_BACKWARD | ||
| p_assert( | ||
| handle._training_state == HandleTrainingState.BACKWARD_PRE, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice! since it is per handle state, no need to check BACKWARD_POST any more, which is much cleaner
This PR actually has meaningful changes. We stratify `TrainingState` into two levels: one is per FSDP instance and one is per `FlatParamHandle`/`FlatParameter`. - At the FSDP instance level, we only care about `IDLE`, FSDP computation (i.e. `FORWARD_BACKWARD`), or `SUMMON_FULL_PARAMS`. These dynamically modify behavior (e.g. `summon_full_params()` forces full precision). - At the `FlatParamHandle` level, we care about the training state for invariants and debugging. Hence, we keep `IDLE`, `FORWARD`, `BACKWARD_PRE`, `BACKWARD_POST`, and `SUMMON_FULL_PARAMS`. Pull Request resolved: pytorch#87916 Approved by: https://github.com/mrshenli
This PR actually has meaningful changes. We stratify `TrainingState` into two levels: one is per FSDP instance and one is per `FlatParamHandle`/`FlatParameter`. - At the FSDP instance level, we only care about `IDLE`, FSDP computation (i.e. `FORWARD_BACKWARD`), or `SUMMON_FULL_PARAMS`. These dynamically modify behavior (e.g. `summon_full_params()` forces full precision). - At the `FlatParamHandle` level, we care about the training state for invariants and debugging. Hence, we keep `IDLE`, `FORWARD`, `BACKWARD_PRE`, `BACKWARD_POST`, and `SUMMON_FULL_PARAMS`. Pull Request resolved: pytorch#87916 Approved by: https://github.com/mrshenli
Stack from ghstack:
_lazy_init()into_fsdp_root_pre_forward()#87941 [FSDP()][26/N] Move_lazy_init()into_fsdp_root_pre_forward()_post_forward_reshard()#87940 [FSDP()][25/N] Add_post_forward_reshard()_lazy_init()#87939 [FSDP()][24/N] Refactor_lazy_init()_reset_lazy_init()#87937 [FSDP] Simplify_reset_lazy_init()_cast_buffers()in_lazy_init()#87936 [FSDP()][22/N] Refactor_cast_buffers()in_lazy_init()_cast_buffers()#87935 [FSDP()][21/N] Refactor_buffer_name_to_orig_dtypecomputationdtypetobuffer_name_to_dtype#87934 [FSDP] Renamedtypetobuffer_name_to_dtypedevicearg from_cast_buffers()#87933 [FSDP] Removedevicearg from_cast_buffers()pre_forward_unshard()#87931 [FSDP()][18/N] Refactorpre_forward_unshard()_fsdp_root_pre_forward()#87930 [FSDP()][17/N] Refactor_fsdp_root_pre_forward()_init_streams()#87928 [FSDP()][15/N] Refactor_init_streams()_FreeEventQueue#87922 [FSDP()][8/N] Refactor limiter's_FreeEventQueueCPUOffloaddataclass #87920 [FSDP()][6/N] RefactorCPUOffloaddataclassMixedPrecisiondataclass #87919 [FSDP()][5/N] RefactorMixedPrecisiondataclassShardingStrategyenum #87918 [FSDP()][4/N] RefactorShardingStrategyenumBackwardPrefetchenumThis PR actually has meaningful changes. We stratify
TrainingStateinto two levels: one is per FSDP instance and one is perFlatParamHandle/FlatParameter.IDLE, FSDP computation (i.e.FORWARD_BACKWARD), orSUMMON_FULL_PARAMS. These dynamically modify behavior (e.g.summon_full_params()forces full precision).FlatParamHandlelevel, we care about the training state for invariants and debugging. Hence, we keepIDLE,FORWARD,BACKWARD_PRE,BACKWARD_POST, andSUMMON_FULL_PARAMS.