Skip to content

Commit cc8ba3d

Browse files
crcrparJorghi12
authored andcommitted
add spectral normalization [pytorch] (pytorch#6929)
* initial commit for spectral norm * fix comment * edit rst * fix doc * remove redundant empty line * fix nit mistakes in doc * replace l2normalize with F.normalize * fix chained `by` * fix docs fix typos add comments related to power iteration and epsilon update link to the paper make some comments specific * fix typo
1 parent 8fc135c commit cc8ba3d

File tree

4 files changed

+162
-0
lines changed

4 files changed

+162
-0
lines changed

docs/source/nn.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,16 @@ Utilities
726726

727727
.. autofunction:: torch.nn.utils.remove_weight_norm
728728

729+
:hidden:`spectral_norm`
730+
~~~~~~~~~~~~~~~~~~~~~
731+
732+
.. autofunction:: torch.nn.utils.spectral_norm
733+
734+
:hidden:`remove_spectral_norm`
735+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
736+
737+
.. autofunction:: torch.nn.utils.remove_spectral_norm
738+
729739

730740
.. currentmodule:: torch.nn.utils.rnn
731741

test/test_nn.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,6 +1324,39 @@ def test_weight_norm_pickle(self):
13241324
m = pickle.loads(pickle.dumps(m))
13251325
self.assertIsInstance(m, nn.Linear)
13261326

1327+
def test_spectral_norm(self):
1328+
input = torch.randn(3, 5)
1329+
m = nn.Linear(5, 7)
1330+
m = torch.nn.utils.spectral_norm(m)
1331+
1332+
self.assertEqual(m.weight_u.size(), torch.Size([m.weight.size(0)]))
1333+
self.assertTrue(hasattr(m, 'weight_org'))
1334+
1335+
m = torch.nn.utils.remove_spectral_norm(m)
1336+
self.assertFalse(hasattr(m, 'weight_org'))
1337+
self.assertFalse(hasattr(m, 'weight_u'))
1338+
1339+
def test_spectral_norm_forward(self):
1340+
input = torch.randn(3, 5)
1341+
m = nn.Linear(5, 7)
1342+
m = torch.nn.utils.spectral_norm(m)
1343+
# naive forward
1344+
_weight, _bias, _u = m.weight_org, m.bias, m.weight_u
1345+
_weight_mat = _weight.view(_weight.size(0), -1)
1346+
_v = torch.mv(_weight_mat.t(), _u)
1347+
_v = F.normalize(_v, dim=0, eps=1e-12)
1348+
_u = torch.mv(_weight_mat, _v)
1349+
_u = F.normalize(_u, dim=0, eps=1e-12)
1350+
_weight.data /= torch.dot(_u, torch.matmul(_weight_mat, _v))
1351+
out_hat = torch.nn.functional.linear(input, _weight, _bias)
1352+
expect_out = m(input)
1353+
self.assertAlmostEqual(expect_out, out_hat)
1354+
1355+
def test_spectral_norm_pickle(self):
1356+
m = torch.nn.utils.spectral_norm(nn.Linear(5, 7))
1357+
m = pickle.loads(pickle.dumps(m))
1358+
self.assertIsInstance(m, nn.Linear)
1359+
13271360
def test_embedding_sparse_basic(self):
13281361
embedding = nn.Embedding(10, 20, sparse=True)
13291362
input = Variable(torch.LongTensor([[0, 2, 4, 5], [4, 3, 0, 9]]))

torch/nn/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
from .clip_grad import clip_grad_norm, clip_grad_norm_, clip_grad_value_
33
from .weight_norm import weight_norm, remove_weight_norm
44
from .convert_parameters import parameters_to_vector, vector_to_parameters
5+
from .spectral_norm import spectral_norm, remove_spectral_norm

torch/nn/utils/spectral_norm.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
"""
2+
Spectral Normalization from https://arxiv.org/abs/1802.05957
3+
"""
4+
import torch
5+
from torch.nn.functional import normalize
6+
from torch.nn.parameter import Parameter
7+
8+
9+
class SpectralNorm(object):
10+
11+
def __init__(self, name='weight', n_power_iterations=1, eps=1e-12):
12+
self.name = name
13+
self.n_power_iterations = n_power_iterations
14+
self.eps = eps
15+
16+
def compute_weight(self, module):
17+
weight = module._parameters[self.name + '_org']
18+
u = module._buffers[self.name + '_u']
19+
height = weight.size(0)
20+
weight_mat = weight.view(height, -1)
21+
for _ in range(self.n_power_iterations):
22+
# Spectral norm of weight equals to `u^T W v`, where `u` and `v`
23+
# are the first left and right singular vectors.
24+
# This power iteration produces approximations of `u` and `v`.
25+
v = normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps)
26+
u = normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps)
27+
28+
sigma = torch.dot(u, torch.matmul(weight_mat, v))
29+
weight.data /= sigma
30+
return weight, u
31+
32+
def remove(self, module):
33+
weight = module._parameters[self.name + '_org']
34+
del module._parameters[self.name]
35+
del module._buffers[self.name + '_u']
36+
del module._parameters[self.name + '_org']
37+
module.register_parameter(self.name, weight)
38+
39+
def __call__(self, module, inputs):
40+
weight, u = self.compute_weight(module)
41+
setattr(module, self.name, weight)
42+
setattr(module, self.name + '_u', u)
43+
44+
@staticmethod
45+
def apply(module, name, n_power_iterations, eps):
46+
fn = SpectralNorm(name, n_power_iterations, eps)
47+
weight = module._parameters[name]
48+
height = weight.size(0)
49+
50+
u = normalize(weight.new_empty(height).normal_(0, 1), dim=0, eps=fn.eps)
51+
module.register_parameter(fn.name + "_org", weight)
52+
module.register_buffer(fn.name + "_u", u)
53+
54+
module.register_forward_pre_hook(fn)
55+
return fn
56+
57+
58+
def spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12):
59+
r"""Applies spectral normalization to a parameter in the given module.
60+
61+
.. math::
62+
\mathbf{W} &= \dfrac{\mathbf{W}}{\sigma(\mathbf{W})} \\
63+
\sigma(\mathbf{W}) &= \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
64+
65+
Spectral normalization stabilizes the training of discriminators (critics)
66+
in Generaive Adversarial Networks (GANs) by rescaling the weight tensor
67+
with spectral norm :math:`\sigma` of the weight matrix calculated using
68+
power iteration method. If the dimension of the weight tensor is greater
69+
than 2, it is reshaped to 2D in power iteration method to get spectral
70+
norm. This is implemented via a hook that calculates spectral norm and
71+
rescales weight before every :meth:`~Module.forward` call.
72+
73+
See `Spectral Normalization for Generative Adversarial Networks`_ .
74+
75+
.. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
76+
77+
Args:
78+
module (nn.Module): containing module
79+
name (str, optional): name of weight parameter
80+
n_power_iterations (int, optional): number of power iterations to
81+
calculate spectal norm
82+
eps (float, optional): epsilon for numerical stability in
83+
calculating norms
84+
85+
Returns:
86+
The original module with the spectal norm hook
87+
88+
Example::
89+
90+
>>> m = spectral_norm(nn.Linear(20, 40))
91+
Linear (20 -> 40)
92+
>>> m.weight_u.size()
93+
torch.Size([20])
94+
95+
"""
96+
SpectralNorm.apply(module, name, n_power_iterations, eps)
97+
return module
98+
99+
100+
def remove_spectral_norm(module, name='weight'):
101+
r"""Removes the spectral normalization reparameterization from a module.
102+
103+
Args:
104+
module (nn.Module): containing module
105+
name (str, optional): name of weight parameter
106+
107+
Example:
108+
>>> m = spectral_norm(nn.Linear(40, 10))
109+
>>> remove_spectral_norm(m)
110+
"""
111+
for k, hook in module._forward_pre_hooks.items():
112+
if isinstance(hook, SpectralNorm) and hook.name == name:
113+
hook.remove(module)
114+
del module._forward_pre_hooks[k]
115+
return module
116+
117+
raise ValueError("spectral_norm of '{}' not found in {}".format(
118+
name, module))

0 commit comments

Comments
 (0)