|
| 1 | +""" |
| 2 | +Spectral Normalization from https://arxiv.org/abs/1802.05957 |
| 3 | +""" |
| 4 | +import torch |
| 5 | +from torch.nn.functional import normalize |
| 6 | +from torch.nn.parameter import Parameter |
| 7 | + |
| 8 | + |
| 9 | +class SpectralNorm(object): |
| 10 | + |
| 11 | + def __init__(self, name='weight', n_power_iterations=1, eps=1e-12): |
| 12 | + self.name = name |
| 13 | + self.n_power_iterations = n_power_iterations |
| 14 | + self.eps = eps |
| 15 | + |
| 16 | + def compute_weight(self, module): |
| 17 | + weight = module._parameters[self.name + '_org'] |
| 18 | + u = module._buffers[self.name + '_u'] |
| 19 | + height = weight.size(0) |
| 20 | + weight_mat = weight.view(height, -1) |
| 21 | + for _ in range(self.n_power_iterations): |
| 22 | + # Spectral norm of weight equals to `u^T W v`, where `u` and `v` |
| 23 | + # are the first left and right singular vectors. |
| 24 | + # This power iteration produces approximations of `u` and `v`. |
| 25 | + v = normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps) |
| 26 | + u = normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps) |
| 27 | + |
| 28 | + sigma = torch.dot(u, torch.matmul(weight_mat, v)) |
| 29 | + weight.data /= sigma |
| 30 | + return weight, u |
| 31 | + |
| 32 | + def remove(self, module): |
| 33 | + weight = module._parameters[self.name + '_org'] |
| 34 | + del module._parameters[self.name] |
| 35 | + del module._buffers[self.name + '_u'] |
| 36 | + del module._parameters[self.name + '_org'] |
| 37 | + module.register_parameter(self.name, weight) |
| 38 | + |
| 39 | + def __call__(self, module, inputs): |
| 40 | + weight, u = self.compute_weight(module) |
| 41 | + setattr(module, self.name, weight) |
| 42 | + setattr(module, self.name + '_u', u) |
| 43 | + |
| 44 | + @staticmethod |
| 45 | + def apply(module, name, n_power_iterations, eps): |
| 46 | + fn = SpectralNorm(name, n_power_iterations, eps) |
| 47 | + weight = module._parameters[name] |
| 48 | + height = weight.size(0) |
| 49 | + |
| 50 | + u = normalize(weight.new_empty(height).normal_(0, 1), dim=0, eps=fn.eps) |
| 51 | + module.register_parameter(fn.name + "_org", weight) |
| 52 | + module.register_buffer(fn.name + "_u", u) |
| 53 | + |
| 54 | + module.register_forward_pre_hook(fn) |
| 55 | + return fn |
| 56 | + |
| 57 | + |
| 58 | +def spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12): |
| 59 | + r"""Applies spectral normalization to a parameter in the given module. |
| 60 | +
|
| 61 | + .. math:: |
| 62 | + \mathbf{W} &= \dfrac{\mathbf{W}}{\sigma(\mathbf{W})} \\ |
| 63 | + \sigma(\mathbf{W}) &= \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} |
| 64 | +
|
| 65 | + Spectral normalization stabilizes the training of discriminators (critics) |
| 66 | + in Generaive Adversarial Networks (GANs) by rescaling the weight tensor |
| 67 | + with spectral norm :math:`\sigma` of the weight matrix calculated using |
| 68 | + power iteration method. If the dimension of the weight tensor is greater |
| 69 | + than 2, it is reshaped to 2D in power iteration method to get spectral |
| 70 | + norm. This is implemented via a hook that calculates spectral norm and |
| 71 | + rescales weight before every :meth:`~Module.forward` call. |
| 72 | +
|
| 73 | + See `Spectral Normalization for Generative Adversarial Networks`_ . |
| 74 | +
|
| 75 | + .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 |
| 76 | +
|
| 77 | + Args: |
| 78 | + module (nn.Module): containing module |
| 79 | + name (str, optional): name of weight parameter |
| 80 | + n_power_iterations (int, optional): number of power iterations to |
| 81 | + calculate spectal norm |
| 82 | + eps (float, optional): epsilon for numerical stability in |
| 83 | + calculating norms |
| 84 | +
|
| 85 | + Returns: |
| 86 | + The original module with the spectal norm hook |
| 87 | +
|
| 88 | + Example:: |
| 89 | +
|
| 90 | + >>> m = spectral_norm(nn.Linear(20, 40)) |
| 91 | + Linear (20 -> 40) |
| 92 | + >>> m.weight_u.size() |
| 93 | + torch.Size([20]) |
| 94 | +
|
| 95 | + """ |
| 96 | + SpectralNorm.apply(module, name, n_power_iterations, eps) |
| 97 | + return module |
| 98 | + |
| 99 | + |
| 100 | +def remove_spectral_norm(module, name='weight'): |
| 101 | + r"""Removes the spectral normalization reparameterization from a module. |
| 102 | +
|
| 103 | + Args: |
| 104 | + module (nn.Module): containing module |
| 105 | + name (str, optional): name of weight parameter |
| 106 | +
|
| 107 | + Example: |
| 108 | + >>> m = spectral_norm(nn.Linear(40, 10)) |
| 109 | + >>> remove_spectral_norm(m) |
| 110 | + """ |
| 111 | + for k, hook in module._forward_pre_hooks.items(): |
| 112 | + if isinstance(hook, SpectralNorm) and hook.name == name: |
| 113 | + hook.remove(module) |
| 114 | + del module._forward_pre_hooks[k] |
| 115 | + return module |
| 116 | + |
| 117 | + raise ValueError("spectral_norm of '{}' not found in {}".format( |
| 118 | + name, module)) |
0 commit comments