@@ -2141,47 +2141,91 @@ def test_spectral_norm_load_state_dict(self):
21412141 for _ in range (activate_times ):
21422142 snm (inp )
21432143
2144+ version_latest_ref_state_dict = deepcopy (snm .state_dict ())
2145+ self .assertEqual ({'weight_orig' , 'bias' , 'weight_u' , 'weight_v' }, set (version_latest_ref_state_dict .keys ()))
2146+
2147+ # test that non-strict loading works
2148+ non_strict_state_dict = deepcopy (version_latest_ref_state_dict )
2149+ non_strict_state_dict ['nonsense' ] = 'nonsense'
2150+ with self .assertRaisesRegex (RuntimeError , r'Unexpected key\(s\) in state_dict: "nonsense"' ):
2151+ snm .load_state_dict (non_strict_state_dict , strict = True )
2152+ snm .load_state_dict (non_strict_state_dict , strict = False )
2153+ del non_strict_state_dict ['weight_orig' ]
2154+ snm .load_state_dict (non_strict_state_dict , strict = False )
2155+ del non_strict_state_dict ['weight_u' ]
2156+ snm .load_state_dict (non_strict_state_dict , strict = False )
2157+ del non_strict_state_dict ['weight_v' ]
2158+ snm .load_state_dict (non_strict_state_dict , strict = False )
2159+ non_strict_state_dict ['weight' ] = snm .weight .detach ().clone () # set W as a buffer
2160+ snm .load_state_dict (non_strict_state_dict , strict = False )
2161+ del non_strict_state_dict ._metadata ['' ]['spectral_norm' ] # remove metadata info
2162+ snm .load_state_dict (non_strict_state_dict , strict = False )
2163+ del non_strict_state_dict ['weight' ] # remove W buffer
2164+ snm .load_state_dict (non_strict_state_dict , strict = False )
2165+ del non_strict_state_dict ['bias' ]
2166+ snm .load_state_dict (non_strict_state_dict , strict = False )
2167+
21442168 # craft a version None state_dict
2145- version_none_state_dict = deepcopy (snm .state_dict ())
2146- self .assertEqual ({'weight_orig' , 'bias' , 'weight_u' , 'weight_v' }, set (version_none_state_dict .keys ()))
2169+ version_none_state_dict = deepcopy (version_latest_ref_state_dict )
21472170 self .assertIn ('spectral_norm' , version_none_state_dict ._metadata ['' ])
21482171 del version_none_state_dict ._metadata ['' ]['spectral_norm' ] # remove metadata info
21492172 del version_none_state_dict ['weight_v' ] # remove v vector
21502173 version_none_state_dict ['weight' ] = snm .weight .detach ().clone () # set W as a buffer
21512174
21522175 # normal state_dict
2153- version_latest_state_dict = deepcopy (snm .state_dict ())
2176+ for version_latest_with_metadata in [True , False ]:
2177+ version_latest_state_dict = deepcopy (version_latest_ref_state_dict )
21542178
2155- snm .eval ()
2156- out0_eval = snm (inp )
2157- snm .train ()
2158- out1_train = snm (inp )
2159- out2_train = snm (inp )
2160- snm .eval ()
2161- out3_eval = snm (inp )
2162-
2163- snm .load_state_dict (version_none_state_dict )
2164- if activate_times > 0 :
2165- # since in loading version None state dict, we assume that the
2166- # values in the state dict have gone through at lease one
2167- # forward, we only test for equivalence when activate_times > 0.
2168- snm .eval ()
2169- self .assertEqual (out0_eval , snm (inp ))
2170- snm .train ()
2171- self .assertEqual (out1_train , snm (inp ))
2172- self .assertEqual (out2_train , snm (inp ))
2173- snm .eval ()
2174- self .assertEqual (out3_eval , snm (inp ))
2175-
2176- # Test normal loading
2177- snm .load_state_dict (version_latest_state_dict )
2178- snm .eval ()
2179- self .assertEqual (out0_eval , snm (inp ))
2180- snm .train ()
2181- self .assertEqual (out1_train , snm (inp ))
2182- self .assertEqual (out2_train , snm (inp ))
2183- snm .eval ()
2184- self .assertEqual (out3_eval , snm (inp ))
2179+ if not version_latest_with_metadata :
2180+ # We want to still load a user-crafted state_dict, one without metadata
2181+ del version_latest_state_dict ._metadata ['' ]['spectral_norm' ]
2182+
2183+ # test that re-wrapping does not matter
2184+ m = torch .nn .utils .remove_spectral_norm (snm )
2185+ snm = torch .nn .utils .spectral_norm (m )
2186+
2187+ snm .load_state_dict (version_latest_ref_state_dict )
2188+ with torch .no_grad ():
2189+ snm .eval ()
2190+ out0_eval = snm (inp )
2191+ snm .train ()
2192+ out1_train = snm (inp )
2193+ out2_train = snm (inp )
2194+ snm .eval ()
2195+ out3_eval = snm (inp )
2196+
2197+ # test that re-wrapping does not matter
2198+ m = torch .nn .utils .remove_spectral_norm (snm )
2199+ snm = torch .nn .utils .spectral_norm (m )
2200+
2201+ snm .load_state_dict (version_none_state_dict )
2202+ if activate_times > 0 :
2203+ # since in loading version None state dict, we assume that the
2204+ # values in the state dict have gone through at lease one
2205+ # forward, we only test for equivalence when activate_times > 0.
2206+ with torch .no_grad ():
2207+ snm .eval ()
2208+ self .assertEqual (out0_eval , snm (inp ))
2209+ snm .train ()
2210+ self .assertEqual (out1_train , snm (inp ))
2211+ self .assertEqual (out2_train , snm (inp ))
2212+ snm .eval ()
2213+ self .assertEqual (out3_eval , snm (inp ))
2214+
2215+ # test that re-wrapping does not matter
2216+ m = torch .nn .utils .remove_spectral_norm (snm )
2217+ snm = torch .nn .utils .spectral_norm (m )
2218+
2219+ # Test normal loading
2220+ snm .load_state_dict (version_latest_state_dict )
2221+ with torch .no_grad ():
2222+ snm .eval ()
2223+ self .assertEqual (out0_eval , snm (inp ))
2224+ snm .train ()
2225+ self .assertEqual (out1_train , snm (inp ))
2226+ self .assertEqual (out2_train , snm (inp ))
2227+ snm .eval ()
2228+ self .assertEqual (out3_eval , snm (inp ))
21852229
21862230 def test_spectral_norm_dim (self ):
21872231 inp = torch .randn (2 , 3 , 10 , 12 )
@@ -3602,7 +3646,7 @@ def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, a
36023646 multihead_attn_module .add_zero_attn , multihead_attn_module .dropout ,
36033647 multihead_attn_module .out_proj .weight , multihead_attn_module .out_proj .bias ,
36043648 multihead_attn_module .training , key_padding_mask_tensor , True , attn_mask_tensor ,
3605- static_k = saved_k_tensor , static_v = saved_v_tensor )
3649+ static_k = saved_k_tensor , static_v = saved_v_tensor )
36063650 else :
36073651 result , result_weight = torch .nn .functional .multi_head_attention_forward (
36083652 _Q , _K , _V ,
@@ -3612,9 +3656,9 @@ def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, a
36123656 multihead_attn_module .add_zero_attn , multihead_attn_module .dropout ,
36133657 multihead_attn_module .out_proj .weight , multihead_attn_module .out_proj .bias ,
36143658 multihead_attn_module .training , key_padding_mask_tensor , True , attn_mask_tensor ,
3615- True , multihead_attn_module .q_proj_weight ,
3659+ True , multihead_attn_module .q_proj_weight ,
36163660 multihead_attn_module .k_proj_weight , multihead_attn_module .v_proj_weight ,
3617- static_k = saved_k_tensor , static_v = saved_v_tensor )
3661+ static_k = saved_k_tensor , static_v = saved_v_tensor )
36183662
36193663 result = result .squeeze (0 ).detach ().numpy ()
36203664
@@ -3644,12 +3688,12 @@ def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, a
36443688
36453689 if saved_k is not None :
36463690 K_split = np .reshape (saved_k , [dims [0 ], nheads , dims [1 ], d_head ])
3647- else :
3691+ else :
36483692 K_split = _split_heads_ref (K_fc , dims , nheads , d_head )
36493693
36503694 if saved_k is not None :
36513695 V_split = np .reshape (saved_v , [dims [0 ], nheads , dims [1 ], d_head ])
3652- else :
3696+ else :
36533697 V_split = _split_heads_ref (V_fc , dims , nheads , d_head )
36543698
36553699 if add_zero_attn :
@@ -3668,7 +3712,7 @@ def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, a
36683712 V = V_split ,
36693713 dims = Q_split .shape ,
36703714 unseen_mask = attn_mask ,
3671- key_padding_mask = key_padding_mask
3715+ key_padding_mask = key_padding_mask
36723716 )
36733717 combined_attn_heads = _combine_heads_ref (
36743718 X = attn_heads , dims = [batch_sz , 1 ], nheads = nheads , d_head = d_head
@@ -3713,7 +3757,7 @@ def test_multihead_attn_all_arguments2():
37133757 add_zero_attn = True , saved_kv = True )
37143758
37153759 def test_multihead_attn_all_arguments3 ():
3716- _multihead_attn_test_helper (add_key_padding_mask = True , add_zero_attn = True ,
3760+ _multihead_attn_test_helper (add_key_padding_mask = True , add_zero_attn = True ,
37173761 saved_kv = True , same_embed_dim = True )
37183762
37193763 test_multihead_attn_add_zero_attn () # Test MultiheadAttention with add_zero_attn
0 commit comments