Skip to content

Commit b6b2b4c

Browse files
Mike Ruberryfacebook-github-bot
authored andcommitted
Refines test_torch.py generic device testing (#26244)
Summary: - Adds SkipCUDAIfRocm and skipCPUIfNoMkl decorators, ports corresponding tests - Changes "SkipIf" input semantics for consistency - Removes torchtest, which has been replaced with this new generic framework - Refactors some common parts out of CUDA tests to TestTorchDeviceType - Ensures all MAGMA tests run on default stream by putting the skipCUDANonDefaultStreamIf in the skipCUDAIfNoMagma decorator. Pull Request resolved: #26244 Differential Revision: D17389060 Pulled By: mruberry fbshipit-source-id: 1375774f24c2266049e6d4b899e7300ddf32eac8
1 parent 26d537d commit b6b2b4c

File tree

4 files changed

+588
-695
lines changed

4 files changed

+588
-695
lines changed

test/common_device_type.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from functools import wraps
33
import unittest
44
import torch
5-
from common_utils import TestCase
5+
from common_utils import TestCase, TEST_WITH_ROCM, TEST_MKL, \
6+
skipCUDANonDefaultStreamIf
67

78
# Note: Generic Device-Type Testing
89
#
@@ -109,7 +110,7 @@ class CUDATestBase(DeviceTypeTestBase):
109110
def setUpClass(cls):
110111
# has_magma shows up after cuda is initialized
111112
torch.ones(1).cuda()
112-
cls.has_magma = torch.cuda.has_magma
113+
cls.no_magma = not torch.cuda.has_magma
113114

114115

115116
# Adds available device-type-specific test base classes
@@ -189,7 +190,7 @@ def __call__(self, fn):
189190
@wraps(fn)
190191
def dep_fn(slf, device, *args, **kwargs):
191192
if self.device_type is None or self.device_type == slf.device_type:
192-
if not self.dep or (isinstance(self.dep, str) and not getattr(slf, self.dep, False)):
193+
if (isinstance(self.dep, str) and getattr(slf, self.dep, True)) or (isinstance(self.dep, bool) and self.dep):
193194
raise unittest.SkipTest(self.reason)
194195

195196
return fn(slf, device, *args, **kwargs)
@@ -212,9 +213,19 @@ def __init__(self, dep, reason):
212213

213214
# Specifies LAPACK as a CPU dependency.
214215
def skipCPUIfNoLapack(fn):
215-
return skipCPUIf(torch._C.has_lapack, "PyTorch compiled without Lapack")(fn)
216+
return skipCPUIf(not torch._C.has_lapack, "PyTorch compiled without Lapack")(fn)
217+
218+
219+
# Specifies MKL as a CPU dependency.
220+
def skipCPUIfNoMkl(fn):
221+
return skipCPUIf(not TEST_MKL, "PyTorch is built without MKL support")(fn)
216222

217223

218224
# Specifies MAGMA as a CUDA dependency.
219225
def skipCUDAIfNoMagma(fn):
220-
return skipCUDAIf('has_magma', "no MAGMA library detected")(fn)
226+
return skipCUDAIf('no_magma', "no MAGMA library detected")(skipCUDANonDefaultStreamIf(True)(fn))
227+
228+
229+
# Skips this test when the CUDA device type is actually ROCm.
230+
def skipCUDAIfRocm(fn):
231+
return skipCUDAIf(TEST_WITH_ROCM, "test doesn't currently work on the ROCm stack")(fn)

test/common_utils.py

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -208,60 +208,11 @@ def wrapper(*args, **kwargs):
208208
fn(*args, **kwargs)
209209
return wrapper
210210

211-
212-
213211
def _test_function(fn, device):
214212
def run_test_function(self):
215213
return fn(self, device)
216214
return run_test_function
217215

218-
219-
class torchtest():
220-
"""Allows to generate and run per-device unittests.
221-
222-
This decorator class allows to generate and run per-device unittest.
223-
224-
Example:
225-
226-
class _TestTorchMixin(torchtest):
227-
228-
@torchtest.for_all_device_types()
229-
def test_zeros_like(self, device):
230-
expected = torch.zeros((100, 100,), device=device)
231-
232-
Will execute:
233-
234-
test_zeros_like (__main__.TestTorch) ... skipped 'Look at test_zeros_like_cpu, test_zeros_like_cuda results.'
235-
test_zeros_like_cpu (__main__.TestTorch) ... ok
236-
test_zeros_like_cuda (__main__.TestTorch) ... ok
237-
238-
To work properly, test class should be inherited from `torchtest`.
239-
for_all_device_types decorator does not guarantee proper functionality in
240-
combination with other decorators.
241-
242-
Please do not extend this decorator to support other cases (such as dtype,
243-
layouts, etc) without consulting with bigger group. Devices is the special
244-
case as build flags control additions/removals (see
245-
https://github.com/pytorch/pytorch/pull/23824 for the reference).
246-
"""
247-
@classmethod
248-
def for_all_device_types(cls):
249-
def wrapper(fn):
250-
test_names = []
251-
252-
for device in torch.testing.get_all_device_types():
253-
test_name = fn.__name__ + '_' + device
254-
assert not hasattr(cls, test_name), "Duplicated test name: " + test_name
255-
setattr(cls, test_name, _test_function(fn, device))
256-
test_names.append(test_name)
257-
258-
@wraps(fn)
259-
def empty_test(*args, **kwargs):
260-
raise unittest.SkipTest("Look at {} results.".format(", ".join(test_names)))
261-
return empty_test
262-
return wrapper
263-
264-
265216
def skipIfNoLapack(fn):
266217
@wraps(fn)
267218
def wrapper(*args, **kwargs):
@@ -271,7 +222,6 @@ def wrapper(*args, **kwargs):
271222
fn(*args, **kwargs)
272223
return wrapper
273224

274-
275225
def skipIfNotRegistered(op_name, message):
276226
"""Wraps the decorator to hide the import of the `core`.
277227

test/test_cuda.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2176,9 +2176,7 @@ def _select_broadcastable_dims(dims_full=None):
21762176
return _TestTorchMixin._select_broadcastable_dims(dims_full)
21772177

21782178
@skipIfRocm
2179-
def test_fft_ifft_rfft_irfft(self):
2180-
_TestTorchMixin._test_fft_ifft_rfft_irfft(self, device=torch.device('cuda'))
2181-
2179+
def test_fft_ifft_rfft_irfft_plan_cache(self):
21822180
@contextmanager
21832181
def plan_cache_max_size(n, device=None):
21842182
if device is None:
@@ -2246,15 +2244,8 @@ def plan_cache_max_size(n, device=None):
22462244
self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10) # default is cuda:0
22472245
self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11) # default is cuda:1
22482246

2249-
# passes on ROCm w/ python 2.7, fails w/ python 3.6
2250-
@skipIfRocm
2251-
def test_stft(self):
2252-
_TestTorchMixin._test_stft(self, device=torch.device('cuda'))
2253-
2254-
def test_multinomial(self):
2255-
_TestTorchMixin._test_multinomial(self, torch.cuda.FloatTensor)
2256-
2257-
# Test two corner cases from older PyTorch (Issue #4858)
2247+
# Tests two corner cases from older PyTorch (Issue #4858).
2248+
def test_multinomial_corner_cases(self):
22582249
freqs = torch.cuda.FloatTensor([
22592250
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
22602251
0.03178183361887932, 0.027680952101945877, 0.033176131546497345,
@@ -2573,8 +2564,7 @@ def test_nvtx(self):
25732564
torch.cuda.nvtx.mark("bar")
25742565
torch.cuda.nvtx.range_pop()
25752566

2576-
def test_bincount_cuda(self):
2577-
_TestTorchMixin._test_bincount(self, device='cuda')
2567+
def test_bincount_compare(self):
25782568
# ensure CUDA code coverage
25792569
input_size = (5000,)
25802570
w = torch.randn(input_size, device='cuda')

0 commit comments

Comments
 (0)