Skip to content

Commit 99b1f6c

Browse files
jma127soumith
authored andcommitted
Enable resetting of batchnorm running moments and cumulative ("simple") moving average (#5766)
1 parent 5014adf commit 99b1f6c

File tree

3 files changed

+108
-9
lines changed

3 files changed

+108
-9
lines changed

test/expect/TestJit.test_batchnorm.expect

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ graph(%0 : Double(2, 2, 2, 2)
22
%1 : Double(2)
33
%2 : Double(2)
44
%3 : Double(2)
5-
%4 : Double(2)) {
6-
%5 : Double(2, 2, 2, 2) = aten::batch_norm[training=1, momentum=0.1, eps=1e-05, cudnn_enabled=1](%0, %1, %2, %3, %4), scope: BatchNorm2d
7-
return (%5);
5+
%4 : Double(2)
6+
%5 : Long(1)) {
7+
%6 : Double(2, 2, 2, 2) = aten::batch_norm[training=1, momentum=0.1, eps=1e-05, cudnn_enabled=1](%0, %1, %2, %3, %4), scope: BatchNorm2d
8+
return (%6);
89
}

test/test_nn.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1781,6 +1781,13 @@ def test_batchnorm_eval(self):
17811781
def test_batchnorm_eval_cuda(self):
17821782
self._test_batchnorm_eval(torch.cuda.FloatTensor)
17831783

1784+
def test_batchnorm_simple_average(self):
1785+
self._test_batchnorm_simple_average()
1786+
1787+
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
1788+
def test_batchnorm_simple_average_cuda(self):
1789+
self._test_batchnorm_simple_average(torch.cuda.FloatTensor)
1790+
17841791
def test_MaxPool1d_indices(self):
17851792
self._test_maxpool_indices(1)
17861793

@@ -1919,6 +1926,7 @@ def test_replicate_buffers(self):
19191926
for i, replica in enumerate(replicas):
19201927
self.assertEqual(replica.bn.running_mean.get_device(), i, 'buffer on wrong device')
19211928
self.assertEqual(replica.bn.running_var.get_device(), i, 'buffer on wrong device')
1929+
self.assertEqual(replica.bn.num_batches_tracked.get_device(), i, 'buffer on wrong device')
19221930

19231931
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
19241932
def test_parallel_apply(self):
@@ -2235,7 +2243,7 @@ def test_state_dict(self):
22352243
net.add_module('empty', None)
22362244

22372245
state_dict = net.state_dict()
2238-
self.assertEqual(len(state_dict), 9)
2246+
self.assertEqual(len(state_dict), 10)
22392247
self.assertIn('linear1.weight', state_dict)
22402248
self.assertIn('linear1.bias', state_dict)
22412249
self.assertIn('linear2.weight', state_dict)
@@ -2247,6 +2255,7 @@ def test_state_dict(self):
22472255
self.assertIn('bn.bias', state_dict)
22482256
self.assertIn('bn.running_var', state_dict)
22492257
self.assertIn('bn.running_mean', state_dict)
2258+
self.assertIn('bn.num_batches_tracked', state_dict)
22502259
self.assertFalse(any(map(lambda k: k.startswith('empty'), state_dict.keys())))
22512260
for k, v in state_dict.items():
22522261
param = net
@@ -3693,17 +3702,21 @@ def _test_batchnorm_update_stats(self, test_type=torch.FloatTensor):
36933702
# training pass
36943703
old_running_mean = module.running_mean.clone()
36953704
old_running_var = module.running_var.clone()
3705+
old_num_batches_tracked = module.num_batches_tracked.clone()
36963706
module(data)
36973707
self.assertNotEqual(old_running_mean, module.running_mean)
36983708
self.assertNotEqual(old_running_var, module.running_var)
3709+
self.assertEqual(old_num_batches_tracked + 1, module.num_batches_tracked)
36993710

37003711
# eval pass
37013712
module.eval()
37023713
old_running_mean = module.running_mean.clone()
37033714
old_running_var = module.running_var.clone()
3715+
old_num_batches_tracked = module.num_batches_tracked.clone()
37043716
module(data)
37053717
self.assertEqual(old_running_mean, module.running_mean)
37063718
self.assertEqual(old_running_var, module.running_var)
3719+
self.assertEqual(old_num_batches_tracked, module.num_batches_tracked)
37073720

37083721
def test_batchnorm_update_stats(self):
37093722
self._test_batchnorm_update_stats()
@@ -3792,6 +3805,48 @@ def _test_batchnorm_eval(self, test_type=torch.FloatTensor):
37923805
self.assertEqual(res1, res2)
37933806
self.assertEqual(grad1, grad2)
37943807

3808+
def _test_batchnorm_simple_average(self, test_type=torch.FloatTensor):
3809+
module = nn.BatchNorm1d(3, momentum=None).type(test_type)
3810+
zeros = torch.zeros(3).type(test_type)
3811+
ones = torch.ones(3).type(test_type)
3812+
self.assertEqual(module.running_mean, zeros)
3813+
self.assertEqual(module.running_var, ones)
3814+
3815+
data1 = torch.rand(4, 3).type(test_type)
3816+
data2 = torch.rand(4, 3).type(test_type)
3817+
3818+
# 1st pass
3819+
res1 = module(data1)
3820+
running_mean1 = module.running_mean.clone()
3821+
running_var1 = module.running_var.clone()
3822+
self.assertNotEqual(running_mean1, zeros)
3823+
self.assertNotEqual(running_var1, ones)
3824+
3825+
# reset stats
3826+
module.reset_running_stats()
3827+
self.assertEqual(module.running_mean, zeros)
3828+
self.assertEqual(module.running_var, ones)
3829+
3830+
# 2nd pass
3831+
res2 = module(data2)
3832+
running_mean2 = module.running_mean.clone()
3833+
running_var2 = module.running_var.clone()
3834+
self.assertNotEqual(running_mean2, zeros)
3835+
self.assertNotEqual(running_var2, ones)
3836+
3837+
# reset stats
3838+
module.reset_running_stats()
3839+
self.assertEqual(module.running_mean, zeros)
3840+
self.assertEqual(module.running_var, ones)
3841+
3842+
# 3rd (combined) pass
3843+
res3 = module(data1)
3844+
res4 = module(data2)
3845+
self.assertEqual(res3, res1)
3846+
self.assertEqual(res4, res2)
3847+
self.assertAlmostEqual(module.running_mean, (running_mean1 + running_mean2) / 2)
3848+
self.assertAlmostEqual(module.running_var, (running_var1 + running_var2) / 2)
3849+
37953850
def test_pairwise_distance(self):
37963851
input1 = Variable(torch.randn(4, 4), requires_grad=True)
37973852
input2 = Variable(torch.randn(4, 4), requires_grad=True)
@@ -5477,6 +5532,14 @@ def multimarginloss_weights_no_reduce_test():
54775532
check_eval=True,
54785533
desc='3d_input',
54795534
),
5535+
dict(
5536+
module_name='BatchNorm1d',
5537+
constructor_args=(10, 1e-3, None),
5538+
input_size=(4, 10),
5539+
cudnn=True,
5540+
check_eval=True,
5541+
desc='affine_simple_average',
5542+
),
54805543
dict(
54815544
module_name='BatchNorm1d',
54825545
constructor_args=(10, 1e-3, 0.3, False),
@@ -5508,6 +5571,14 @@ def multimarginloss_weights_no_reduce_test():
55085571
cudnn=True,
55095572
check_eval=True,
55105573
),
5574+
dict(
5575+
module_name='BatchNorm2d',
5576+
constructor_args=(3, 1e-3, None),
5577+
input_size=(2, 3, 6, 6),
5578+
cudnn=True,
5579+
check_eval=True,
5580+
desc='2d_simple_average',
5581+
),
55115582
dict(
55125583
module_name='BatchNorm2d',
55135584
constructor_args=(3, 1e-3, 0.8),
@@ -5539,6 +5610,14 @@ def multimarginloss_weights_no_reduce_test():
55395610
cudnn=True,
55405611
check_eval=True,
55415612
),
5613+
dict(
5614+
module_name='BatchNorm3d',
5615+
constructor_args=(3, 1e-3, None),
5616+
input_size=(2, 3, 4, 4, 4),
5617+
cudnn=True,
5618+
check_eval=True,
5619+
desc='3d_simple_average',
5620+
),
55425621
dict(
55435622
module_name='BatchNorm3d',
55445623
constructor_args=(3, 1e-3, 0.7),

torch/nn/modules/batchnorm.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,21 @@ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
2525
if self.track_running_stats:
2626
self.register_buffer('running_mean', torch.zeros(num_features))
2727
self.register_buffer('running_var', torch.ones(num_features))
28+
self.register_buffer('num_batches_tracked', torch.LongTensor([0]))
2829
else:
2930
self.register_parameter('running_mean', None)
3031
self.register_parameter('running_var', None)
32+
self.register_parameter('num_batches_tracked', None)
3133
self.reset_parameters()
3234

33-
def reset_parameters(self):
35+
def reset_running_stats(self):
3436
if self.track_running_stats:
3537
self.running_mean.zero_()
3638
self.running_var.fill_(1)
39+
self.num_batches_tracked.zero_()
40+
41+
def reset_parameters(self):
42+
self.reset_running_stats()
3743
if self.affine:
3844
self.weight.data.uniform_()
3945
self.bias.data.zero_()
@@ -44,9 +50,19 @@ def _check_input_dim(self, input):
4450
def forward(self, input):
4551
self._check_input_dim(input)
4652

53+
exponential_average_factor = 0.0
54+
55+
if self.training and self.track_running_stats:
56+
self.num_batches_tracked += 1
57+
if self.momentum is None: # use cumulative moving average
58+
exponential_average_factor = 1.0 / max(1, self.num_batches_tracked.item())
59+
else: # use exponential moving average
60+
exponential_average_factor = self.momentum
61+
4762
return F.batch_norm(
4863
input, self.running_mean, self.running_var, self.weight, self.bias,
49-
self.training or not self.track_running_stats, self.momentum, self.eps)
64+
self.training or not self.track_running_stats,
65+
exponential_average_factor, self.eps)
5066

