3737from typing import Dict, List, Tuple, Union
3838import torch.backends.quantized
3939import torch.testing._internal.data
40- from torch.testing._internal.common_cuda import tf32_on_and_off
4140
4241
4342# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
@@ -8441,7 +8440,6 @@ def dims_full_for_fn():
84418440 r1 = fntorch(t0_full, t1, t2)
84428441 self.assertEqual(r0, r1)
84438442
8444- @tf32_on_and_off(0.001)
84458443 def test_broadcast_batched_matmul(self, device):
84468444 n_dim = random.randint(1, 8)
84478445 m_dim = random.randint(1, 8)
@@ -10431,7 +10429,6 @@ def check_norm(a, b, expected_norm, gels_result):
1043110429
1043210430 @skipCUDAIfNoMagma
1043310431 @skipCPUIfNoLapack
10434- @tf32_on_and_off(0.001)
1043510432 def test_qr(self, device):
1043610433 def run_test(tensor_dims, some):
1043710434 A = torch.randn(*tensor_dims, device=device)
@@ -11511,7 +11508,6 @@ def test_cdist_norm_batch(self, device):
1151111508 expected = self._brute_cdist(x, y, p=p)
1151211509 self.assertEqual(expected, actual)
1151311510
11514- @tf32_on_and_off(0.005)
1151511511 def test_cdist_large(self, device):
1151611512 for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1151711513 x = torch.randn(1000, 10, device=device)
@@ -11521,7 +11517,6 @@ def test_cdist_large(self, device):
1152111517 self.assertEqual(expected, actual)
1152211518
1152311519 @slowTest
11524- @tf32_on_and_off(0.01)
1152511520 def test_cdist_large_batch(self, device):
1152611521 for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1152711522 x = torch.randn(4, 3, 1000, 10, device=device)
@@ -11530,7 +11525,6 @@ def test_cdist_large_batch(self, device):
1153011525 expected = self._brute_cdist(x, y, p=2)
1153111526 self.assertEqual(expected, actual)
1153211527
11533- @tf32_on_and_off(0.005)
1153411528 def test_cdist_non_contiguous(self, device):
1153511529 for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1153611530 x = torch.randn(5, 7, device=device).transpose(-1, -2)
@@ -11557,7 +11551,6 @@ def test_cdist_non_contiguous(self, device):
1155711551 self.assertTrue(y.is_contiguous())
1155811552 self.assertEqual(expected, actual)
1155911553
11560- @tf32_on_and_off()
1156111554 def test_cdist_non_contiguous_batch(self, device):
1156211555 for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
1156311556 x = torch.randn(4, 3, 2, 5, 7, device=device).transpose(-1, -2)
@@ -12394,7 +12387,6 @@ def test_empty_tensor_props(self, device):
1239412387 self.assertEqual(x.stride(), y.stride())
1239512388
1239612389 @unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
12397- @tf32_on_and_off(0.005)
1239812390 def test_tensordot(self, device):
1239912391 a = torch.arange(60., device=device).reshape(3, 4, 5)
1240012392 b = torch.arange(24., device=device).reshape(4, 3, 2)
@@ -16478,7 +16470,6 @@ def test_addmm(self, device):
1647816470 @dtypes(torch.float, torch.double)
1647916471 @dtypesIfCUDA(*([torch.float, torch.double] +
1648016472 ([] if TEST_WITH_ROCM else torch.testing.get_all_complex_dtypes())))
16481- @tf32_on_and_off(0.005)
1648216473 def test_addmm_sizes(self, device, dtype):
1648316474 for m in [0, 1, 25]:
1648416475 for n in [0, 1, 10]:
@@ -16928,7 +16919,6 @@ def test_remainder_edge_cases(self, device, dtype):
1692816919 @onlyOnCPUAndCUDA
1692916920 @dtypes(torch.float32, torch.float64, torch.bfloat16, torch.int32, torch.int64, torch.cfloat, torch.cdouble)
1693016921 @dtypesIfCUDA(torch.float32, torch.float64)
16931- @tf32_on_and_off(0.01)
1693216922 def test_mm(self, device, dtype):
1693316923 def _test_mm(n, m, p, dtype, genf):
1693416924 # helper function
@@ -17974,7 +17964,6 @@ def test_pickle_gradscaler(self, device):
1797417964 self.assertEqual(b.scale(torch.tensor([4.0], dtype=torch.float32, device=device)), 12.0)
1797517965
1797617966 @onlyCUDA
17977- @tf32_on_and_off(0.005)
1797817967 def test_mv_stride_0(self, device):
1797917968 # Reference: https://github.com/pytorch/pytorch/issues/38315
1798017969 mat = torch.randn(2, 2, device=device)
@@ -18930,6 +18919,8 @@ def test_split_view(self, device):
1893018919
1893118920_float_types_no_half = [torch.float, torch.double]
1893218921
18922+ _complex_types = [torch.cfloat, torch.cdouble]
18923+
1893318924# _float_types2 adds bfloat16 type to _float_types only on ROCm. Should eventually be unified
1893418925# with _float_types when bfloat16 bringup is complete on all platforms
1893518926_float_types2 = _float_types + [torch.bfloat16] if TEST_WITH_ROCM else _float_types
@@ -19104,13 +19095,13 @@ def inner(self, device, dtype):
1910419095 ('pow', 'tensor', _small_3d, lambda t, d: [_small_3d(t, d).abs()],
1910519096 1e-1, 1e-1, 1e-5, _float_types2),
1910619097 ('addbmm', '', _small_2d, lambda t, d: [_small_3d(t, d), _small_3d(t, d)],
19107- 1e-1, 1e-1, 1e-4, _float_types2, _cpu_types, True, [tf32_on_and_off(0.005)] ),
19098+ 1e-1, 1e-1, 1e-4, _float_types2),
1910819099 ('addbmm', 'scalar', _small_2d, lambda t, d: [_number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)],
1910919100 1e-1, 1e-1, 1e-4, _float_types2, _cpu_types, True,
19110- [tf32_on_and_off(0.005), _wrap_maybe_warns("This overload of addbmm_? is deprecated")]),
19101+ [_wrap_maybe_warns("This overload of addbmm_? is deprecated")]),
1911119102 ('addbmm', 'two_scalars', _small_2d, lambda t, d: [_number(0.5, 3, t), _number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)],
1911219103 1e-1, 1e-1, 1e-4, _float_types2, _cpu_types, True,
19113- [tf32_on_and_off(0.005), _wrap_maybe_warns("This overload of addbmm_? is deprecated")]),
19104+ [_wrap_maybe_warns("This overload of addbmm_? is deprecated")]),
1911419105 ('baddbmm', '', _small_3d, lambda t, d: [_small_3d(t, d), _small_3d(t, d)],
1911519106 1e-2, 1e-1, 1e-4, _float_types2),
1911619107 ('baddbmm', 'scalar', _small_3d, lambda t, d: [_number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)],
@@ -19135,26 +19126,25 @@ def inner(self, device, dtype):
1913519126 1e-1, 1e-5, _types2, _cpu_types, True,
1913619127 [_wrap_maybe_warns("This overload of addcmul_? is deprecated")]),
1913719128 ('addmm', '', _medium_2d, lambda t, d: [_medium_2d(t, d), _medium_2d(t, d)],
19138- 1e-1, 1e-1, 1e-4, _float_types2, _cpu_types, True, [tf32_on_and_off(0.005)] ),
19129+ 1e-1, 1e-1, 1e-4, _float_types2),
1913919130 ('addmm', 'scalar', _medium_2d,
1914019131 lambda t, d: [_number(0.4, 2, t), _medium_2d(t, d), _medium_2d(t, d)],
1914119132 1e-1, 1e-1, 1e-4, _float_types2, _cpu_types, True,
19142- [tf32_on_and_off(0.005), _wrap_maybe_warns("This overload of addmm_? is deprecated")]),
19133+ [_wrap_maybe_warns("This overload of addmm_? is deprecated")]),
1914319134 ('addmm', 'two_scalars', _medium_2d,
1914419135 lambda t, d: [_number(0.5, 3, t), _number(0.4, 2, t), _medium_2d(t, d), _medium_2d(t, d)],
1914519136 1e-1, 1e-1, 1e-4, _float_types2, _cpu_types, True,
19146- [tf32_on_and_off(0.005), _wrap_maybe_warns("This overload of addmm_? is deprecated")]),
19137+ [_wrap_maybe_warns("This overload of addmm_? is deprecated")]),
1914719138 ('addmv', '', _medium_1d, lambda t, d: [_medium_2d(t, d), _medium_1d(t, d)],
19148- 1e-2, 1e-1, 1e-4, _float_types2 + _complex_types_skip_rocm, _cpu_types,
19149- True, [tf32_on_and_off(0.005)]),
19139+ 1e-2, 1e-1, 1e-4, _float_types2 + _complex_types_skip_rocm),
1915019140 ('addmv', 'scalar', _medium_1d,
1915119141 lambda t, d: [_number(0.4, 2, t), _medium_2d(t, d), _medium_1d(t, d)],
1915219142 1e-2, 1e-1, 1e-4, _float_types2 + _complex_types_skip_rocm, _cpu_types, True,
19153- [tf32_on_and_off(0.005), _wrap_maybe_warns("This overload of addmv_? is deprecated")]),
19143+ [_wrap_maybe_warns("This overload of addmv_? is deprecated")]),
1915419144 ('addmv', 'two_scalars', _medium_1d,
1915519145 lambda t, d: [_number(0.5, 3, t), _number(0.4, 2, t), _medium_2d(t, d), _medium_1d(t, d)],
1915619146 1e-2, 1e-1, 1e-4, _float_types2 + _complex_types_skip_rocm, _cpu_types, True,
19157- [tf32_on_and_off(0.005), _wrap_maybe_warns("This overload of addmv_? is deprecated")]),
19147+ [_wrap_maybe_warns("This overload of addmv_? is deprecated")]),
1915819148 ('addr', '', _medium_2d, lambda t, d: [_medium_1d(t, d), _medium_1d(t, d)],
1915919149 1e-2, 1e-1, 1e-4, _float_types2),
1916019150 ('addr', 'scalar', _medium_2d,
0 commit comments