Skip to content

Commit d30e280

Browse files
authored
Revert "Enable resetting of batchnorm running moments and cumulative ("simple") moving average (#5766)"
This reverts commit 99b1f6c.
1 parent d0eddf1 commit d30e280

File tree

3 files changed

+9
-108
lines changed

3 files changed

+9
-108
lines changed

test/expect/TestJit.test_batchnorm.expect

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@ graph(%0 : Double(2, 2, 2, 2)
22
%1 : Double(2)
33
%2 : Double(2)
44
%3 : Double(2)
5-
%4 : Double(2)
6-
%5 : Long(1)) {
7-
%6 : Double(2, 2, 2, 2) = aten::batch_norm[training=1, momentum=0.1, eps=1e-05, cudnn_enabled=1](%0, %1, %2, %3, %4), scope: BatchNorm2d
8-
return (%6);
5+
%4 : Double(2)) {
6+
%5 : Double(2, 2, 2, 2) = aten::batch_norm[training=1, momentum=0.1, eps=1e-05, cudnn_enabled=1](%0, %1, %2, %3, %4), scope: BatchNorm2d
7+
return (%5);
98
}

test/test_nn.py

Lines changed: 1 addition & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1781,13 +1781,6 @@ def test_batchnorm_eval(self):
17811781
def test_batchnorm_eval_cuda(self):
17821782
self._test_batchnorm_eval(torch.cuda.FloatTensor)
17831783

1784-
def test_batchnorm_simple_average(self):
1785-
self._test_batchnorm_simple_average()
1786-
1787-
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
1788-
def test_batchnorm_simple_average_cuda(self):
1789-
self._test_batchnorm_simple_average(torch.cuda.FloatTensor)
1790-
17911784
def test_MaxPool1d_indices(self):
17921785
self._test_maxpool_indices(1)
17931786

@@ -1926,7 +1919,6 @@ def test_replicate_buffers(self):
19261919
for i, replica in enumerate(replicas):
19271920
self.assertEqual(replica.bn.running_mean.get_device(), i, 'buffer on wrong device')
19281921
self.assertEqual(replica.bn.running_var.get_device(), i, 'buffer on wrong device')
1929-
self.assertEqual(replica.bn.num_batches_tracked.get_device(), i, 'buffer on wrong device')
19301922

19311923
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
19321924
def test_parallel_apply(self):
@@ -2243,7 +2235,7 @@ def test_state_dict(self):
22432235
net.add_module('empty', None)
22442236

22452237
state_dict = net.state_dict()
2246-
self.assertEqual(len(state_dict), 10)
2238+
self.assertEqual(len(state_dict), 9)
22472239
self.assertIn('linear1.weight', state_dict)
22482240
self.assertIn('linear1.bias', state_dict)
22492241
self.assertIn('linear2.weight', state_dict)
@@ -2255,7 +2247,6 @@ def test_state_dict(self):
22552247
self.assertIn('bn.bias', state_dict)
22562248
self.assertIn('bn.running_var', state_dict)
22572249
self.assertIn('bn.running_mean', state_dict)
2258-
self.assertIn('bn.num_batches_tracked', state_dict)
22592250
self.assertFalse(any(map(lambda k: k.startswith('empty'), state_dict.keys())))
22602251
for k, v in state_dict.items():
22612252
param = net
@@ -3702,21 +3693,17 @@ def _test_batchnorm_update_stats(self, test_type=torch.FloatTensor):
37023693
# training pass
37033694
old_running_mean = module.running_mean.clone()
37043695
old_running_var = module.running_var.clone()
3705-
old_num_batches_tracked = module.num_batches_tracked.clone()
37063696
module(data)
37073697
self.assertNotEqual(old_running_mean, module.running_mean)
37083698
self.assertNotEqual(old_running_var, module.running_var)
3709-
self.assertEqual(old_num_batches_tracked + 1, module.num_batches_tracked)
37103699

