-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Bug
PR #3658, in fixing an issue related to loading optimizers state on a different device, goes too far and makes floating-point state the same dtype as the parameters. This can result in a loss of precision when the state is higher precision than the parameters. In fairseq, we are working around with intricate hacks which rely on the internal details:
The current behavior also makes it difficult to write unit tests for optimizer state (when state dtype differs from parameters dtype) load/save because the loaded state for the optimizer is different than the state before save.
To Reproduce
Steps to reproduce the behavior:
Expected behavior
dtype of optimizer state should not change
Environment
Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).
You can get the script and run it with:
wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py
- PyTorch Version (e.g., 1.0):
- OS (e.g., Linux):
- How you installed PyTorch (
conda,pip, source): - Build command you used (if compiling from source):
- Python version:
- CUDA/cuDNN version:
- GPU models and configuration:
- Any other relevant information:
Additional context
cc @vincentqb