@@ -1781,13 +1781,6 @@ def test_batchnorm_eval(self):
17811781 def test_batchnorm_eval_cuda (self ):
17821782 self ._test_batchnorm_eval (torch .cuda .FloatTensor )
17831783
1784- def test_batchnorm_simple_average (self ):
1785- self ._test_batchnorm_simple_average ()
1786-
1787- @unittest .skipIf (not TEST_CUDA , "CUDA unavailable" )
1788- def test_batchnorm_simple_average_cuda (self ):
1789- self ._test_batchnorm_simple_average (torch .cuda .FloatTensor )
1790-
17911784 def test_MaxPool1d_indices (self ):
17921785 self ._test_maxpool_indices (1 )
17931786
@@ -1926,7 +1919,6 @@ def test_replicate_buffers(self):
19261919 for i , replica in enumerate (replicas ):
19271920 self .assertEqual (replica .bn .running_mean .get_device (), i , 'buffer on wrong device' )
19281921 self .assertEqual (replica .bn .running_var .get_device (), i , 'buffer on wrong device' )
1929- self .assertEqual (replica .bn .num_batches_tracked .get_device (), i , 'buffer on wrong device' )
19301922
19311923 @unittest .skipIf (not TEST_MULTIGPU , "multi-GPU not supported" )
19321924 def test_parallel_apply (self ):
@@ -2243,7 +2235,7 @@ def test_state_dict(self):
22432235 net .add_module ('empty' , None )
22442236
22452237 state_dict = net .state_dict ()
2246- self .assertEqual (len (state_dict ), 10 )
2238+ self .assertEqual (len (state_dict ), 9 )
22472239 self .assertIn ('linear1.weight' , state_dict )
22482240 self .assertIn ('linear1.bias' , state_dict )
22492241 self .assertIn ('linear2.weight' , state_dict )
@@ -2255,7 +2247,6 @@ def test_state_dict(self):
22552247 self .assertIn ('bn.bias' , state_dict )
22562248 self .assertIn ('bn.running_var' , state_dict )
22572249 self .assertIn ('bn.running_mean' , state_dict )
2258- self .assertIn ('bn.num_batches_tracked' , state_dict )
22592250 self .assertFalse (any (map (lambda k : k .startswith ('empty' ), state_dict .keys ())))
22602251 for k , v in state_dict .items ():
22612252 param = net
@@ -3702,21 +3693,17 @@ def _test_batchnorm_update_stats(self, test_type=torch.FloatTensor):
37023693 # training pass
37033694 old_running_mean = module .running_mean .clone ()
37043695 old_running_var = module .running_var .clone ()
3705- old_num_batches_tracked = module .num_batches_tracked .clone ()
37063696 module (data )
37073697 self .assertNotEqual (old_running_mean , module .running_mean )
37083698 self .assertNotEqual (old_running_var , module .running_var )
3709- self .assertEqual (old_num_batches_tracked + 1 , module .num_batches_tracked )
37103699
37113700 # eval pass
37123701 module .eval ()
37133702 old_running_mean = module .running_mean .clone ()
37143703 old_running_var = module .running_var .clone ()
3715- old_num_batches_tracked = module .num_batches_tracked .clone ()
37163704 module (data )
37173705 self .assertEqual (old_running_mean , module .running_mean )
37183706 self .assertEqual (old_running_var , module .running_var )
3719- self .assertEqual (old_num_batches_tracked , module .num_batches_tracked )
37203707
37213708 def test_batchnorm_update_stats (self ):
37223709 self ._test_batchnorm_update_stats ()
@@ -3805,48 +3792,6 @@ def _test_batchnorm_eval(self, test_type=torch.FloatTensor):
38053792 self .assertEqual (res1 , res2 )
38063793 self .assertEqual (grad1 , grad2 )
38073794
3808- def _test_batchnorm_simple_average (self , test_type = torch .FloatTensor ):
3809- module = nn .BatchNorm1d (3 , momentum = None ).type (test_type )
3810- zeros = torch .zeros (3 ).type (test_type )
3811- ones = torch .ones (3 ).type (test_type )
3812- self .assertEqual (module .running_mean , zeros )
3813- self .assertEqual (module .running_var , ones )
3814-
3815- data1 = torch .rand (4 , 3 ).type (test_type )
3816- data2 = torch .rand (4 , 3 ).type (test_type )
3817-
3818- # 1st pass
3819- res1 = module (data1 )
3820- running_mean1 = module .running_mean .clone ()
3821- running_var1 = module .running_var .clone ()
3822- self .assertNotEqual (running_mean1 , zeros )
3823- self .assertNotEqual (running_var1 , ones )
3824-
3825- # reset stats
3826- module .reset_running_stats ()
3827- self .assertEqual (module .running_mean , zeros )
3828- self .assertEqual (module .running_var , ones )
3829-
3830- # 2nd pass
3831- res2 = module (data2 )
3832- running_mean2 = module .running_mean .clone ()
3833- running_var2 = module .running_var .clone ()
3834- self .assertNotEqual (running_mean2 , zeros )
3835- self .assertNotEqual (running_var2 , ones )
3836-
3837- # reset stats
3838- module .reset_running_stats ()
3839- self .assertEqual (module .running_mean , zeros )
3840- self .assertEqual (module .running_var , ones )
3841-
3842- # 3rd (combined) pass
3843- res3 = module (data1 )
3844- res4 = module (data2 )
3845- self .assertEqual (res3 , res1 )
3846- self .assertEqual (res4 , res2 )
3847- self .assertAlmostEqual (module .running_mean , (running_mean1 + running_mean2 ) / 2 )
3848- self .assertAlmostEqual (module .running_var , (running_var1 + running_var2 ) / 2 )
3849-
38503795 def test_pairwise_distance (self ):
38513796 input1 = Variable (torch .randn (4 , 4 ), requires_grad = True )
38523797 input2 = Variable (torch .randn (4 , 4 ), requires_grad = True )
@@ -5532,14 +5477,6 @@ def multimarginloss_weights_no_reduce_test():
55325477 check_eval = True ,
55335478 desc = '3d_input' ,
55345479 ),
5535- dict (
5536- module_name = 'BatchNorm1d' ,
5537- constructor_args = (10 , 1e-3 , None ),
5538- input_size = (4 , 10 ),
5539- cudnn = True ,
5540- check_eval = True ,
5541- desc = 'affine_simple_average' ,
5542- ),
55435480 dict (
55445481 module_name = 'BatchNorm1d' ,
55455482 constructor_args = (10 , 1e-3 , 0.3 , False ),
@@ -5571,14 +5508,6 @@ def multimarginloss_weights_no_reduce_test():
55715508 cudnn = True ,
55725509 check_eval = True ,
55735510 ),
5574- dict (
5575- module_name = 'BatchNorm2d' ,
5576- constructor_args = (3 , 1e-3 , None ),
5577- input_size = (2 , 3 , 6 , 6 ),
5578- cudnn = True ,
5579- check_eval = True ,
5580- desc = '2d_simple_average' ,
5581- ),
55825511 dict (
55835512 module_name = 'BatchNorm2d' ,
55845513 constructor_args = (3 , 1e-3 , 0.8 ),
@@ -5610,14 +5539,6 @@ def multimarginloss_weights_no_reduce_test():
56105539 cudnn = True ,
56115540 check_eval = True ,
56125541 ),
5613- dict (
5614- module_name = 'BatchNorm3d' ,
5615- constructor_args = (3 , 1e-3 , None ),
5616- input_size = (2 , 3 , 4 , 4 , 4 ),
5617- cudnn = True ,
5618- check_eval = True ,
5619- desc = '3d_simple_average' ,
5620- ),
56215542 dict (
56225543 module_name = 'BatchNorm3d' ,
56235544 constructor_args = (3 , 1e-3 , 0.7 ),
0 commit comments