Skip to content

Conversation

@crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Apr 25, 2018

related to #5027.
This PR aims at implementing Spectral Normalization in a way similar to torch.nn.utils.weight_norm.

@crcrpar crcrpar changed the title add spectral normalization add spectral normalization [pytorch] Apr 25, 2018
@soumith
Copy link
Contributor

soumith commented Apr 25, 2018

@pytorchbot retest this please

return v / denom


class SpectralNorm(object):

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@ezyang
Copy link
Contributor

ezyang commented Apr 25, 2018

@soumith Do we want this one?

@soumith
Copy link
Contributor

soumith commented Apr 25, 2018

We want spectral norm, yea. It's being used quite a lot these days.

\mathbf{W} &= \dfrac{\mathbf{W}}{\sigma(\mathbf{W})} \\
\sigma(\mathbf{W}) &= \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
Spectral normalization stabilize the training of discriminators(critics)

This comment was marked as off-topic.

This comment was marked as off-topic.

Spectral normalization stabilize the training of discriminators(critics)
in GANs by rescaling the weight tensor by spectral norm "sigma" of the
weight matrix calculated by power iteration method. If the dimension of the
weight tensor is greater than 2, reahaped to 2D in power iteration method

This comment was marked as off-topic.

This comment was marked as off-topic.

def compute_weight(self, module):
weight = module._parameters[self.name + '_org']
u = module._buffers[self.name + '_u']
height, _cuda = weight.size(0), weight.is_cuda

This comment was marked as off-topic.

This comment was marked as off-topic.

from torch.nn.parameter import Parameter


def l2normalize(v, eps=1e-12):

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

for _ in range(self.n_power_iterations):
v = l2normalize(torch.matmul(weight_mat.t(), u), self.eps)
u = l2normalize(torch.matmul(weight_mat, v), self.eps)
v = normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps)

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Spectral normalization stabilizes the training of discriminators (critics)
in GANs by rescaling the weight tensor by spectral norm "sigma" of the
weight matrix calculated by power iteration method. If the dimension of the

This comment was marked as off-topic.

This comment was marked as off-topic.

@ezyang
Copy link
Contributor

ezyang commented Apr 30, 2018

@fmassa @ssnl is this ok to merge now?

Copy link
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code looks solid. There are still some minor nits around doc and comments. I think this is good to go once those are fixed.

Spectral normalization stabilizes the training of discriminators (critics)
in GANs by rescaling the weight tensor with spectral norm :math:`\sigma` of
the weight matrix calculated using power iteration method. If the dimension
of the weight tensor is greater than 2, reshaped to 2D in power iteration method

This comment was marked as off-topic.

This comment was marked as off-topic.

spectral norm and rescales weight before every :meth:`~Module.forward`
call.
See https://arxiv.org/abs/1802.05957

This comment was marked as off-topic.

This comment was marked as off-topic.

r"""Removes the spectral normalization reparameterization from a module.
Args:
module (nn.Modue): containing module

This comment was marked as off-topic.

weight = module._parameters[name]
height = weight.size(0)

u = normalize(weight.data.new(height).normal_(0, 1), dim=0, eps=fn.eps)

This comment was marked as off-topic.

This comment was marked as off-topic.

\sigma(\mathbf{W}) &= \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
Spectral normalization stabilizes the training of discriminators (critics)
in GANs by rescaling the weight tensor with spectral norm :math:`\sigma` of

This comment was marked as off-topic.

This comment was marked as off-topic.

u = module._buffers[self.name + '_u']
height = weight.size(0)
weight_mat = weight.view(height, -1)
for _ in range(self.n_power_iterations):

This comment was marked as off-topic.

name (str, optional): name of weight parameter
n_power_iterations (int, optional): number of power iterations to
calculate spectal norm
eps (float, optional): epsilon for numerical stability

This comment was marked as off-topic.

This comment was marked as off-topic.

fix typos
add comments related to power iteration and epsilon
update link to the paper
make some comments specific
Copy link
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM except the typo

weight_mat = weight.view(height, -1)
for _ in range(self.n_power_iterations):
# Spectral norm of weight equals to `u^T W v`, where `u` and `v`
# are the first left and right singular vecors.

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. @fmassa , do you want to take an extra look?

@ssnl ssnl merged commit ba04633 into pytorch:master May 1, 2018
@ssnl
Copy link
Collaborator

ssnl commented May 1, 2018

Thank you! @crcrpar

@crcrpar
Copy link
Collaborator Author

crcrpar commented May 1, 2018

Thank you for your very kind reviews and suggestions!

Jorghi12 pushed a commit to wsttiger/pytorch that referenced this pull request May 10, 2018
* initial commit for spectral norm

* fix comment

* edit rst

* fix doc

* remove redundant empty line

* fix nit mistakes in doc

* replace l2normalize with F.normalize

* fix chained `by`

* fix docs

fix typos
add comments related to power iteration and epsilon
update link to the paper
make some comments specific

* fix typo
weiyangfb pushed a commit to weiyangfb/pytorch that referenced this pull request Jun 11, 2018
* initial commit for spectral norm

* fix comment

* edit rst

* fix doc

* remove redundant empty line

* fix nit mistakes in doc

* replace l2normalize with F.normalize

* fix chained `by`

* fix docs

fix typos
add comments related to power iteration and epsilon
update link to the paper
make some comments specific

* fix typo
u = normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps)

sigma = torch.dot(u, torch.matmul(weight_mat, v))
weight.data /= sigma

This comment was marked as off-topic.

This comment was marked as off-topic.

@crcrpar crcrpar deleted the spectral-norm branch April 1, 2019 23:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants