@@ -4086,9 +4086,6 @@ def padding3d_circular(input, pad):
40864086 desc = 'margin' ,
40874087 check_sum_reduction = True ,
40884088 ),
4089- ]
4090-
4091- new_criterion_tests = [
40924089 dict (
40934090 module_name = 'BCEWithLogitsLoss' ,
40944091 input_fn = lambda : torch .rand (15 , 10 ).clamp_ (1e-2 , 1 - 1e-2 ),
@@ -4789,71 +4786,6 @@ def test_cuda(self, test_case):
47894786 raise
47904787
47914788
4792- class CriterionTest (TestBase ):
4793-
4794- _required_arg_names = TestBase ._required_arg_names .union ({'target' })
4795-
4796- def __init__ (self , * args , ** kwargs ):
4797- super ().__init__ (* args , ** kwargs )
4798- self .should_test_cuda = kwargs .get ('test_cuda' , True )
4799- self .check_forward_only = kwargs .get ('check_forward_only' , True )
4800-
4801- def _get_target (self ):
4802- return self ._get_arg ('target' , True )
4803-
4804- def __call__ (self , test_case ):
4805- module = self .constructor (* self .constructor_args )
4806- input = self ._get_input ()
4807-
4808- # Check that these methods don't raise errors
4809- module .__repr__ ()
4810- str (module )
4811-
4812- target = self ._get_target ()
4813-
4814- if self .reference_fn is not None :
4815- out = test_case ._forward_criterion (module , input , target , extra_args = self .extra_args )
4816- ref_args = (deepcopy (input ), deepcopy (target )) + self .extra_args + (module ,)
4817- expected_out = self .reference_fn (* ref_args )
4818- test_case .assertEqual (out , expected_out )
4819-
4820- if self .check_forward_only :
4821- return
4822-
4823- test_case .check_criterion_jacobian (module , input , target )
4824- self ._do_extra_tests (test_case , module , input , target )
4825-
4826- def test_cuda (self , test_case ):
4827- if not TEST_CUDA or not self .should_test_cuda :
4828- raise unittest .SkipTest ('Excluded from CUDA tests' )
4829- try :
4830- cpu_input = self ._get_input ()
4831- type_map = {
4832- 'torch.DoubleTensor' : torch .cuda .FloatTensor ,
4833- }
4834- gpu_input = to_gpu (cpu_input , type_map = type_map )
4835-
4836- cpu_target = self ._get_target ()
4837- gpu_target = to_gpu (cpu_target , type_map = type_map )
4838-
4839- cpu_module = self .constructor (* self .constructor_args )
4840- gpu_module = self .constructor (* self .constructor_args ).float ().cuda ()
4841-
4842- cpu_output = test_case ._forward_criterion (cpu_module , cpu_input , cpu_target )
4843- gpu_output = test_case ._forward_criterion (gpu_module , gpu_input , gpu_target )
4844- test_case .assertEqual (cpu_output , gpu_output , atol = 4e-4 , rtol = 0 )
4845-
4846- gradOutput = torch .randn (())
4847- cpu_gradInput = test_case ._backward_criterion (cpu_module , cpu_input , cpu_target , gradOutput )
4848- gpu_gradInput = test_case ._backward_criterion (gpu_module , gpu_input , gpu_target , gradOutput )
4849- test_case .assertEqual (cpu_gradInput , gpu_gradInput , atol = 4e-4 , rtol = 0 )
4850- except NotImplementedError :
4851- pass
4852-
4853- def _do_extra_tests (self , test_case , module , input , target ):
4854- pass
4855-
4856-
48574789class InputVariableMixin (object ):
48584790 def _get_input (self ):
48594791 input = TestBase ._get_input (self , False )
@@ -5024,17 +4956,43 @@ def constructor_args(self):
50244956 return self ._get_arg ('constructor_args' , False )
50254957
50264958
5027- class NewCriterionTest (InputVariableMixin , CriterionTest ):
4959+ class CriterionTest (InputVariableMixin , TestBase ):
50284960 # TODO: check that criterions don't ignore grad_output
50294961
4962+ _required_arg_names = TestBase ._required_arg_names .union ({'target' })
4963+
50304964 def __init__ (self , * args , ** kwargs ):
50314965 super ().__init__ (* args , ** kwargs )
4966+ self .should_test_cuda = kwargs .get ('test_cuda' , True )
4967+ self .check_forward_only = kwargs .get ('check_forward_only' , True )
50324968 self .check_gradgrad = kwargs .get ('check_gradgrad' , True )
50334969 self .check_half = kwargs .get ('check_half' , True )
50344970 self .check_bfloat16 = kwargs .get ('check_bfloat16' , False )
50354971 self .convert_target = kwargs .get ('convert_target' , True )
50364972 self .test_cpu = kwargs .get ('test_cpu' , True )
50374973
4974+ def __call__ (self , test_case ):
4975+ module = self .constructor (* self .constructor_args )
4976+ input = self ._get_input ()
4977+
4978+ # Check that these methods don't raise errors
4979+ module .__repr__ ()
4980+ str (module )
4981+
4982+ target = self ._get_target ()
4983+
4984+ if self .reference_fn is not None :
4985+ out = test_case ._forward_criterion (module , input , target , extra_args = self .extra_args )
4986+ ref_args = (deepcopy (input ), deepcopy (target )) + self .extra_args + (module ,)
4987+ expected_out = self .reference_fn (* ref_args )
4988+ test_case .assertEqual (out , expected_out )
4989+
4990+ if self .check_forward_only :
4991+ return
4992+
4993+ test_case .check_criterion_jacobian (module , input , target )
4994+ self ._do_extra_tests (test_case , module , input , target )
4995+
50384996 def _do_extra_tests (self , test_case , module , input , target ):
50394997 if not self .check_gradgrad :
50404998 return
0 commit comments