@@ -4686,19 +4686,20 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
46864686 # skip some of the error handling
46874687 self .param .data .copy_ (state_dict [prefix + "serialized" ] - 1 )
46884688
4689- m = CustomState ()
4690- m .param [0 ] = 10
4691- m .sub .weight [0 , 0 ] = 555
4689+ # use sequential to verify nesting
4690+ m = nn .Sequential (CustomState ())
4691+ m [0 ].param [0 ] = 10
4692+ m [0 ].sub .weight [0 , 0 ] = 555
46924693 state_dict = m .state_dict ()
4693- self .assertEqual (state_dict ["serialized" ].item (), 11 )
4694- self .assertIn ("sub.weight" , state_dict )
4695- self .assertNotIn ("param" , state_dict )
4694+ self .assertEqual (state_dict ["0. serialized" ].item (), 11 )
4695+ self .assertIn ("0. sub.weight" , state_dict )
4696+ self .assertNotIn ("0. param" , state_dict )
46964697 del m
4697- mm = CustomState ()
4698- self .assertEqual (mm .param [0 ].item (), 1 )
4698+ mm = nn . Sequential ( CustomState () )
4699+ self .assertEqual (mm [ 0 ] .param [0 ].item (), 1 )
46994700 mm .load_state_dict (state_dict )
4700- self .assertEqual (mm .param [0 ].item (), 10 )
4701- self .assertEqual (mm .sub .weight [0 , 0 ].item (), 555 )
4701+ self .assertEqual (mm [ 0 ] .param [0 ].item (), 10 )
4702+ self .assertEqual (mm [ 0 ] .sub .weight [0 , 0 ].item (), 555 )
47024703
47034704 def test_parameter_assignment (self ):
47044705 l = nn .Linear (5 , 5 )
0 commit comments