Skip to content

Commit 9df9c46

Browse files
Ailing Zhangfacebook-github-bot
authored andcommitted
fix loading 1dim tensor from 0.3.* to 0dim tensor (#9781)
Summary: This PR fixes #9743 . Adding backward support when loading a checkpoint from 0.3.* with 1dim tensor, they are now 0 dim tensor in 0.4+. Pull Request resolved: #9781 Differential Revision: D8988196 Pulled By: ailzhang fbshipit-source-id: a7a1bc771d597394208430575d5a4d23b9653fef
1 parent d65c667 commit 9df9c46

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

torch/nn/modules/module.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,10 @@ def _load_from_state_dict(self, state_dict, prefix, metadata, strict, missing_ke
642642
if key in state_dict:
643643
input_param = state_dict[key]
644644

645+
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
646+
if len(param.shape) == 0 and len(input_param.shape) == 1:
647+
input_param = input_param[0]
648+
645649
if input_param.shape != param.shape:
646650
# local shape should match the one in checkpoint
647651
error_msgs.append('size mismatch for {}: copying a param of {} from checkpoint, '

0 commit comments

Comments
 (0)