@@ -1781,6 +1781,13 @@ def test_batchnorm_eval(self):
17811781 def test_batchnorm_eval_cuda (self ):
17821782 self ._test_batchnorm_eval (torch .cuda .FloatTensor )
17831783
1784+ def test_batchnorm_simple_average (self ):
1785+ self ._test_batchnorm_simple_average ()
1786+
1787+ @unittest .skipIf (not TEST_CUDA , "CUDA unavailable" )
1788+ def test_batchnorm_simple_average_cuda (self ):
1789+ self ._test_batchnorm_simple_average (torch .cuda .FloatTensor )
1790+
17841791 def test_MaxPool1d_indices (self ):
17851792 self ._test_maxpool_indices (1 )
17861793
@@ -1919,6 +1926,7 @@ def test_replicate_buffers(self):
19191926 for i , replica in enumerate (replicas ):
19201927 self .assertEqual (replica .bn .running_mean .get_device (), i , 'buffer on wrong device' )
19211928 self .assertEqual (replica .bn .running_var .get_device (), i , 'buffer on wrong device' )
1929+ self .assertEqual (replica .bn .num_batches_tracked .get_device (), i , 'buffer on wrong device' )
19221930
19231931 @unittest .skipIf (not TEST_MULTIGPU , "multi-GPU not supported" )
19241932 def test_parallel_apply (self ):
@@ -2235,7 +2243,7 @@ def test_state_dict(self):
22352243 net .add_module ('empty' , None )
22362244
22372245 state_dict = net .state_dict ()
2238- self .assertEqual (len (state_dict ), 9 )
2246+ self .assertEqual (len (state_dict ), 10 )
22392247 self .assertIn ('linear1.weight' , state_dict )
22402248 self .assertIn ('linear1.bias' , state_dict )
22412249 self .assertIn ('linear2.weight' , state_dict )
@@ -2247,6 +2255,7 @@ def test_state_dict(self):
22472255 self .assertIn ('bn.bias' , state_dict )
22482256 self .assertIn ('bn.running_var' , state_dict )
22492257 self .assertIn ('bn.running_mean' , state_dict )
2258+ self .assertIn ('bn.num_batches_tracked' , state_dict )
22502259 self .assertFalse (any (map (lambda k : k .startswith ('empty' ), state_dict .keys ())))
22512260 for k , v in state_dict .items ():
22522261 param = net
@@ -3693,17 +3702,21 @@ def _test_batchnorm_update_stats(self, test_type=torch.FloatTensor):
36933702 # training pass
36943703 old_running_mean = module .running_mean .clone ()
36953704 old_running_var = module .running_var .clone ()
3705+ old_num_batches_tracked = module .num_batches_tracked .clone ()
36963706 module (data )
36973707 self .assertNotEqual (old_running_mean , module .running_mean )
36983708 self .assertNotEqual (old_running_var , module .running_var )
3709+ self .assertEqual (old_num_batches_tracked + 1 , module .num_batches_tracked )
36993710
37003711 # eval pass
37013712 module .eval ()
37023713 old_running_mean = module .running_mean .clone ()
37033714 old_running_var = module .running_var .clone ()
3715+ old_num_batches_tracked = module .num_batches_tracked .clone ()
37043716 module (data )
37053717 self .assertEqual (old_running_mean , module .running_mean )
37063718 self .assertEqual (old_running_var , module .running_var )
3719+ self .assertEqual (old_num_batches_tracked , module .num_batches_tracked )
37073720
37083721 def test_batchnorm_update_stats (self ):
37093722 self ._test_batchnorm_update_stats ()
@@ -3792,6 +3805,48 @@ def _test_batchnorm_eval(self, test_type=torch.FloatTensor):
37923805 self .assertEqual (res1 , res2 )
37933806 self .assertEqual (grad1 , grad2 )
37943807
3808+ def _test_batchnorm_simple_average (self , test_type = torch .FloatTensor ):
3809+ module = nn .BatchNorm1d (3 , momentum = None ).type (test_type )
3810+ zeros = torch .zeros (3 ).type (test_type )
3811+ ones = torch .ones (3 ).type (test_type )
3812+ self .assertEqual (module .running_mean , zeros )
3813+ self .assertEqual (module .running_var , ones )
3814+
3815+ data1 = torch .rand (4 , 3 ).type (test_type )
3816+ data2 = torch .rand (4 , 3 ).type (test_type )
3817+
3818+ # 1st pass
3819+ res1 = module (data1 )
3820+ running_mean1 = module .running_mean .clone ()
3821+ running_var1 = module .running_var .clone ()
3822+ self .assertNotEqual (running_mean1 , zeros )
3823+ self .assertNotEqual (running_var1 , ones )
3824+
3825+ # reset stats
3826+ module .reset_running_stats ()
3827+ self .assertEqual (module .running_mean , zeros )
3828+ self .assertEqual (module .running_var , ones )
3829+
3830+ # 2nd pass
3831+ res2 = module (data2 )
3832+ running_mean2 = module .running_mean .clone ()
3833+ running_var2 = module .running_var .clone ()
3834+ self .assertNotEqual (running_mean2 , zeros )
3835+ self .assertNotEqual (running_var2 , ones )
3836+
3837+ # reset stats
3838+ module .reset_running_stats ()
3839+ self .assertEqual (module .running_mean , zeros )
3840+ self .assertEqual (module .running_var , ones )
3841+
3842+ # 3rd (combined) pass
3843+ res3 = module (data1 )
3844+ res4 = module (data2 )
3845+ self .assertEqual (res3 , res1 )
3846+ self .assertEqual (res4 , res2 )
3847+ self .assertAlmostEqual (module .running_mean , (running_mean1 + running_mean2 ) / 2 )
3848+ self .assertAlmostEqual (module .running_var , (running_var1 + running_var2 ) / 2 )
3849+
37953850 def test_pairwise_distance (self ):
37963851 input1 = Variable (torch .randn (4 , 4 ), requires_grad = True )
37973852 input2 = Variable (torch .randn (4 , 4 ), requires_grad = True )
@@ -5477,6 +5532,14 @@ def multimarginloss_weights_no_reduce_test():
54775532 check_eval = True ,
54785533 desc = '3d_input' ,
54795534 ),
5535+ dict (
5536+ module_name = 'BatchNorm1d' ,
5537+ constructor_args = (10 , 1e-3 , None ),
5538+ input_size = (4 , 10 ),
5539+ cudnn = True ,
5540+ check_eval = True ,
5541+ desc = 'affine_simple_average' ,
5542+ ),
54805543 dict (
54815544 module_name = 'BatchNorm1d' ,
54825545 constructor_args = (10 , 1e-3 , 0.3 , False ),
@@ -5508,6 +5571,14 @@ def multimarginloss_weights_no_reduce_test():
55085571 cudnn = True ,
55095572 check_eval = True ,
55105573 ),
5574+ dict (
5575+ module_name = 'BatchNorm2d' ,
5576+ constructor_args = (3 , 1e-3 , None ),
5577+ input_size = (2 , 3 , 6 , 6 ),
5578+ cudnn = True ,
5579+ check_eval = True ,
5580+ desc = '2d_simple_average' ,
5581+ ),
55115582 dict (
55125583 module_name = 'BatchNorm2d' ,
55135584 constructor_args = (3 , 1e-3 , 0.8 ),
@@ -5539,6 +5610,14 @@ def multimarginloss_weights_no_reduce_test():
55395610 cudnn = True ,
55405611 check_eval = True ,
55415612 ),
5613+ dict (
5614+ module_name = 'BatchNorm3d' ,
5615+ constructor_args = (3 , 1e-3 , None ),
5616+ input_size = (2 , 3 , 4 , 4 , 4 ),
5617+ cudnn = True ,
5618+ check_eval = True ,
5619+ desc = '3d_simple_average' ,
5620+ ),
55425621 dict (
55435622 module_name = 'BatchNorm3d' ,
55445623 constructor_args = (3 , 1e-3 , 0.7 ),
0 commit comments