Skip to content

Commit 1e3f728

Browse files
committed
Merge criterion and new criterion tests.
They didn't actually do anything different. [ghstack-poisoned]
1 parent 42269a8 commit 1e3f728

File tree

2 files changed

+32
-74
lines changed

2 files changed

+32
-74
lines changed

test/test_nn.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@
3838
get_function_arglist, load_tests, repeat_test_for_types, ALL_TENSORTYPES, \
3939
ALL_TENSORTYPES2, TemporaryFileName, TEST_WITH_UBSAN, IS_PPC
4040
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, TEST_CUDNN_VERSION
41-
from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, NewCriterionTest, \
42-
module_tests, criterion_tests, new_criterion_tests, loss_reference_fns, \
41+
from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \
42+
module_tests, criterion_tests, loss_reference_fns, \
4343
ctcloss_reference, new_module_tests
4444
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes, \
4545
dtypesIfCUDA, skipCUDAIfNoCudnn, skipCUDAIfCudnnVersionLessThan, onlyCUDA, \
@@ -9010,10 +9010,10 @@ def reference_fn(i, p, m):
90109010

90119011
add_test(test, decorator)
90129012

9013-
for test_params in criterion_tests + new_criterion_tests:
9013+
for test_params in criterion_tests:
90149014
name = test_params.pop('module_name')
90159015
test_params['constructor'] = getattr(nn, name)
9016-
test = NewCriterionTest(**test_params)
9016+
test = CriterionTest(**test_params)
90179017
decorator = test_params.pop('decorator', None)
90189018
add_test(test, decorator)
90199019
if 'check_sum_reduction' in test_params:
@@ -9028,7 +9028,7 @@ def sum_reduction_constructor(*args, **kwargs):
90289028
return sum_reduction_constructor
90299029

90309030
test_params['constructor'] = gen_sum_reduction_constructor(test_params['constructor'])
9031-
test = NewCriterionTest(**test_params)
9031+
test = CriterionTest(**test_params)
90329032
add_test(test, decorator)
90339033

90349034

torch/testing/_internal/common_nn.py

