Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions test/expect/TestJit.test_batchnorm.expect
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ graph(%0 : Double(2, 2, 2, 2)
%1 : Double(2)
%2 : Double(2)
%3 : Double(2)
%4 : Double(2)) {
%5 : Double(2, 2, 2, 2) = aten::batch_norm[training=1, momentum=0.1, eps=1e-05, cudnn_enabled=1](%0, %1, %2, %3, %4), scope: BatchNorm2d
return (%5);
%4 : Double(2)
%5 : Long(1)) {
%6 : Double(2, 2, 2, 2) = aten::batch_norm[training=1, momentum=0.1, eps=1e-05, cudnn_enabled=1](%0, %1, %2, %3, %4), scope: BatchNorm2d
return (%6);
}
81 changes: 80 additions & 1 deletion test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1779,6 +1779,13 @@ def test_batchnorm_eval(self):
def test_batchnorm_eval_cuda(self):
self._test_batchnorm_eval(torch.cuda.FloatTensor)

def test_batchnorm_simple_average(self):
self._test_batchnorm_simple_average()

@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_batchnorm_simple_average_cuda(self):
self._test_batchnorm_simple_average(torch.cuda.FloatTensor)

def test_MaxPool1d_indices(self):
self._test_maxpool_indices(1)

Expand Down Expand Up @@ -1917,6 +1924,7 @@ def test_replicate_buffers(self):
for i, replica in enumerate(replicas):
self.assertEqual(replica.bn.running_mean.get_device(), i, 'buffer on wrong device')
self.assertEqual(replica.bn.running_var.get_device(), i, 'buffer on wrong device')
self.assertEqual(replica.bn.num_batches_tracked.get_device(), i, 'buffer on wrong device')

@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_parallel_apply(self):
Expand Down Expand Up @@ -2233,7 +2241,7 @@ def test_state_dict(self):
net.add_module('empty', None)

state_dict = net.state_dict()
self.assertEqual(len(state_dict), 9)
self.assertEqual(len(state_dict), 10)
self.assertIn('linear1.weight', state_dict)
self.assertIn('linear1.bias', state_dict)
self.assertIn('linear2.weight', state_dict)
Expand All @@ -2245,6 +2253,7 @@ def test_state_dict(self):
self.assertIn('bn.bias', state_dict)
self.assertIn('bn.running_var', state_dict)
self.assertIn('bn.running_mean', state_dict)
self.assertIn('bn.num_batches_tracked', state_dict)
self.assertFalse(any(map(lambda k: k.startswith('empty'), state_dict.keys())))
for k, v in state_dict.items():
param = net
Expand Down Expand Up @@ -3691,17 +3700,21 @@ def _test_batchnorm_update_stats(self, test_type=torch.FloatTensor):
# training pass
old_running_mean = module.running_mean.clone()
old_running_var = module.running_var.clone()
old_num_batches_tracked = module.num_batches_tracked.clone()
module(data)
self.assertNotEqual(old_running_mean, module.running_mean)
self.assertNotEqual(old_running_var, module.running_var)
self.assertEqual(old_num_batches_tracked + 1, module.num_batches_tracked)

# eval pass
module.eval()
old_running_mean = module.running_mean.clone()
old_running_var = module.running_var.clone()
old_num_batches_tracked = module.num_batches_tracked.clone()
module(data)
self.assertEqual(old_running_mean, module.running_mean)
self.assertEqual(old_running_var, module.running_var)
self.assertEqual(old_num_batches_tracked, module.num_batches_tracked)

def test_batchnorm_update_stats(self):
self._test_batchnorm_update_stats()
Expand Down Expand Up @@ -3790,6 +3803,48 @@ def _test_batchnorm_eval(self, test_type=torch.FloatTensor):
self.assertEqual(res1, res2)
self.assertEqual(grad1, grad2)

def _test_batchnorm_simple_average(self, test_type=torch.FloatTensor):
module = nn.BatchNorm1d(3, momentum=None).type(test_type)
zeros = torch.zeros(3).type(test_type)
ones = torch.ones(3).type(test_type)
self.assertEqual(module.running_mean, zeros)
self.assertEqual(module.running_var, ones)

data1 = torch.rand(4, 3).type(test_type)
data2 = torch.rand(4, 3).type(test_type)

# 1st pass
res1 = module(data1)
running_mean1 = module.running_mean.clone()
running_var1 = module.running_var.clone()
self.assertNotEqual(running_mean1, zeros)
self.assertNotEqual(running_var1, ones)

# reset stats
module.reset_running_stats()
self.assertEqual(module.running_mean, zeros)
self.assertEqual(module.running_var, ones)

