Skip to content

Commit d48cbd6

Browse files
ssnlfacebook-github-bot
authored andcommitted
Fix spectral_norm load_state_dict with strict=False (#22545)
Summary: Fixes #21251 also fixes some missing hook removals. Pull Request resolved: #22545 Differential Revision: D16139506 Pulled By: soumith fbshipit-source-id: 552a9f9f91be328a47ee8f1e1d29c1f59b0ebca3
1 parent 94bd5dd commit d48cbd6

File tree

2 files changed

+119
-48
lines changed

2 files changed

+119
-48
lines changed

test/test_nn.py

Lines changed: 84 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2141,47 +2141,91 @@ def test_spectral_norm_load_state_dict(self):
21412141
for _ in range(activate_times):
21422142
snm(inp)
21432143

2144+
version_latest_ref_state_dict = deepcopy(snm.state_dict())
2145+
self.assertEqual({'weight_orig', 'bias', 'weight_u', 'weight_v'}, set(version_latest_ref_state_dict.keys()))
2146+
2147+
# test that non-strict loading works
2148+
non_strict_state_dict = deepcopy(version_latest_ref_state_dict)
2149+
non_strict_state_dict['nonsense'] = 'nonsense'
2150+
with self.assertRaisesRegex(RuntimeError, r'Unexpected key\(s\) in state_dict: "nonsense"'):
2151+
snm.load_state_dict(non_strict_state_dict, strict=True)
2152+
snm.load_state_dict(non_strict_state_dict, strict=False)
2153+
del non_strict_state_dict['weight_orig']
2154+
snm.load_state_dict(non_strict_state_dict, strict=False)
2155+
del non_strict_state_dict['weight_u']
2156+
snm.load_state_dict(non_strict_state_dict, strict=False)
2157+
del non_strict_state_dict['weight_v']
2158+
snm.load_state_dict(non_strict_state_dict, strict=False)
2159+
non_strict_state_dict['weight'] = snm.weight.detach().clone() # set W as a buffer
2160+
snm.load_state_dict(non_strict_state_dict, strict=False)
2161+
del non_strict_state_dict._metadata['']['spectral_norm'] # remove metadata info
2162+
snm.load_state_dict(non_strict_state_dict, strict=False)
2163+
del non_strict_state_dict['weight'] # remove W buffer
2164+
snm.load_state_dict(non_strict_state_dict, strict=False)
2165+
del non_strict_state_dict['bias']
2166+
snm.load_state_dict(non_strict_state_dict, strict=False)
2167+
21442168
# craft a version None state_dict
2145-
version_none_state_dict = deepcopy(snm.state_dict())
2146-
self.assertEqual({'weight_orig', 'bias', 'weight_u', 'weight_v'}, set(version_none_state_dict.keys()))
2169+
version_none_state_dict = deepcopy(version_latest_ref_state_dict)
21472170
self.assertIn('spectral_norm', version_none_state_dict._metadata[''])
21482171
del version_none_state_dict._metadata['']['spectral_norm'] # remove metadata info
21492172
del version_none_state_dict['weight_v'] # remove v vector
21502173
version_none_state_dict['weight'] = snm.weight.detach().clone() # set W as a buffer
21512174

21522175
# normal state_dict
2153-
version_latest_state_dict = deepcopy(snm.state_dict())
2176+
for version_latest_with_metadata in [True, False]:
2177+
version_latest_state_dict = deepcopy(version_latest_ref_state_dict)
21542178

2155-
snm.eval()
2156-
out0_eval = snm(inp)
2157-
snm.train()
2158-
out1_train = snm(inp)
2159-
out2_train = snm(inp)
2160-
snm.eval()
2161-
out3_eval = snm(inp)
2162-
2163-
snm.load_state_dict(version_none_state_dict)
2164-
if activate_times > 0:
2165-
# since in loading version None state dict, we assume that the
2166-
# values in the state dict have gone through at lease one
2167-
# forward, we only test for equivalence when activate_times > 0.
2168-
snm.eval()
2169-
self.assertEqual(out0_eval, snm(inp))
2170-
snm.train()
2171-
self.assertEqual(out1_train, snm(inp))
2172-
self.assertEqual(out2_train, snm(inp))
2173-
snm.eval()
2174-
self.assertEqual(out3_eval, snm(inp))
2175-
2176-
# Test normal loading
2177-
snm.load_state_dict(version_latest_state_dict)
2178-
snm.eval()
2179-
self.assertEqual(out0_eval, snm(inp))
2180-
snm.train()
2181-
self.assertEqual(out1_train, snm(inp))
2182-
self.assertEqual(out2_train, snm(inp))
2183-
snm.eval()
2184-
self.assertEqual(out3_eval, snm(inp))
2179+
if not version_latest_with_metadata:
2180+
# We want to still load a user-crafted state_dict, one without metadata
2181+
del version_latest_state_dict._metadata['']['spectral_norm']
2182+
2183+
# test that re-wrapping does not matter
2184+
m = torch.nn.utils.remove_spectral_norm(snm)
2185+
snm = torch.nn.utils.spectral_norm(m)
2186+
2187+
snm.load_state_dict(version_latest_ref_state_dict)
2188+
with torch.no_grad():
2189+
snm.eval()
2190+
out0_eval = snm(inp)
2191+
snm.train()
2192+
out1_train = snm(inp)
2193+
out2_train = snm(inp)
2194+
snm.eval()
2195+
out3_eval = snm(inp)
2196+
2197+
# test that re-wrapping does not matter
2198+
m = torch.nn.utils.remove_spectral_norm(snm)
2199+
snm = torch.nn.utils.spectral_norm(m)
2200+
2201+
snm.load_state_dict(version_none_state_dict)
2202+
if activate_times > 0:
2203+
# since in loading version None state dict, we assume that the
2204+
# values in the state dict have gone through at lease one
2205+
# forward, we only test for equivalence when activate_times > 0.
2206+
with torch.no_grad():
2207+
snm.eval()
2208+
self.assertEqual(out0_eval, snm(inp))
2209+
snm.train()
2210+
self.assertEqual(out1_train, snm(inp))
2211+
self.assertEqual(out2_train, snm(inp))
2212+
snm.eval()
2213+
self.assertEqual(out3_eval, snm(inp))
2214+
2215+
# test that re-wrapping does not matter
2216+
m = torch.nn.utils.remove_spectral_norm(snm)
2217+
snm = torch.nn.utils.spectral_norm(m)
2218+
2219+
# Test normal loading
2220+
snm.load_state_dict(version_latest_state_dict)
2221+
with torch.no_grad():
2222+
snm.eval()
2223+
self.assertEqual(out0_eval, snm(inp))
2224+
snm.train()
2225+
self.assertEqual(out1_train, snm(inp))
2226+
self.assertEqual(out2_train, snm(inp))
2227+
snm.eval()
2228+
self.assertEqual(out3_eval, snm(inp))
21852229

21862230
def test_spectral_norm_dim(self):
21872231
inp = torch.randn(2, 3, 10, 12)
@@ -3602,7 +3646,7 @@ def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, a
36023646
multihead_attn_module.add_zero_attn, multihead_attn_module.dropout,
36033647
multihead_attn_module.out_proj.weight, multihead_attn_module.out_proj.bias,
36043648
multihead_attn_module.training, key_padding_mask_tensor, True, attn_mask_tensor,
3605-
static_k=saved_k_tensor, static_v=saved_v_tensor)
3649+
static_k=saved_k_tensor, static_v=saved_v_tensor)
36063650
else:
36073651
result, result_weight = torch.nn.functional.multi_head_attention_forward(
36083652
_Q, _K, _V,
@@ -3612,9 +3656,9 @@ def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, a
36123656
multihead_attn_module.add_zero_attn, multihead_attn_module.dropout,
36133657
multihead_attn_module.out_proj.weight, multihead_attn_module.out_proj.bias,
36143658
multihead_attn_module.training, key_padding_mask_tensor, True, attn_mask_tensor,
3615-
True, multihead_attn_module.q_proj_weight,
3659+
True, multihead_attn_module.q_proj_weight,
36163660
multihead_attn_module.k_proj_weight, multihead_attn_module.v_proj_weight,
3617-
static_k=saved_k_tensor, static_v=saved_v_tensor)
3661+
static_k=saved_k_tensor, static_v=saved_v_tensor)
36183662

