Skip to content

Commit e074cec

Browse files
committed
instance norm fix running stats settings
1 parent bf1c7d9 commit e074cec

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

torch/nn/modules/instancenorm.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ class _InstanceNorm(_BatchNorm):
66
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=False):
77
super(_InstanceNorm, self).__init__(
88
num_features, eps, momentum, affine)
9+
self.use_running_stats = False
910

1011
def forward(self, input):
1112
b, c = input.size(0), input.size(1)
@@ -24,16 +25,22 @@ def forward(self, input):
2425

2526
out = F.batch_norm(
2627
input_reshaped, running_mean, running_var, weight, bias,
27-
True, self.momentum, self.eps)
28+
not self.use_running_stats, self.momentum, self.eps)
2829

2930
# Reshape back
3031
self.running_mean.copy_(running_mean.view(b, c).mean(0, keepdim=False))
3132
self.running_var.copy_(running_var.view(b, c).mean(0, keepdim=False))
3233

3334
return out.view(b, c, *input.size()[2:])
3435

35-
def eval(self):
36-
return self
36+
def use_running_stats(self, mode=True):
37+
r"""Set using running statistics or instance statistics.
38+
39+
Instance normalization usually use instance statistics in both training
40+
and evaluation modes. But users can set this method to use running
41+
statistics in the fashion similar to batch normalization in eval mode.
42+
"""
43+
self.use_running_stats = mode
3744

3845

3946
class InstanceNorm1d(_InstanceNorm):
@@ -52,7 +59,8 @@ class InstanceNorm1d(_InstanceNorm):
5259
5360
At evaluation time (`.eval()`), the default behaviour of the InstanceNorm module stays the same
5461
i.e. running mean/variance is NOT used for normalization. One can force using stored
55-
mean and variance with `.train(False)` method.
62+
mean and variance with `.use_running_stats(mode=True)` method, and switch back to normal
63+
behavior with `.use_running_stats(mode=False)` method.
5664
5765
Args:
5866
num_features: num_features from an expected input of size `batch_size x num_features x width`
@@ -97,7 +105,8 @@ class InstanceNorm2d(_InstanceNorm):
97105
98106
At evaluation time (`.eval()`), the default behaviour of the InstanceNorm module stays the same
99107
i.e. running mean/variance is NOT used for normalization. One can force using stored
100-
mean and variance with `.train(False)` method.
108+
mean and variance with `.use_running_stats(mode=True)` method, and switch back to normal
109+
behavior with `.use_running_stats(mode=False)` method.
101110
102111
Args:
103112
num_features: num_features from an expected input of size batch_size x num_features x height x width
@@ -142,7 +151,8 @@ class InstanceNorm3d(_InstanceNorm):
142151
143152
At evaluation time (`.eval()`), the default behaviour of the InstanceNorm module stays the same
144153
i.e. running mean/variance is NOT used for normalization. One can force using stored
145-
mean and variance with `.train(False)` method.
154+
mean and variance with `.use_running_stats(mode=True)` method, and switch back to normal
155+
behavior with `.use_running_stats(mode=False)` method.
146156
147157
148158
Args:

0 commit comments

Comments
 (0)