Skip to content

Commit da0224d

Browse files
author
Dmytro Dzhulgakov
committed
fix breakage and add nested test
1 parent 9a92c24 commit da0224d

File tree

2 files changed

+13
-12
lines changed

2 files changed

+13
-12
lines changed

test/test_nn.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4686,19 +4686,20 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
46864686
# skip some of the error handling
46874687
self.param.data.copy_(state_dict[prefix + "serialized"] - 1)
46884688

4689-
m = CustomState()
4690-
m.param[0] = 10
4691-
m.sub.weight[0, 0] = 555
4689+
# use sequential to verify nesting
4690+
m = nn.Sequential(CustomState())
4691+
m[0].param[0] = 10
4692+
m[0].sub.weight[0, 0] = 555
46924693
state_dict = m.state_dict()
4693-
self.assertEqual(state_dict["serialized"].item(), 11)
4694-
self.assertIn("sub.weight", state_dict)
4695-
self.assertNotIn("param", state_dict)
4694+
self.assertEqual(state_dict["0.serialized"].item(), 11)
4695+
self.assertIn("0.sub.weight", state_dict)
4696+
self.assertNotIn("0.param", state_dict)
46964697
del m
4697-
mm = CustomState()
4698-
self.assertEqual(mm.param[0].item(), 1)
4698+
mm = nn.Sequential(CustomState())
4699+
self.assertEqual(mm[0].param[0].item(), 1)
46994700
mm.load_state_dict(state_dict)
4700-
self.assertEqual(mm.param[0].item(), 10)
4701-
self.assertEqual(mm.sub.weight[0, 0].item(), 555)
4701+
self.assertEqual(mm[0].param[0].item(), 10)
4702+
self.assertEqual(mm[0].sub.weight[0, 0].item(), 555)
47024703

47034704
def test_parameter_assignment(self):
47044705
l = nn.Linear(5, 5)

torch/jit/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1808,8 +1808,8 @@ def _get_methods(cls):
18081808
_compiled_methods_whitelist = {
18091809
'forward', 'register_buffer', 'register_parameter', 'add_module',
18101810
'_apply', 'apply', 'cuda', 'cpu', 'to', 'type', 'float', 'double', 'half',
1811-
'state_dict', 'load_state_dict', '_load_from_state_dict',
1812-
'_named_members', 'parameters', 'named_parameters',
1811+
'state_dict', '_save_to_state_dict', 'load_state_dict',
1812+
'_load_from_state_dict', '_named_members', 'parameters', 'named_parameters',
18131813
'buffers', 'named_buffers', 'children', 'named_children', 'modules',
18141814
'named_modules', 'zero_grad', 'share_memory', '_get_name', 'extra_repr',
18151815
'_slow_forward', '_tracing_name', 'eval', 'train',

0 commit comments

Comments
 (0)