36193663
result = result.squeeze(0).detach().numpy()
36203664

@@ -3644,12 +3688,12 @@ def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, a
36443688

36453689
if saved_k is not None:
36463690
K_split = np.reshape(saved_k, [dims[0], nheads, dims[1], d_head])
3647-
else:
3691+
else:
36483692
K_split = _split_heads_ref(K_fc, dims, nheads, d_head)
36493693

36503694
if saved_k is not None:
36513695
V_split = np.reshape(saved_v, [dims[0], nheads, dims[1], d_head])
3652-
else:
3696+
else:
36533697
V_split = _split_heads_ref(V_fc, dims, nheads, d_head)
36543698

36553699
if add_zero_attn:
@@ -3668,7 +3712,7 @@ def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, a
36683712
V=V_split,
36693713
dims=Q_split.shape,
36703714
unseen_mask=attn_mask,
3671-
key_padding_mask=key_padding_mask
3715+
key_padding_mask=key_padding_mask
36723716
)
36733717
combined_attn_heads = _combine_heads_ref(
36743718
X=attn_heads, dims=[batch_sz, 1], nheads=nheads, d_head=d_head
@@ -3713,7 +3757,7 @@ def test_multihead_attn_all_arguments2():
37133757
add_zero_attn=True, saved_kv=True)
37143758

