Skip to content

checkpoint restore of optimizers changes dtype of Floating-point state #43706

@msbaines

Description

@msbaines

🐛 Bug

PR #3658, in fixing an issue related to loading optimizers state on a different device, goes too far and makes floating-point state the same dtype as the parameters. This can result in a loss of precision when the state is higher precision than the parameters. In fairseq, we are working around with intricate hacks which rely on the internal details:

https://github.com/pytorch/fairseq/blob/fc27170a9e70c6485331d8c84d56142a98de8a84/fairseq/optim/fp16_optimizer.py#L276-L293

The current behavior also makes it difficult to write unit tests for optimizer state (when state dtype differs from parameters dtype) load/save because the loaded state for the optimizer is different than the state before save.

To Reproduce

Steps to reproduce the behavior:

Expected behavior

dtype of optimizer state should not change

Environment

Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).

You can get the script and run it with:

wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py
  • PyTorch Version (e.g., 1.0):
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, source):
  • Build command you used (if compiling from source):
  • Python version:
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

cc @vincentqb

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: optimizerRelated to torch.optimtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions