Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions torch/nn/modules/instancenorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class _InstanceNorm(_BatchNorm):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=False):
super(_InstanceNorm, self).__init__(
num_features, eps, momentum, affine)
self.use_running_stats = False

This comment was marked as off-topic.

This comment was marked as off-topic.


def forward(self, input):
b, c = input.size(0), input.size(1)
Expand All @@ -24,16 +25,22 @@ def forward(self, input):

out = F.batch_norm(
input_reshaped, running_mean, running_var, weight, bias,
True, self.momentum, self.eps)
not self.use_running_stats, self.momentum, self.eps)

# Reshape back
self.running_mean.copy_(running_mean.view(b, c).mean(0, keepdim=False))
self.running_var.copy_(running_var.view(b, c).mean(0, keepdim=False))

return out.view(b, c, *input.size()[2:])

def eval(self):
return self
def use_running_stats(self, mode=True):
r"""Set using running statistics or instance statistics.

Instance normalization usually use instance statistics in both training
and evaluation modes. But users can set this method to use running
statistics in the fashion similar to batch normalization in eval mode.
"""
self.use_running_stats = mode


class InstanceNorm1d(_InstanceNorm):
Expand All @@ -52,7 +59,8 @@ class InstanceNorm1d(_InstanceNorm):

At evaluation time (`.eval()`), the default behaviour of the InstanceNorm module stays the same
i.e. running mean/variance is NOT used for normalization. One can force using stored
mean and variance with `.train(False)` method.
mean and variance with `.use_running_stats(mode=True)` method, and switch back to normal
behavior with `.use_running_stats(mode=False)` method.

Args:
num_features: num_features from an expected input of size `batch_size x num_features x width`
Expand Down Expand Up @@ -97,7 +105,8 @@ class InstanceNorm2d(_InstanceNorm):

At evaluation time (`.eval()`), the default behaviour of the InstanceNorm module stays the same
i.e. running mean/variance is NOT used for normalization. One can force using stored
mean and variance with `.train(False)` method.
mean and variance with `.use_running_stats(mode=True)` method, and switch back to normal
behavior with `.use_running_stats(mode=False)` method.

Args:
num_features: num_features from an expected input of size batch_size x num_features x height x width
Expand Down Expand Up @@ -142,7 +151,8 @@ class InstanceNorm3d(_InstanceNorm):

At evaluation time (`.eval()`), the default behaviour of the InstanceNorm module stays the same
i.e. running mean/variance is NOT used for normalization. One can force using stored
mean and variance with `.train(False)` method.
mean and variance with `.use_running_stats(mode=True)` method, and switch back to normal
behavior with `.use_running_stats(mode=False)` method.


Args:
Expand Down