-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[FSDP] Clean up FlatParamHandle dtypes, post-backward hook
#90660
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90660
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit e9aa3e0: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
ghstack-source-id: 572b81e Pull Request resolved: pytorch#90660
zhaojuanmao
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!!
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
ghstack-source-id: 9137cab Pull Request resolved: pytorch#90660
To make mixed precision precise internally, #90660 changed the implementation to save `_orig_param_dtype`, `_low_prec_param_dtype`, and `_reduce_dtype` explicitly. However, these are computed at FSDP construction time, so it does not allow the user to change the model dtype after FSDP construction time but before lazy initialization. This PR recomputes those dtype attributes as needed if the model dtype changes in that window. Note that any mixed precision settings specified by the user take precedence over the model dtype. [ghstack-poisoned]
To make mixed precision precise internally, #90660 changed the implementation to save `_orig_param_dtype`, `_low_prec_param_dtype`, and `_reduce_dtype` explicitly. However, these are computed at FSDP construction time, so it does not allow the user to change the model dtype after FSDP construction time but before lazy initialization. This PR recomputes those dtype attributes as needed if the model dtype changes in that window. Note that any mixed precision settings specified by the user take precedence over the model dtype. [ghstack-poisoned]
… FSDP init" Closes #90838. To make mixed precision precise internally, #90660 changed the implementation to save `_orig_param_dtype`, `_low_prec_param_dtype`, and `_reduce_dtype` explicitly. However, these are computed at FSDP construction time, so it does not allow the user to change the model dtype after FSDP construction time but before lazy initialization. This PR recomputes those dtype attributes as needed if the model dtype changes in that window. Note that any mixed precision settings specified by the user take precedence over the model dtype. [ghstack-poisoned]
Closes #90838. To make mixed precision precise internally, #90660 changed the implementation to save `_orig_param_dtype`, `_low_prec_param_dtype`, and `_reduce_dtype` explicitly. However, these are computed at FSDP construction time, so it does not allow the user to change the model dtype after FSDP construction time but before lazy initialization. This PR recomputes those dtype attributes as needed if the model dtype changes in that window. Note that any mixed precision settings specified by the user take precedence over the model dtype. [ghstack-poisoned]
… FSDP init" Closes #90838. To make mixed precision precise internally, #90660 changed the implementation to save `_orig_param_dtype`, `_low_prec_param_dtype`, and `_reduce_dtype` explicitly. However, these are computed at FSDP construction time, so it does not allow the user to change the model dtype after FSDP construction time but before lazy initialization. This PR recomputes those dtype attributes as needed if the model dtype changes in that window. Note that any mixed precision settings specified by the user take precedence over the model dtype. [ghstack-poisoned]
Closes #90838. To make mixed precision precise internally, #90660 changed the implementation to save `_orig_param_dtype`, `_low_prec_param_dtype`, and `_reduce_dtype` explicitly. However, these are computed at FSDP construction time, so it does not allow the user to change the model dtype after FSDP construction time but before lazy initialization. This PR recomputes those dtype attributes as needed if the model dtype changes in that window. Note that any mixed precision settings specified by the user take precedence over the model dtype. [ghstack-poisoned]
Closes #90838. To make mixed precision precise internally, #90660 changed the implementation to save `_orig_param_dtype`, `_low_prec_param_dtype`, and `_reduce_dtype` explicitly. However, these are computed at FSDP construction time, so it does not allow the user to change the model dtype after FSDP construction time but before lazy initialization. This PR recomputes those dtype attributes as needed if the model dtype changes in that window. Note that any mixed precision settings specified by the user take precedence over the model dtype. Pull Request resolved: #91192 Approved by: https://github.com/zhaojuanmao
Stack from ghstack:
FlatParamHandledtypes, post-backward hook #90660 [FSDP] Clean upFlatParamHandledtypes, post-backward hookreduce_dtype#90615 [FSDP] Tighten post-bwd cast toreduce_dtype_storage()in test file #90622 [FSDP][Easy] Move to_storage()in test file_stream_to_namefor debugging #90611 [FSDP] Save_stream_to_namefor debuggingDTensor,use_orig_params=True#90562 [Reland][FSDP] Another fix forDTensor,use_orig_params=TrueThis PR reworks the internal handling of parameter and gradient reduction mixed precision, cleans up the post-backward hook logic, and adds some minor changes to the communication hooks.
Overview
This PR addresses everything in #90657 except renaming
keep_low_precision_gradstokeep_grads_in_reduce_dtypesince that is BC breaking. I recommend reading the issue before preceding.For
MixedPrecision(param_dtype, reduce_dtype, ...), the exact rule for parameter and gradient reduction mixed precision that we are following is:This PR enforces that, at the
FlatParamHandlelevel,handle._config.fwd_bwd_param_dtypeandhandle._config.reduce_dtypeare neverNone. The way to check if mixed precision is enabled is to compare against the original parameter dtype, which is now stored inhandle._orig_param_dtype. It is no longer to check againstNone.This avoids ambiguous cases such as when the user passes
MixedPrecision(param_dtype=torch.float32). In that case, our existing implementation mistakenly thinks that parameter mixed precision is enabled and either relies on no-ops silently or errors (such as one case reported by MosaicML).Additional Details
FullyShardedDataParallel._mixed_precision_enabled_for_params,FullyShardedDataParallel._mixed_precision_enabled_for_reduce, andFullyShardedDataParallel._mixed_precision_keep_low_precision_gradssince they are not used.test_meta_device_with_mixed_precision()exercises a tricky edge case with meta device initialization,apply()(calling intosummon_full_params()), andparam_dtype=torch.float32for a nested wrapping case, where each nested instance has parameters.Follow-Ups
HandleConfigand store its fields as attributes onFlatParamHandledirectly.keep_low_precision_gradstokeep_grads_in_reduce_dtype.