5167
def __repr__(self):
5268
return ('{name}({num_features}, eps={eps}, momentum={momentum},'
@@ -93,7 +109,8 @@ class BatchNorm1d(_BatchNorm):
93109
eps: a value added to the denominator for numerical stability.
94110
Default: 1e-5
95111
momentum: the value used for the running_mean and running_var
96-
computation. Default: 0.1
112+
computation. Can be set to ``None`` for cumulative moving average
113+
(i.e. simple average). Default: 0.1
97114
affine: a boolean value that when set to ``True``, this module has
98115
learnable affine parameters. Default: ``True``
99116
track_running_stats: a boolean value that when set to ``True``, this
@@ -162,7 +179,8 @@ class BatchNorm2d(_BatchNorm):
162179
eps: a value added to the denominator for numerical stability.
163180
Default: 1e-5
164181
momentum: the value used for the running_mean and running_var
165-
computation. Default: 0.1
182+
computation. Can be set to ``None`` for cumulative moving average
183+
(i.e. simple average). Default: 0.1
166184
affine: a boolean value that when set to ``True``, this module has
167185
learnable affine parameters. Default: ``True``
168186
track_running_stats: a boolean value that when set to ``True``, this
@@ -232,7 +250,8 @@ class BatchNorm3d(_BatchNorm):
232250
eps: a value added to the denominator for numerical stability.
233251
Default: 1e-5
234252
momentum: the value used for the running_mean and running_var
235-
computation. Default: 0.1
253+
computation. Can be set to ``None`` for cumulative moving average
254+
(i.e. simple average). Default: 0.1
236255
affine: a boolean value that when set to ``True``, this module has
237256
learnable affine parameters. Default: ``True``
238257
track_running_stats: a boolean value that when set to ``True``, this

0 commit comments

Comments
 (0)