Skip to content

Conversation

@awgu
Copy link
Collaborator

@awgu awgu commented Dec 11, 2022

Stack from ghstack:

This PR reworks the internal handling of parameter and gradient reduction mixed precision, cleans up the post-backward hook logic, and adds some minor changes to the communication hooks.

Overview
This PR addresses everything in #90657 except renaming keep_low_precision_grads to keep_grads_in_reduce_dtype since that is BC breaking. I recommend reading the issue before preceding.

For MixedPrecision(param_dtype, reduce_dtype, ...), the exact rule for parameter and gradient reduction mixed precision that we are following is:

If param_dtype is not None and reduce_dtype is None, then we infer reduce_dtype = param_dtype. Otherwise, we take param_dtype and reduce_dtype as is.

This PR enforces that, at the FlatParamHandle level, handle._config.fwd_bwd_param_dtype and handle._config.reduce_dtype are never None. The way to check if mixed precision is enabled is to compare against the original parameter dtype, which is now stored in handle._orig_param_dtype. It is no longer to check against None.

This avoids ambiguous cases such as when the user passes MixedPrecision(param_dtype=torch.float32). In that case, our existing implementation mistakenly thinks that parameter mixed precision is enabled and either relies on no-ops silently or errors (such as one case reported by MosaicML).

Additional Details

  • We remove FullyShardedDataParallel._mixed_precision_enabled_for_params, FullyShardedDataParallel._mixed_precision_enabled_for_reduce, and FullyShardedDataParallel._mixed_precision_keep_low_precision_grads since they are not used.
  • The unit test test_meta_device_with_mixed_precision() exercises a tricky edge case with meta device initialization, apply() (calling into summon_full_params()), and param_dtype=torch.float32 for a nested wrapping case, where each nested instance has parameters.
  • We include some minor fixes/improvements to the communication hook implementation.

Follow-Ups

  • We should get rid of HandleConfig and store its fields as attributes on FlatParamHandle directly.
  • Rename keep_low_precision_grads to keep_grads_in_reduce_dtype.

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 11, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/90660

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit e9aa3e0:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: distributed (fsdp) release notes category label Dec 11, 2022
awgu pushed a commit that referenced this pull request Dec 11, 2022
awgu pushed a commit to awgu/pytorch that referenced this pull request Dec 12, 2022
@awgu awgu added the topic: not user facing topic category label Dec 12, 2022
Copy link
Contributor

@zhaojuanmao zhaojuanmao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!!

@awgu awgu added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 13, 2022
@awgu
Copy link
Collaborator Author

awgu commented Dec 13, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

awgu pushed a commit to awgu/pytorch that referenced this pull request Dec 13, 2022
awgu pushed a commit that referenced this pull request Dec 20, 2022
To make mixed precision precise internally, #90660 changed the implementation to save `_orig_param_dtype`, `_low_prec_param_dtype`, and `_reduce_dtype` explicitly. However, these are computed at FSDP construction time, so it does not allow the user to change the model dtype after FSDP construction time but before lazy initialization. This PR recomputes those dtype attributes as needed if the model dtype changes in that window.

Note that any mixed precision settings specified by the user take precedence over the model dtype.

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Dec 20, 2022
To make mixed precision precise internally, #90660 changed the implementation to save `_orig_param_dtype`, `_low_prec_param_dtype`, and `_reduce_dtype` explicitly. However, these are computed at FSDP construction time, so it does not allow the user to change the model dtype after FSDP construction time but before lazy initialization. This PR recomputes those dtype attributes as needed if the model dtype changes in that window.

Note that any mixed precision settings specified by the user take precedence over the model dtype.

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Jan 10, 2023
… FSDP init"


Closes #90838.

To make mixed precision precise internally, #90660 changed the implementation to save `_orig_param_dtype`, `_low_prec_param_dtype`, and `_reduce_dtype` explicitly. However, these are computed at FSDP construction time, so it does not allow the user to change the model dtype after FSDP construction time but before lazy initialization. This PR recomputes those dtype attributes as needed if the model dtype changes in that window.

Note that any mixed precision settings specified by the user take precedence over the model dtype.

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Jan 10, 2023
Closes #90838.

To make mixed precision precise internally, #90660 changed the implementation to save `_orig_param_dtype`, `_low_prec_param_dtype`, and `_reduce_dtype` explicitly. However, these are computed at FSDP construction time, so it does not allow the user to change the model dtype after FSDP construction time but before lazy initialization. This PR recomputes those dtype attributes as needed if the model dtype changes in that window.

Note that any mixed precision settings specified by the user take precedence over the model dtype.

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Jan 10, 2023
… FSDP init"


Closes #90838.

To make mixed precision precise internally, #90660 changed the implementation to save `_orig_param_dtype`, `_low_prec_param_dtype`, and `_reduce_dtype` explicitly. However, these are computed at FSDP construction time, so it does not allow the user to change the model dtype after FSDP construction time but before lazy initialization. This PR recomputes those dtype attributes as needed if the model dtype changes in that window.

Note that any mixed precision settings specified by the user take precedence over the model dtype.

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Jan 10, 2023
Closes #90838.

To make mixed precision precise internally, #90660 changed the implementation to save `_orig_param_dtype`, `_low_prec_param_dtype`, and `_reduce_dtype` explicitly. However, these are computed at FSDP construction time, so it does not allow the user to change the model dtype after FSDP construction time but before lazy initialization. This PR recomputes those dtype attributes as needed if the model dtype changes in that window.

Note that any mixed precision settings specified by the user take precedence over the model dtype.

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Jan 12, 2023
Closes #90838.

To make mixed precision precise internally, #90660 changed the implementation to save `_orig_param_dtype`, `_low_prec_param_dtype`, and `_reduce_dtype` explicitly. However, these are computed at FSDP construction time, so it does not allow the user to change the model dtype after FSDP construction time but before lazy initialization. This PR recomputes those dtype attributes as needed if the model dtype changes in that window.

Note that any mixed precision settings specified by the user take precedence over the model dtype.
Pull Request resolved: #91192
Approved by: https://github.com/zhaojuanmao
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (fsdp) release notes category topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants