Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 84 additions & 40 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2141,47 +2141,91 @@ def test_spectral_norm_load_state_dict(self):
for _ in range(activate_times):
snm(inp)

version_latest_ref_state_dict = deepcopy(snm.state_dict())
self.assertEqual({'weight_orig', 'bias', 'weight_u', 'weight_v'}, set(version_latest_ref_state_dict.keys()))

# test that non-strict loading works
non_strict_state_dict = deepcopy(version_latest_ref_state_dict)
non_strict_state_dict['nonsense'] = 'nonsense'
with self.assertRaisesRegex(RuntimeError, r'Unexpected key\(s\) in state_dict: "nonsense"'):
snm.load_state_dict(non_strict_state_dict, strict=True)
snm.load_state_dict(non_strict_state_dict, strict=False)
del non_strict_state_dict['weight_orig']
snm.load_state_dict(non_strict_state_dict, strict=False)
del non_strict_state_dict['weight_u']
snm.load_state_dict(non_strict_state_dict, strict=False)
del non_strict_state_dict['weight_v']
snm.load_state_dict(non_strict_state_dict, strict=False)
non_strict_state_dict['weight'] = snm.weight.detach().clone() # set W as a buffer
snm.load_state_dict(non_strict_state_dict, strict=False)
del non_strict_state_dict._metadata['']['spectral_norm'] # remove metadata info
snm.load_state_dict(non_strict_state_dict, strict=False)
del non_strict_state_dict['weight'] # remove W buffer
snm.load_state_dict(non_strict_state_dict, strict=False)
del non_strict_state_dict['bias']
snm.load_state_dict(non_strict_state_dict, strict=False)

# craft a version None state_dict
version_none_state_dict = deepcopy(snm.state_dict())
self.assertEqual({'weight_orig', 'bias', 'weight_u', 'weight_v'}, set(version_none_state_dict.keys()))
version_none_state_dict = deepcopy(version_latest_ref_state_dict)
self.assertIn('spectral_norm', version_none_state_dict._metadata[''])
del version_none_state_dict._metadata['']['spectral_norm'] # remove metadata info
del version_none_state_dict['weight_v'] # remove v vector
version_none_state_dict['weight'] = snm.weight.detach().clone() # set W as a buffer

# normal state_dict
version_latest_state_dict = deepcopy(snm.state_dict())
for version_latest_with_metadata in [True, False]:
version_latest_state_dict = deepcopy(version_latest_ref_state_dict)

snm.eval()
out0_eval = snm(inp)
snm.train()
out1_train = snm(inp)
out2_train = snm(inp)
snm.eval()
out3_eval = snm(inp)

snm.load_state_dict(version_none_state_dict)
if activate_times > 0:
# since in loading version None state dict, we assume that the
# values in the state dict have gone through at lease one
# forward, we only test for equivalence when activate_times > 0.
snm.eval()
self.assertEqual(out0_eval, snm(inp))
snm.train()
self.assertEqual(out1_train, snm(inp))
self.assertEqual(out2_train, snm(inp))
snm.eval()
self.assertEqual(out3_eval, snm(inp))

# Test normal loading
snm.load_state_dict(version_latest_state_dict)
snm.eval()
self.assertEqual(out0_eval, snm(inp))
snm.train()
self.assertEqual(out1_train, snm(inp))
self.assertEqual(out2_train, snm(inp))
snm.eval()
self.assertEqual(out3_eval, snm(inp))
if not version_latest_with_metadata:
# We want to still load a user-crafted state_dict, one without metadata
del version_latest_state_dict._metadata['']['spectral_norm']

# test that re-wrapping does not matter
m = torch.nn.utils.remove_spectral_norm(snm)
snm = torch.nn.utils.spectral_norm(m)

snm.load_state_dict(version_latest_ref_state_dict)
with torch.no_grad():
snm.eval()
out0_eval = snm(inp)
snm.train()
out1_train = snm(inp)
out2_train = snm(inp)
snm.eval()
out3_eval = snm(inp)

# test that re-wrapping does not matter
m = torch.nn.utils.remove_spectral_norm(snm)
snm = torch.nn.utils.spectral_norm(m)

snm.load_state_dict(version_none_state_dict)
if activate_times > 0:
# since in loading version None state dict, we assume that the
# values in the state dict have gone through at lease one
# forward, we only test for equivalence when activate_times > 0.
with torch.no_grad():
snm.eval()
self.assertEqual(out0_eval, snm(inp))
snm.train()
self.assertEqual(out1_train, snm(inp))
self.assertEqual(out2_train, snm(inp))
snm.eval()
self.assertEqual(out3_eval, snm(inp))

# test that re-wrapping does not matter
m = torch.nn.utils.remove_spectral_norm(snm)
snm = torch.nn.utils.spectral_norm(m)

# Test normal loading
snm.load_state_dict(version_latest_state_dict)
with torch.no_grad():
snm.eval()
self.assertEqual(out0_eval, snm(inp))
snm.train()
self.assertEqual(out1_train, snm(inp))
self.assertEqual(out2_train, snm(inp))
snm.eval()
self.assertEqual(out3_eval, snm(inp))