37113700
# eval pass
37123701
module.eval()
37133702
old_running_mean = module.running_mean.clone()
37143703
old_running_var = module.running_var.clone()
3715-
old_num_batches_tracked = module.num_batches_tracked.clone()
37163704
module(data)
37173705
self.assertEqual(old_running_mean, module.running_mean)
37183706
self.assertEqual(old_running_var, module.running_var)
3719-
self.assertEqual(old_num_batches_tracked, module.num_batches_tracked)
37203707

37213708
def test_batchnorm_update_stats(self):
37223709
self._test_batchnorm_update_stats()
@@ -3805,48 +3792,6 @@ def _test_batchnorm_eval(self, test_type=torch.FloatTensor):
38053792
self.assertEqual(res1, res2)
38063793
self.assertEqual(grad1, grad2)
38073794

3808-
def _test_batchnorm_simple_average(self, test_type=torch.FloatTensor):
3809-
module = nn.BatchNorm1d(3, momentum=None).type(test_type)
3810-
zeros = torch.zeros(3).type(test_type)
3811-
ones = torch.ones(3).type(test_type)
3812-
self.assertEqual(module.running_mean, zeros)
3813-
self.assertEqual(module.running_var, ones)
3814-
3815-
data1 = torch.rand(4, 3).type(test_type)
3816-
data2 = torch.rand(4, 3).type(test_type)
3817-
3818-
# 1st pass
3819-
res1 = module(data1)
3820-
running_mean1 = module.running_mean.clone()
3821-
running_var1 = module.running_var.clone()
3822-
self.assertNotEqual(running_mean1, zeros)
3823-
self.assertNotEqual(running_var1, ones)
3824-
3825-
# reset stats
3826-
module.reset_running_stats()
3827-
self.assertEqual(module.running_mean, zeros)
3828-
self.assertEqual(module.running_var, ones)
3829-
3830-
# 2nd pass
3831-
res2 = module(data2)
3832-
running_mean2 = module.running_mean.clone()
3833-
running_var2 = module.running_var.clone()
3834-
self.assertNotEqual(running_mean2, zeros)
3835-
self.assertNotEqual(running_var2, ones)
3836-
3837-
# reset stats
3838-
module.reset_running_stats()
3839-
self.assertEqual(module.running_mean, zeros)
3840-
self.assertEqual(module.running_var, ones)
3841-
3842-
# 3rd (combined) pass
3843-
res3 = module(data1)
3844-
res4 = module(data2)
3845-
self.assertEqual(res3, res1)
3846-
self.assertEqual(res4, res2)
3847-
self.assertAlmostEqual(module.running_mean, (running_mean1 + running_mean2) / 2)
3848-
self.assertAlmostEqual(module.running_var, (running_var1 + running_var2) / 2)
3849-
38503795
def test_pairwise_distance(self):
38513796
input1 = Variable(torch.randn(4, 4), requires_grad=True)
38523797
input2 = Variable(torch.randn(4, 4), requires_grad=True)
@@ -5532,14 +5477,6 @@ def multimarginloss_weights_no_reduce_test():
55325477
check_eval=True,
55335478
desc='3d_input',
55345479
),
5535-
dict(
5536-
module_name='BatchNorm1d',
5537-
constructor_args=(10, 1e-3, None),
5538-
input_size=(4, 10),
5539-
cudnn=True,
5540-
check_eval=True,
5541-
desc='affine_simple_average',
5542-
),
55435480
dict(
55445481
module_name='BatchNorm1d',
55455482
constructor_args=(10, 1e-3, 0.3, False),
@@ -5571,14 +5508,6 @@ def multimarginloss_weights_no_reduce_test():
55715508
cudnn=True,
55725509
check_eval=True,
55735510
),
5574-
dict(
5575-
module_name='BatchNorm2d',
5576-
constructor_args=(3, 1e-3, None),
5577-
input_size=(2, 3, 6, 6),
5578-
cudnn=True,
5579-
check_eval=True,
5580-
desc='2d_simple_average',
5581-
),
55825511
dict(
55835512
module_name='BatchNorm2d',
55845513
constructor_args=(3, 1e-3, 0.8),
@@ -5610,14 +5539,6 @@ def multimarginloss_weights_no_reduce_test():
56105539
cudnn=True,
56115540
check_eval=True,
56125541
),
5613-
dict(
5614-
module_name='BatchNorm3d',
5615-
constructor_args=(3, 1e-3, None),
5616-
input_size=(2, 3, 4, 4, 4),
5617-
cudnn=True,
5618-
check_eval=True,
5619-
desc='3d_simple_average',
5620-
),
56215542
dict(
56225543
module_name='BatchNorm3d',
56235544
constructor_args=(3, 1e-3, 0.7),

torch/nn/modules/batchnorm.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,15 @@ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
2525
if self.track_running_stats:
2626
self.register_buffer('running_mean', torch.zeros(num_features))
2727
self.register_buffer('running_var', torch.ones(num_features))
28-
self.register_buffer('num_batches_tracked', torch.LongTensor([0]))
2928
else:
3029
self.register_parameter('running_mean', None)
3130
self.register_parameter('running_var', None)
32-
self.register_parameter('num_batches_tracked', None)
3331
self.reset_parameters()
3432

35-
def reset_running_stats(self):
33+
def reset_parameters(self):
3634
if self.track_running_stats:
3735
self.running_mean.zero_()
3836
self.running_var.fill_(1)
39-
self.num_batches_tracked.zero_()
40-
41-
def reset_parameters(self):
42-
self.reset_running_stats()
4337
if self.affine:
4438
self.weight.data.uniform_()
4539
self.bias.data.zero_()
@@ -50,19 +44,9 @@ def _check_input_dim(self, input):
5044
def forward(self, input):
5145
self._check_input_dim(input)
5246

53-
exponential_average_factor = 0.0
54-
55-
if self.training and self.track_running_stats:
56-
self.num_batches_tracked += 1
57-
if self.momentum is None: # use cumulative moving average
58-
exponential_average_factor = 1.0 / max(1, self.num_batches_tracked.item())
59-
else: # use exponential moving average
60-
exponential_average_factor = self.momentum
61-
6247
return F.batch_norm(
6348
input, self.running_mean, self.running_var, self.weight, self.bias,
64-
self.training or not self.track_running_stats,
65-
exponential_average_factor, self.eps)
49+
self.training or not self.track_running_stats, self.momentum, self.eps)
6650

