@@ -6,6 +6,7 @@ class _InstanceNorm(_BatchNorm):
66 def __init__ (self , num_features , eps = 1e-5 , momentum = 0.1 , affine = False ):
77 super (_InstanceNorm , self ).__init__ (
88 num_features , eps , momentum , affine )
9+ self .use_running_stats = False
910
1011 def forward (self , input ):
1112 b , c = input .size (0 ), input .size (1 )
@@ -24,16 +25,22 @@ def forward(self, input):
2425
2526 out = F .batch_norm (
2627 input_reshaped , running_mean , running_var , weight , bias ,
27- True , self .momentum , self .eps )
28+ not self . use_running_stats , self .momentum , self .eps )
2829
2930 # Reshape back
3031 self .running_mean .copy_ (running_mean .view (b , c ).mean (0 , keepdim = False ))
3132 self .running_var .copy_ (running_var .view (b , c ).mean (0 , keepdim = False ))
3233
3334 return out .view (b , c , * input .size ()[2 :])
3435
35- def eval (self ):
36- return self
36+ def use_running_stats (self , mode = True ):
37+ r"""Set using running statistics or instance statistics.
38+
39+ Instance normalization usually use instance statistics in both training
40+ and evaluation modes. But users can set this method to use running
41+ statistics in the fashion similar to batch normalization in eval mode.
42+ """
43+ self .use_running_stats = mode
3744
3845
3946class InstanceNorm1d (_InstanceNorm ):
@@ -52,7 +59,8 @@ class InstanceNorm1d(_InstanceNorm):
5259
5360 At evaluation time (`.eval()`), the default behaviour of the InstanceNorm module stays the same
5461 i.e. running mean/variance is NOT used for normalization. One can force using stored
55- mean and variance with `.train(False)` method.
62+ mean and variance with `.use_running_stats(mode=True)` method, and switch back to normal
63+ behavior with `.use_running_stats(mode=False)` method.
5664
5765 Args:
5866 num_features: num_features from an expected input of size `batch_size x num_features x width`
@@ -97,7 +105,8 @@ class InstanceNorm2d(_InstanceNorm):
97105
98106 At evaluation time (`.eval()`), the default behaviour of the InstanceNorm module stays the same
99107 i.e. running mean/variance is NOT used for normalization. One can force using stored
100- mean and variance with `.train(False)` method.
108+ mean and variance with `.use_running_stats(mode=True)` method, and switch back to normal
109+ behavior with `.use_running_stats(mode=False)` method.
101110
102111 Args:
103112 num_features: num_features from an expected input of size batch_size x num_features x height x width
@@ -142,7 +151,8 @@ class InstanceNorm3d(_InstanceNorm):
142151
143152 At evaluation time (`.eval()`), the default behaviour of the InstanceNorm module stays the same
144153 i.e. running mean/variance is NOT used for normalization. One can force using stored
145- mean and variance with `.train(False)` method.
154+ mean and variance with `.use_running_stats(mode=True)` method, and switch back to normal
155+ behavior with `.use_running_stats(mode=False)` method.
146156
147157
148158 Args:
0 commit comments