Lines changed: 27 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -4086,9 +4086,6 @@ def padding3d_circular(input, pad):
40864086
desc='margin',
40874087
check_sum_reduction=True,
40884088
),
4089-
]
4090-
4091-
new_criterion_tests = [
40924089
dict(
40934090
module_name='BCEWithLogitsLoss',
40944091
input_fn=lambda: torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2),
@@ -4789,71 +4786,6 @@ def test_cuda(self, test_case):
47894786
raise
47904787

47914788

4792-
class CriterionTest(TestBase):
4793-
4794-
_required_arg_names = TestBase._required_arg_names.union({'target'})
4795-
4796-
def __init__(self, *args, **kwargs):
4797-
super().__init__(*args, **kwargs)
4798-
self.should_test_cuda = kwargs.get('test_cuda', True)
4799-
self.check_forward_only = kwargs.get('check_forward_only', True)
4800-
4801-
def _get_target(self):
4802-
return self._get_arg('target', True)
4803-
4804-
def __call__(self, test_case):
4805-
module = self.constructor(*self.constructor_args)
4806-
input = self._get_input()
4807-
4808-
# Check that these methods don't raise errors
4809-
module.__repr__()
4810-
str(module)
4811-
4812-
target = self._get_target()
4813-
4814-
if self.reference_fn is not None:
4815-
out = test_case._forward_criterion(module, input, target, extra_args=self.extra_args)
4816-
ref_args = (deepcopy(input), deepcopy(target)) + self.extra_args + (module,)
4817-
expected_out = self.reference_fn(*ref_args)
4818-
test_case.assertEqual(out, expected_out)
4819-
4820-
if self.check_forward_only:
4821-
return
4822-
4823-
test_case.check_criterion_jacobian(module, input, target)
4824-
self._do_extra_tests(test_case, module, input, target)
4825-
4826-
def test_cuda(self, test_case):
4827-
if not TEST_CUDA or not self.should_test_cuda:
4828-
raise unittest.SkipTest('Excluded from CUDA tests')
4829-
try:
4830-
cpu_input = self._get_input()
4831-
type_map = {
4832-
'torch.DoubleTensor': torch.cuda.FloatTensor,
4833-
}
4834-
gpu_input = to_gpu(cpu_input, type_map=type_map)
4835-
4836-
cpu_target = self._get_target()
4837-
gpu_target = to_gpu(cpu_target, type_map=type_map)
4838-
4839-
cpu_module = self.constructor(*self.constructor_args)
4840-
gpu_module = self.constructor(*self.constructor_args).float().cuda()
4841-
4842-
cpu_output = test_case._forward_criterion(cpu_module, cpu_input, cpu_target)
4843-
gpu_output = test_case._forward_criterion(gpu_module, gpu_input, gpu_target)
4844-
test_case.assertEqual(cpu_output, gpu_output, atol=4e-4, rtol=0)
4845-
4846-
gradOutput = torch.randn(())
4847-
cpu_gradInput = test_case._backward_criterion(cpu_module, cpu_input, cpu_target, gradOutput)
4848-
gpu_gradInput = test_case._backward_criterion(gpu_module, gpu_input, gpu_target, gradOutput)
4849-
test_case.assertEqual(cpu_gradInput, gpu_gradInput, atol=4e-4, rtol=0)
4850-
except NotImplementedError:
4851-
pass
4852-
4853-
def _do_extra_tests(self, test_case, module, input, target):
4854-
pass
4855-
4856-
48574789
class InputVariableMixin(object):
48584790
def _get_input(self):
48594791
input = TestBase._get_input(self, False)
@@ -5024,17 +4956,43 @@ def constructor_args(self):
50244956
return self._get_arg('constructor_args', False)
50254957

50264958

5027-
class NewCriterionTest(InputVariableMixin, CriterionTest):
4959+
class CriterionTest(InputVariableMixin, TestBase):
50284960
# TODO: check that criterions don't ignore grad_output
50294961

4962+
_required_arg_names = TestBase._required_arg_names.union({'target'})
4963+
50304964
def __init__(self, *args, **kwargs):
50314965
super().__init__(*args, **kwargs)
4966+
self.should_test_cuda = kwargs.get('test_cuda', True)
4967+
self.check_forward_only = kwargs.get('check_forward_only', True)
50324968
self.check_gradgrad = kwargs.get('check_gradgrad', True)
50334969
self.check_half = kwargs.get('check_half', True)
50344970
self.check_bfloat16 = kwargs.get('check_bfloat16', False)
50354971
self.convert_target = kwargs.get('convert_target', True)
50364972
self.test_cpu = kwargs.get('test_cpu', True)
50374973

4974+
def __call__(self, test_case):
4975+
module = self.constructor(*self.constructor_args)
4976+
input = self._get_input()
4977+
4978+
# Check that these methods don't raise errors
4979+
module.__repr__()
4980+
str(module)
4981+
4982+
target = self._get_target()
4983+
4984+
if self.reference_fn is not None:
4985+
out = test_case._forward_criterion(module, input, target, extra_args=self.extra_args)
4986+
ref_args = (deepcopy(input), deepcopy(target)) + self.extra_args + (module,)
4987+
expected_out = self.reference_fn(*ref_args)
4988+
test_case.assertEqual(out, expected_out)
4989+
4990+
if self.check_forward_only:
4991+
return
4992+
4993+
test_case.check_criterion_jacobian(module, input, target)
4994+
self._do_extra_tests(test_case, module, input, target)
4995+
50384996
def _do_extra_tests(self, test_case, module, input, target):
50394997
if not self.check_gradgrad:
50404998
return

0 commit comments

Comments
 (0)