def test_spectral_norm_dim(self):
inp = torch.randn(2, 3, 10, 12)
Expand Down Expand Up @@ -3602,7 +3646,7 @@ def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, a
multihead_attn_module.add_zero_attn, multihead_attn_module.dropout,
multihead_attn_module.out_proj.weight, multihead_attn_module.out_proj.bias,
multihead_attn_module.training, key_padding_mask_tensor, True, attn_mask_tensor,
static_k=saved_k_tensor, static_v=saved_v_tensor)
static_k=saved_k_tensor, static_v=saved_v_tensor)
else:
result, result_weight = torch.nn.functional.multi_head_attention_forward(
_Q, _K, _V,
Expand All @@ -3612,9 +3656,9 @@ def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, a
multihead_attn_module.add_zero_attn, multihead_attn_module.dropout,
multihead_attn_module.out_proj.weight, multihead_attn_module.out_proj.bias,
multihead_attn_module.training, key_padding_mask_tensor, True, attn_mask_tensor,
True, multihead_attn_module.q_proj_weight,
True, multihead_attn_module.q_proj_weight,
multihead_attn_module.k_proj_weight, multihead_attn_module.v_proj_weight,
static_k=saved_k_tensor, static_v=saved_v_tensor)
static_k=saved_k_tensor, static_v=saved_v_tensor)

result = result.squeeze(0).detach().numpy()

Expand Down Expand Up @@ -3644,12 +3688,12 @@ def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, a

if saved_k is not None:
K_split = np.reshape(saved_k, [dims[0], nheads, dims[1], d_head])
else:
else:
K_split = _split_heads_ref(K_fc, dims, nheads, d_head)

if saved_k is not None:
V_split = np.reshape(saved_v, [dims[0], nheads, dims[1], d_head])
else:
else:
V_split = _split_heads_ref(V_fc, dims, nheads, d_head)

if add_zero_attn:
Expand All @@ -3668,7 +3712,7 @@ def _multihead_attn_test_helper(add_key_padding_mask=False, add_bias_kv=False, a
V=V_split,
dims=Q_split.shape,
unseen_mask=attn_mask,
key_padding_mask=key_padding_mask
key_padding_mask=key_padding_mask
)
combined_attn_heads = _combine_heads_ref(
X=attn_heads, dims=[batch_sz, 1], nheads=nheads, d_head=d_head
Expand Down Expand Up @@ -3713,7 +3757,7 @@ def test_multihead_attn_all_arguments2():
add_zero_attn=True, saved_kv=True)

def test_multihead_attn_all_arguments3():
_multihead_attn_test_helper(add_key_padding_mask=True, add_zero_attn=True,
_multihead_attn_test_helper(add_key_padding_mask=True, add_zero_attn=True,
saved_kv=True, same_embed_dim=True)

test_multihead_attn_add_zero_attn() # Test MultiheadAttention with add_zero_attn
Expand Down
43 changes: 35 additions & 8 deletions torch/nn/utils/spectral_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ def apply(module, name, n_power_iterations, dim, eps):
module.register_buffer(fn.name + "_v", v)

module.register_forward_pre_hook(fn)

module._register_state_dict_hook(SpectralNormStateDictHook(fn))
module._register_load_state_dict_pre_hook(SpectralNormLoadStateDictPreHook(fn))
return fn
Expand All @@ -161,14 +160,30 @@ def __call__(self, state_dict, prefix, local_metadata, strict,
fn = self.fn
version = local_metadata.get('spectral_norm', {}).get(fn.name + '.version', None)
if version is None or version < 1:
weight_key = prefix + fn.name
if version is None and all(weight_key + s in state_dict for s in ('_orig', '_u', '_v')) and \
weight_key not in state_dict:
# Detect if it is the updated state dict and just missing metadata.
# This could happen if the users are crafting a state dict themselves,
# so we just pretend that this is the newest.
return
has_missing_keys = False
for suffix in ('_orig', '', '_u'):
key = weight_key + suffix
if key not in state_dict:
has_missing_keys = True
if strict:
missing_keys.append(key)
if has_missing_keys:
return
with torch.no_grad():
weight_orig = state_dict[prefix + fn.name + '_orig']
weight = state_dict.pop(prefix + fn.name)
weight_orig = state_dict[weight_key + '_orig']
weight = state_dict.pop(weight_key)
sigma = (weight_orig / weight).mean()
weight_mat = fn.reshape_weight_to_matrix(weight_orig)
u = state_dict[prefix + fn.name + '_u']
u = state_dict[weight_key + '_u']
v = fn._solve_v_and_rescale(weight_mat, u, sigma)
state_dict[prefix + fn.name + '_v'] = v
state_dict[weight_key + '_v'] = v


# This is a top level class because Py2 pickle doesn't like inner class nor an
Expand Down Expand Up @@ -255,7 +270,19 @@ def remove_spectral_norm(module, name='weight'):
if isinstance(hook, SpectralNorm) and hook.name == name:
hook.remove(module)
del module._forward_pre_hooks[k]
return module
break
else:
raise ValueError("spectral_norm of '{}' not found in {}".format(
name, module))

for k, hook in module._state_dict_hooks.items():
if isinstance(hook, SpectralNormStateDictHook) and hook.fn.name == name:
del module._state_dict_hooks[k]
break

raise ValueError("spectral_norm of '{}' not found in {}".format(
name, module))
for k, hook in module._load_state_dict_pre_hooks.items():
if isinstance(hook, SpectralNormLoadStateDictPreHook) and hook.fn.name == name:
del module._load_state_dict_pre_hooks[k]
break

return module