6751
def __repr__(self):
6852
return ('{name}({num_features}, eps={eps}, momentum={momentum},'
@@ -109,8 +93,7 @@ class BatchNorm1d(_BatchNorm):
10993
eps: a value added to the denominator for numerical stability.
11094
Default: 1e-5
11195
momentum: the value used for the running_mean and running_var
112-
computation. Can be set to ``None`` for cumulative moving average
113-
(i.e. simple average). Default: 0.1
96+
computation. Default: 0.1
11497
affine: a boolean value that when set to ``True``, this module has
11598
learnable affine parameters. Default: ``True``
11699
track_running_stats: a boolean value that when set to ``True``, this
@@ -179,8 +162,7 @@ class BatchNorm2d(_BatchNorm):
179162
eps: a value added to the denominator for numerical stability.
180163
Default: 1e-5
181164
momentum: the value used for the running_mean and running_var
182-
computation. Can be set to ``None`` for cumulative moving average
183-
(i.e. simple average). Default: 0.1
165+
computation. Default: 0.1
184166
affine: a boolean value that when set to ``True``, this module has
185167
learnable affine parameters. Default: ``True``
186168
track_running_stats: a boolean value that when set to ``True``, this
@@ -250,8 +232,7 @@ class BatchNorm3d(_BatchNorm):
250232
eps: a value added to the denominator for numerical stability.
251233
Default: 1e-5
252234
momentum: the value used for the running_mean and running_var
253-
computation. Can be set to ``None`` for cumulative moving average
254-
(i.e. simple average). Default: 0.1
235+
computation. Default: 0.1
255236
affine: a boolean value that when set to ``True``, this module has
256237
learnable affine parameters. Default: ``True``
257238
track_running_stats: a boolean value that when set to ``True``, this

0 commit comments

Comments
 (0)