-
Notifications
You must be signed in to change notification settings - Fork 26.3k
add spectral normalization [pytorch] #6929
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@pytorchbot retest this please |
| return v / denom | ||
|
|
||
|
|
||
| class SpectralNorm(object): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@soumith Do we want this one? |
|
We want spectral norm, yea. It's being used quite a lot these days. |
torch/nn/utils/spectral_norm.py
Outdated
| \mathbf{W} &= \dfrac{\mathbf{W}}{\sigma(\mathbf{W})} \\ | ||
| \sigma(\mathbf{W}) &= \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} | ||
| Spectral normalization stabilize the training of discriminators(critics) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/utils/spectral_norm.py
Outdated
| Spectral normalization stabilize the training of discriminators(critics) | ||
| in GANs by rescaling the weight tensor by spectral norm "sigma" of the | ||
| weight matrix calculated by power iteration method. If the dimension of the | ||
| weight tensor is greater than 2, reahaped to 2D in power iteration method |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/utils/spectral_norm.py
Outdated
| def compute_weight(self, module): | ||
| weight = module._parameters[self.name + '_org'] | ||
| u = module._buffers[self.name + '_u'] | ||
| height, _cuda = weight.size(0), weight.is_cuda |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/utils/spectral_norm.py
Outdated
| from torch.nn.parameter import Parameter | ||
|
|
||
|
|
||
| def l2normalize(v, eps=1e-12): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| for _ in range(self.n_power_iterations): | ||
| v = l2normalize(torch.matmul(weight_mat.t(), u), self.eps) | ||
| u = l2normalize(torch.matmul(weight_mat, v), self.eps) | ||
| v = normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/utils/spectral_norm.py
Outdated
| Spectral normalization stabilizes the training of discriminators (critics) | ||
| in GANs by rescaling the weight tensor by spectral norm "sigma" of the | ||
| weight matrix calculated by power iteration method. If the dimension of the |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
ssnl
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code looks solid. There are still some minor nits around doc and comments. I think this is good to go once those are fixed.
torch/nn/utils/spectral_norm.py
Outdated
| Spectral normalization stabilizes the training of discriminators (critics) | ||
| in GANs by rescaling the weight tensor with spectral norm :math:`\sigma` of | ||
| the weight matrix calculated using power iteration method. If the dimension | ||
| of the weight tensor is greater than 2, reshaped to 2D in power iteration method |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/utils/spectral_norm.py
Outdated
| spectral norm and rescales weight before every :meth:`~Module.forward` | ||
| call. | ||
| See https://arxiv.org/abs/1802.05957 |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/utils/spectral_norm.py
Outdated
| r"""Removes the spectral normalization reparameterization from a module. | ||
| Args: | ||
| module (nn.Modue): containing module |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/utils/spectral_norm.py
Outdated
| weight = module._parameters[name] | ||
| height = weight.size(0) | ||
|
|
||
| u = normalize(weight.data.new(height).normal_(0, 1), dim=0, eps=fn.eps) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/utils/spectral_norm.py
Outdated
| \sigma(\mathbf{W}) &= \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} | ||
| Spectral normalization stabilizes the training of discriminators (critics) | ||
| in GANs by rescaling the weight tensor with spectral norm :math:`\sigma` of |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
| u = module._buffers[self.name + '_u'] | ||
| height = weight.size(0) | ||
| weight_mat = weight.view(height, -1) | ||
| for _ in range(self.n_power_iterations): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/utils/spectral_norm.py
Outdated
| name (str, optional): name of weight parameter | ||
| n_power_iterations (int, optional): number of power iterations to | ||
| calculate spectal norm | ||
| eps (float, optional): epsilon for numerical stability |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
ssnl
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM except the typo
torch/nn/utils/spectral_norm.py
Outdated
| weight_mat = weight.view(height, -1) | ||
| for _ in range(self.n_power_iterations): | ||
| # Spectral norm of weight equals to `u^T W v`, where `u` and `v` | ||
| # are the first left and right singular vecors. |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
ssnl
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. @fmassa , do you want to take an extra look?
|
Thank you! @crcrpar |
|
Thank you for your very kind reviews and suggestions! |
* initial commit for spectral norm * fix comment * edit rst * fix doc * remove redundant empty line * fix nit mistakes in doc * replace l2normalize with F.normalize * fix chained `by` * fix docs fix typos add comments related to power iteration and epsilon update link to the paper make some comments specific * fix typo
* initial commit for spectral norm * fix comment * edit rst * fix doc * remove redundant empty line * fix nit mistakes in doc * replace l2normalize with F.normalize * fix chained `by` * fix docs fix typos add comments related to power iteration and epsilon update link to the paper make some comments specific * fix typo
| u = normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps) | ||
|
|
||
| sigma = torch.dot(u, torch.matmul(weight_mat, v)) | ||
| weight.data /= sigma |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
related to #5027.
This PR aims at implementing Spectral Normalization in a way similar to
torch.nn.utils.weight_norm.