# 2nd pass
res2 = module(data2)
running_mean2 = module.running_mean.clone()
running_var2 = module.running_var.clone()
self.assertNotEqual(running_mean2, zeros)
self.assertNotEqual(running_var2, ones)

# reset stats
module.reset_running_stats()
self.assertEqual(module.running_mean, zeros)
self.assertEqual(module.running_var, ones)

# 3rd (combined) pass
res3 = module(data1)
res4 = module(data2)
self.assertEqual(res3, res1)
self.assertEqual(res4, res2)
self.assertAlmostEqual(module.running_mean, (running_mean1 + running_mean2) / 2)
self.assertAlmostEqual(module.running_var, (running_var1 + running_var2) / 2)

def test_pairwise_distance(self):
input1 = Variable(torch.randn(4, 4), requires_grad=True)
input2 = Variable(torch.randn(4, 4), requires_grad=True)
Expand Down Expand Up @@ -5446,6 +5501,14 @@ def multimarginloss_weights_no_reduce_test():
check_eval=True,
desc='3d_input',
),
dict(
module_name='BatchNorm1d',
constructor_args=(10, 1e-3, None),
input_size=(4, 10),
cudnn=True,
check_eval=True,
desc='affine_simple_average',
),
dict(
module_name='BatchNorm1d',
constructor_args=(10, 1e-3, 0.3, False),
Expand Down Expand Up @@ -5477,6 +5540,14 @@ def multimarginloss_weights_no_reduce_test():
cudnn=True,
check_eval=True,
),
dict(
module_name='BatchNorm2d',
constructor_args=(3, 1e-3, None),
input_size=(2, 3, 6, 6),
cudnn=True,
check_eval=True,
desc='2d_simple_average',
),
dict(
module_name='BatchNorm2d',
constructor_args=(3, 1e-3, 0.8),
Expand Down Expand Up @@ -5508,6 +5579,14 @@ def multimarginloss_weights_no_reduce_test():
cudnn=True,
check_eval=True,
),
dict(
module_name='BatchNorm3d',
constructor_args=(3, 1e-3, None),
input_size=(2, 3, 4, 4, 4),
cudnn=True,
check_eval=True,
desc='3d_simple_average',
),
dict(
module_name='BatchNorm3d',
constructor_args=(3, 1e-3, 0.7),
Expand Down
29 changes: 24 additions & 5 deletions torch/nn/modules/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,21 @@ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
if self.track_running_stats:
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.register_buffer('num_batches_tracked', torch.LongTensor([0]))
else:
self.register_parameter('running_mean', None)
self.register_parameter('running_var', None)
self.register_parameter('num_batches_tracked', None)
self.reset_parameters()

def reset_parameters(self):
def reset_running_stats(self):
if self.track_running_stats:
self.running_mean.zero_()
self.running_var.fill_(1)
self.num_batches_tracked.zero_()

def reset_parameters(self):
self.reset_running_stats()
if self.affine:
self.weight.data.uniform_()
self.bias.data.zero_()
Expand All @@ -44,9 +50,19 @@ def _check_input_dim(self, input):
def forward(self, input):
self._check_input_dim(input)

exponential_average_factor = 0.0

if self.training and self.track_running_stats:
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / max(1, self.num_batches_tracked.item())
else: # use exponential moving average
exponential_average_factor = self.momentum

return F.batch_norm(
input, self.running_mean, self.running_var, self.weight, self.bias,
self.training or not self.track_running_stats, self.momentum, self.eps)
self.training or not self.track_running_stats,
exponential_average_factor, self.eps)

def __repr__(self):
return ('{name}({num_features}, eps={eps}, momentum={momentum},'
Expand Down Expand Up @@ -93,7 +109,8 @@ class BatchNorm1d(_BatchNorm):
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
computation. Can be set to ``None`` for cumulative moving average
(i.e. simple average). Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``True``
track_running_stats: a boolean value that when set to ``True``, this
Expand Down Expand Up @@ -162,7 +179,8 @@ class BatchNorm2d(_BatchNorm):
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
computation. Can be set to ``None`` for cumulative moving average
(i.e. simple average). Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``True``
track_running_stats: a boolean value that when set to ``True``, this
Expand Down Expand Up @@ -232,7 +250,8 @@ class BatchNorm3d(_BatchNorm):
eps: a value added to the denominator for numerical stability.
Default: 1e-5
momentum: the value used for the running_mean and running_var
computation. Default: 0.1
computation. Can be set to ``None`` for cumulative moving average
(i.e. simple average). Default: 0.1
affine: a boolean value that when set to ``True``, this module has
learnable affine parameters. Default: ``True``
track_running_stats: a boolean value that when set to ``True``, this
Expand Down