Cast tensors when loading optimizer state dicts #3658
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Right now optimizers can load state dicts of other optimizers only if all parameters are matching in type and device (in contrast to
nn.Modules). This is too strict for many use cases, and is addresses in this patch.The only problem is that optimizer state isn't typed in any way, so code from this PR tries to make reasonable guesses - only state that's bound to certain parameters is casted (with parameter being the template), and we assume that floating point tensors in the state should match the type of parameter (I can't think of better way to handle load_state_dict across sets of parameters with different fp types). All other types are only moved to a different device.
Fixes #2830, #1442.