37153759
def test_multihead_attn_all_arguments3():
3716-
_multihead_attn_test_helper(add_key_padding_mask=True, add_zero_attn=True,
3760+
_multihead_attn_test_helper(add_key_padding_mask=True, add_zero_attn=True,
37173761
saved_kv=True, same_embed_dim=True)
37183762

37193763
test_multihead_attn_add_zero_attn() # Test MultiheadAttention with add_zero_attn

torch/nn/utils/spectral_norm.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,6 @@ def apply(module, name, n_power_iterations, dim, eps):
135135
module.register_buffer(fn.name + "_v", v)
136136

137137
module.register_forward_pre_hook(fn)
138-
139138
module._register_state_dict_hook(SpectralNormStateDictHook(fn))
140139
module._register_load_state_dict_pre_hook(SpectralNormLoadStateDictPreHook(fn))
141140
return fn
@@ -161,14 +160,30 @@ def __call__(self, state_dict, prefix, local_metadata, strict,
161160
fn = self.fn
162161
version = local_metadata.get('spectral_norm', {}).get(fn.name + '.version', None)
163162
if version is None or version < 1:
163+
weight_key = prefix + fn.name
164+
if version is None and all(weight_key + s in state_dict for s in ('_orig', '_u', '_v')) and \
165+
weight_key not in state_dict:
166+
# Detect if it is the updated state dict and just missing metadata.
167+
# This could happen if the users are crafting a state dict themselves,
168+
# so we just pretend that this is the newest.
169+
return
170+
has_missing_keys = False
171+
for suffix in ('_orig', '', '_u'):
172+
key = weight_key + suffix
173+
if key not in state_dict:
174+
has_missing_keys = True
175+
if strict:
176+
missing_keys.append(key)
177+
if has_missing_keys:
178+
return
164179
with torch.no_grad():
165-
weight_orig = state_dict[prefix + fn.name + '_orig']
166-
weight = state_dict.pop(prefix + fn.name)
180+
weight_orig = state_dict[weight_key + '_orig']
181+
weight = state_dict.pop(weight_key)
167182
sigma = (weight_orig / weight).mean()
168183
weight_mat = fn.reshape_weight_to_matrix(weight_orig)
169-
u = state_dict[prefix + fn.name + '_u']
184+
u = state_dict[weight_key + '_u']
170185
v = fn._solve_v_and_rescale(weight_mat, u, sigma)
171-
state_dict[prefix + fn.name + '_v'] = v
186+
state_dict[weight_key + '_v'] = v
172187

173188

174189
# This is a top level class because Py2 pickle doesn't like inner class nor an
@@ -255,7 +270,19 @@ def remove_spectral_norm(module, name='weight'):
255270
if isinstance(hook, SpectralNorm) and hook.name == name:
256271
hook.remove(module)
257272
del module._forward_pre_hooks[k]
258-
return module
273+
break
274+
else:
275+
raise ValueError("spectral_norm of '{}' not found in {}".format(
276+
name, module))
277+
278+
for k, hook in module._state_dict_hooks.items():
279+
if isinstance(hook, SpectralNormStateDictHook) and hook.fn.name == name:
280+
del module._state_dict_hooks[k]
281+
break
259282

260-
raise ValueError("spectral_norm of '{}' not found in {}".format(
261-
name, module))
283+
for k, hook in module._load_state_dict_pre_hooks.items():
284+
if isinstance(hook, SpectralNormLoadStateDictPreHook) and hook.fn.name == name:
285+
del module._load_state_dict_pre_hooks[k]
286+
break
287+
288+
return module

0 commit comments

Comments
 (0)