Skip to content

Commit b4b8f53

Browse files
Mike Ruberryfacebook-github-bot
authored andcommitted
Ports most of test_torch.py to generic device type framework (#26232)
Summary: This PR moves many tests in test_torch.py to the generic device type framework. This means that many CUDA tests now run in test_torch.py and there is greater consistency in how tests for many device types are written. One change is that all MAGMA tests are run on the default stream due to intermittent instability running MAGMA on the non-default stream. This is a known issue. Pull Request resolved: #26232 Test Plan: While this PR edits the tests itself, it was validated using two independent methods: (1) The code was reviewed and it was verified that all deleted functions were actually moved. (2) The output of the TestTorch CI was reviewed and test outputs were matched before and after this PR. Differential Revision: D17386370 Pulled By: mruberry fbshipit-source-id: 843d14911bbd52e8aac6861c0d9bc3d0d9418219
1 parent 9f6b6b8 commit b4b8f53

File tree

2 files changed

+10484
-11003
lines changed

2 files changed

+10484
-11003
lines changed

test/test_cuda.py

Lines changed: 1 addition & 252 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from common_methods_invocations import tri_tests_args, tri_large_tests_args, \
2424
_compare_trilu_indices, _compare_large_trilu_indices
2525
from common_utils import TestCase, get_gpu_type, to_gpu, freeze_rng_state, run_tests, \
26-
PY3, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN, skipIfRocm, TEST_NUMPY, TEST_SCIPY, \
26+
PY3, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN, skipIfRocm, \
2727
TEST_WITH_ROCM, load_tests, slowTest, skipCUDANonDefaultStreamIf
2828

2929
# load_tests from common_utils is used to automatically filter tests for
@@ -1086,18 +1086,6 @@ def test_abs_zero(self):
10861086
for num in abs_zeros:
10871087
self.assertGreater(math.copysign(1.0, num), 0.0)
10881088

1089-
def test_bitwise_not(self):
1090-
_TestTorchMixin._test_bitwise_not(self, 'cuda')
1091-
1092-
def test_logical_not(self):
1093-
_TestTorchMixin._test_logical_not(self, 'cuda')
1094-
1095-
def test_logical_xor(self):
1096-
_TestTorchMixin._test_logical_xor(self, 'cuda')
1097-
1098-
def test_isinf(self):
1099-
_TestTorchMixin._test_isinf(self, lambda t: t.cuda())
1100-
11011089
@unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
11021090
def test_arithmetic_large_tensor(self):
11031091
x = torch.empty(2**30, device='cuda')
@@ -1390,9 +1378,6 @@ def test_cat_autogpu(self):
13901378
z = torch.cat([x, y], 0)
13911379
self.assertEqual(z.get_device(), x.get_device())
13921380

1393-
def test_clamp(self):
1394-
_TestTorchMixin._test_clamp(self, 'cuda')
1395-
13961381
def test_cat(self):
13971382
SIZE = 10
13981383
for dim in range(-3, 3):
@@ -1414,12 +1399,6 @@ def test_cat(self):
14141399
z = torch.cat([x, y])
14151400
self.assertEqual(z.size(), (21, SIZE, SIZE))
14161401

1417-
def test_cat_empty_legacy(self):
1418-
_TestTorchMixin._test_cat_empty_legacy(self, use_cuda=True)
1419-
1420-
def test_cat_empty(self):
1421-
_TestTorchMixin._test_cat_empty(self, use_cuda=True)
1422-
14231402
def test_bernoulli(self):
14241403
_TestTorchMixin._test_bernoulli(self, torch.float32, torch.float64, 'cuda')
14251404
_TestTorchMixin._test_bernoulli(self, torch.float32, torch.float16, 'cuda')
@@ -2196,109 +2175,6 @@ def test_prod_large(self):
21962175
def _select_broadcastable_dims(dims_full=None):
21972176
return _TestTorchMixin._select_broadcastable_dims(dims_full)
21982177

2199-
@slowTest
2200-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2201-
def test_inverse_many_batches(self):
2202-
_TestTorchMixin._test_inverse_slow(self, lambda t: t.cuda())
2203-
2204-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2205-
def test_pinverse(self):
2206-
_TestTorchMixin._test_pinverse(self, lambda t: t.cuda())
2207-
2208-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2209-
def test_matrix_rank(self):
2210-
_TestTorchMixin._test_matrix_rank(self, lambda x: x.cuda())
2211-
2212-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2213-
def test_matrix_power(self):
2214-
_TestTorchMixin._test_matrix_power(self, conv_fn=lambda t: t.cuda())
2215-
2216-
def test_chain_matmul(self):
2217-
_TestTorchMixin._test_chain_matmul(self, cast=lambda t: t.cuda())
2218-
2219-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2220-
def test_det_logdet_slogdet(self):
2221-
_TestTorchMixin._test_det_logdet_slogdet(self, 'cuda')
2222-
2223-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2224-
def test_det_logdet_slogdet_batched(self):
2225-
_TestTorchMixin._test_det_logdet_slogdet_batched(self, 'cuda')
2226-
2227-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2228-
def test_solve(self):
2229-
_TestTorchMixin._test_solve(self, lambda t: t.cuda())
2230-
2231-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2232-
def test_solve_batched(self):
2233-
_TestTorchMixin._test_solve_batched(self, lambda t: t.cuda())
2234-
2235-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2236-
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2237-
def test_solve_batched_non_contiguous(self):
2238-
_TestTorchMixin._test_solve_batched_non_contiguous(self, lambda t: t.cuda())
2239-
2240-
@slowTest
2241-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2242-
def test_solve_batched_many_batches(self):
2243-
_TestTorchMixin._test_solve_batched_many_batches(self, lambda t: t.cuda())
2244-
2245-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2246-
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2247-
def test_solve_batched_broadcasting(self):
2248-
_TestTorchMixin._test_solve_batched_broadcasting(self, lambda t: t.cuda())
2249-
2250-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2251-
def test_cholesky_solve(self):
2252-
_TestTorchMixin._test_cholesky_solve(self, lambda t: t.cuda())
2253-
2254-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2255-
def test_cholesky_solve_batched(self):
2256-
_TestTorchMixin._test_cholesky_solve_batched(self, lambda t: t.cuda())
2257-
2258-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2259-
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2260-
def test_cholesky_solve_batched_non_contiguous(self):
2261-
_TestTorchMixin._test_cholesky_solve_batched_non_contiguous(self, lambda t: t.cuda())
2262-
2263-
@slowTest
2264-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2265-
def test_cholesky_solve_batched_many_batches(self):
2266-
_TestTorchMixin._test_cholesky_solve_batched_many_batches(self, lambda t: t.cuda())
2267-
2268-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2269-
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2270-
def test_cholesky_solve_batched_broadcasting(self):
2271-
_TestTorchMixin._test_cholesky_solve_batched_broadcasting(self, lambda t: t.cuda())
2272-
2273-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2274-
def test_cholesky_inverse(self):
2275-
_TestTorchMixin._test_cholesky_inverse(self, lambda t: t.cuda())
2276-
2277-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2278-
def test_cholesky(self):
2279-
_TestTorchMixin._test_cholesky(self, lambda t: t.cuda())
2280-
2281-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2282-
def test_cholesky_batched(self):
2283-
_TestTorchMixin._test_cholesky_batched(self, lambda t: t.cuda())
2284-
2285-
@slowTest
2286-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2287-
def test_cholesky_batched_many_batches(self):
2288-
_TestTorchMixin._test_cholesky_batched_many_batches(self, lambda t: t.cuda())
2289-
2290-
def test_view(self):
2291-
_TestTorchMixin._test_view(self, lambda t: t.cuda())
2292-
2293-
def test_flip(self):
2294-
_TestTorchMixin._test_flip(self, use_cuda=True)
2295-
2296-
def test_rot90(self):
2297-
_TestTorchMixin._test_rot90(self, use_cuda=True)
2298-
2299-
def test_signal_window_functions(self):
2300-
_TestTorchMixin._test_signal_window_functions(self, device=torch.device('cuda'))
2301-
23022178
@skipIfRocm
23032179
def test_fft_ifft_rfft_irfft(self):
23042180
_TestTorchMixin._test_fft_ifft_rfft_irfft(self, device=torch.device('cuda'))
@@ -2411,10 +2287,6 @@ def test_multinomial(self):
24112287
samples = probs.multinomial(1000000, replacement=True)
24122288
self.assertGreater(probs[samples].min().item(), 0)
24132289

2414-
@skipCUDANonDefaultStreamIf(True)
2415-
def test_multinomial_alias(self):
2416-
_TestTorchMixin._test_multinomial_alias(self, lambda t: t.cuda())
2417-
24182290
@staticmethod
24192291
def mute():
24202292
os.dup2(os.open(os.devnull, os.O_WRONLY), sys.stderr.fileno())
@@ -2452,25 +2324,6 @@ def test_multinomial_invalid_probs_cuda(self):
24522324
self._spawn_method(test_method, torch.Tensor([1, 1, nan]))
24532325
self._spawn_method(test_method, torch.Tensor([0, 1, 0]))
24542326

2455-
def test_broadcast(self):
2456-
_TestTorchMixin._test_broadcast(self, lambda t: t.cuda())
2457-
2458-
def test_contiguous(self):
2459-
_TestTorchMixin._test_contiguous(self, lambda t: t.cuda())
2460-
2461-
def test_broadcast_fused_matmul(self):
2462-
_TestTorchMixin._test_broadcast_fused_matmul(self, lambda t: t.cuda())
2463-
2464-
def test_broadcast_batched_matmul(self):
2465-
_TestTorchMixin._test_broadcast_batched_matmul(self, lambda t: t.cuda())
2466-
2467-
def test_index(self):
2468-
_TestTorchMixin._test_index(self, lambda t: t.cuda())
2469-
2470-
@skipCUDANonDefaultStreamIf(True)
2471-
def test_advancedindex(self):
2472-
_TestTorchMixin._test_advancedindex(self, lambda t: t.cuda())
2473-
24742327
def test_advancedindex_mixed_cpu_cuda(self):
24752328
def test(x, ia, ib):
24762329
# test getitem
@@ -2519,9 +2372,6 @@ def test(x, ia, ib):
25192372
ib = ib.to(other_device)
25202373
test(x, ia, ib)
25212374

2522-
def test_advancedindex_big(self):
2523-
_TestTorchMixin._test_advancedindex_big(self, lambda t: t.cuda())
2524-
25252375
@slowTest
25262376
@unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
25272377
def test_huge_index(self):
@@ -2531,9 +2381,6 @@ def test_huge_index(self):
25312381
res_cpu = src.cpu()[idx.cpu()]
25322382
self.assertEqual(res.cpu(), res_cpu)
25332383

2534-
def test_kthvalue(self):
2535-
_TestTorchMixin._test_kthvalue(self, device='cuda')
2536-
25372384
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
25382385
def test_lu(self):
25392386
_TestTorchMixin._test_lu(self, lambda t: t.cuda(), pivot=False)
@@ -2549,29 +2396,11 @@ def test_lu_solve_batched(self):
25492396
_TestTorchMixin._test_lu_solve_batched(self, lambda t: t.cuda(), pivot=False)
25502397
_TestTorchMixin._test_lu_solve_batched(self, lambda t: t.cuda(), pivot=True)
25512398

2552-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2553-
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2554-
def test_lu_solve_batched_non_contiguous(self):
2555-
_TestTorchMixin._test_lu_solve_batched_non_contiguous(self, lambda t: t.cuda())
2556-
2557-
@slowTest
2558-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2559-
def test_lu_solve_batched_many_batches(self):
2560-
_TestTorchMixin._test_lu_solve_batched_many_batches(self, lambda t: t.cuda())
2561-
2562-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2563-
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2564-
def test_lu_solve_batched_broadcasting(self):
2565-
_TestTorchMixin._test_lu_solve_batched_broadcasting(self, lambda t: t.cuda())
2566-
25672399
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
25682400
def test_lu_unpack(self):
25692401
_TestTorchMixin._test_lu_unpack(self, lambda t: t.cuda(), pivot=False)
25702402
_TestTorchMixin._test_lu_unpack(self, lambda t: t.cuda(), pivot=True)
25712403

2572-
def test_dim_reduction(self):
2573-
_TestTorchMixin._test_dim_reduction(self, lambda t: t.cuda())
2574-
25752404
def test_tensor_gather(self):
25762405
_TestTorchMixin._test_gather(self, lambda t: t.cuda(), False)
25772406

@@ -2603,12 +2432,6 @@ def test_max_with_inf(self):
26032432
def test_min_with_inf(self):
26042433
_TestTorchMixin._test_min_with_inf(self, (torch.half, torch.float, torch.double), 'cuda')
26052434

2606-
def test_rpow(self):
2607-
_TestTorchMixin._test_rpow(self, lambda x: x.cuda())
2608-
2609-
def test_remainder_overflow(self):
2610-
_TestTorchMixin._test_remainder_overflow(self, dtype=torch.int64, device='cuda')
2611-
26122435
def test_var(self):
26132436
cpu_tensor = torch.randn(2, 3, 3)
26142437
gpu_tensor = cpu_tensor.cuda()
@@ -2704,18 +2527,6 @@ def test(use_double=False):
27042527
test(True)
27052528
test(False)
27062529

2707-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2708-
def test_symeig(self):
2709-
_TestTorchMixin._test_symeig(self, lambda t: t.cuda())
2710-
2711-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2712-
def test_svd(self):
2713-
_TestTorchMixin._test_svd(self, lambda t: t.cuda())
2714-
2715-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2716-
def test_svd_no_singularvectors(self):
2717-
_TestTorchMixin._test_svd_no_singularvectors(self, lambda t: t.cuda())
2718-
27192530
def test_arange(self):
27202531
for t in ['IntTensor', 'LongTensor', 'FloatTensor', 'DoubleTensor']:
27212532
a = torch.cuda.__dict__[t]()
@@ -2739,64 +2550,11 @@ def test_logspace(self):
27392550
b = torch.logspace(1, 10, 10, 2)
27402551
self.assertEqual(a, b.cuda())
27412552

2742-
def test_lerp(self):
2743-
_TestTorchMixin._test_lerp(self, lambda t: t.cuda())
2744-
2745-
def test_diagflat(self):
2746-
_TestTorchMixin._test_diagflat(self, dtype=torch.float32, device='cuda')
2747-
2748-
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
2749-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2750-
@skipCUDANonDefaultStreamIf(True)
2751-
def test_norm(self):
2752-
_TestTorchMixin._test_norm(self, device='cuda')
2753-
2754-
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
2755-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2756-
@skipCUDANonDefaultStreamIf(True)
2757-
def test_nuclear_norm_axes_small_brute_force(self):
2758-
_TestTorchMixin._test_nuclear_norm_axes(self, device='cuda')
2759-
2760-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2761-
@skipCUDANonDefaultStreamIf(True)
2762-
def test_nuclear_norm_exceptions(self):
2763-
_TestTorchMixin._test_nuclear_norm_exceptions(self, device='cuda')
2764-
2765-
def test_dist(self):
2766-
_TestTorchMixin._test_dist(self, device='cuda')
2767-
2768-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2769-
def test_geqrf(self):
2770-
_TestTorchMixin._test_geqrf(self, lambda t: t.cuda())
2771-
2772-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2773-
@skipCUDANonDefaultStreamIf(True)
2774-
def test_triangular_solve(self):
2775-
_TestTorchMixin._test_triangular_solve(self, lambda t: t.cuda())
2776-
27772553
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
27782554
@unittest.skip("Spuriously failing")
27792555
def test_triangular_solve_batched(self):
27802556
_TestTorchMixin._test_triangular_solve_batched(self, lambda t: t.cuda())
27812557

2782-
@slowTest
2783-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2784-
def test_triangular_solve_batched_many_batches(self):
2785-
_TestTorchMixin._test_triangular_solve_batched_many_batches(self, lambda t: t.cuda())
2786-
2787-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2788-
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
2789-
def test_triangular_solve_batched_broadcasting(self):
2790-
_TestTorchMixin._test_triangular_solve_batched_broadcasting(self, lambda t: t.cuda())
2791-
2792-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2793-
def test_lstsq(self):
2794-
_TestTorchMixin._test_lstsq(self, 'cuda')
2795-
2796-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2797-
def test_qr(self):
2798-
_TestTorchMixin._test_qr(self, lambda t: t.cuda())
2799-
28002558
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
28012559
def test_get_set_rng_state_all(self):
28022560
states = torch.cuda.get_rng_state_all()
@@ -2815,12 +2573,6 @@ def test_nvtx(self):
28152573
torch.cuda.nvtx.mark("bar")
28162574
torch.cuda.nvtx.range_pop()
28172575

2818-
def test_randperm_cuda(self):
2819-
_TestTorchMixin._test_randperm(self, device='cuda')
2820-
2821-
def test_random_neg_values(self):
2822-
_TestTorchMixin._test_random_neg_values(self, use_cuda=True)
2823-
28242576
def test_bincount_cuda(self):
28252577
_TestTorchMixin._test_bincount(self, device='cuda')
28262578
# ensure CUDA code coverage
@@ -2911,9 +2663,6 @@ def test_large_trilu_indices(self):
29112663
for test_args in tri_large_tests_args:
29122664
_compare_large_trilu_indices(self, *test_args, device='cuda')
29132665

2914-
def test_triu_tril(self):
2915-
_TestTorchMixin._test_triu_tril(self, lambda t: t.cuda())
2916-
29172666
def test_cuda_round(self):
29182667
# test half-to-even
29192668
a = [-5.8, -3.5, -2.3, -1.5, -0.5, 0.5, 1.5, 2.3, 3.5, 5.8]

0 commit comments

Comments
 (0)