Skip to content

Commit fbf991d

Browse files
Mike Ruberryfacebook-github-bot
authored andcommitted
Creates generic device type testing framework (#25967)
Summary: This PR addresses #24851 by... 1. lets device types easily register themselves for testing 2. lets tests be written to run on multiple devices and with multiple dtypes 3. provides a mechanism to instantiate those tests so they are discoverable and filterable by unittest and pytest It refactors three tests from test_torch.py to demonstrate how to use it. `test_diagonal` is the simplest example. Most tests just need to be modified to accept 'device' as an argument. The framework will then instantiate `test_diagonal_cpu` and `test_diagonal_cuda` (when CUDA is available) which call `test_diagonal` with the appropriate 'device' argument. `test_neg` also has dtype variants. It accepts both 'device' and 'dtype' as arguments, and the dtypes it runs with are specified with the 'dtypes' decorator. Dtypes can be specified for all device types and particular device types. The framework instantiates tests like `test_neg_cpu_torch.float`. `test_inverse` has device-specific dependencies. These dependencies are expressed with the sugary 'skipCUDAIfNoMagma' and 'skipCPUIfNoLapack' decorators. These decorators are device-specific so CPU testing is not skipped if Magma is not installed, and there conditions may be checked after or before the test case has been initialized. This means that skipCUDAIfNoMagma does not initialize CUDA. In fact, CUDA is only initialized if a CUDA test is run. These instantiated tests may be run as usual and with pytest filtering it's easy to run one test on all device types, run all the tests for a particular device type, or run a device type and dtype combination. See the note "Generic Device-Type Testing" for more detail. Pull Request resolved: #25967 Differential Revision: D17381987 Pulled By: mruberry fbshipit-source-id: 4a639641130f0a59d22da0efe0951b24b5bc4bfb
1 parent dc6939e commit fbf991d

File tree

3 files changed

+341
-138
lines changed

3 files changed

+341
-138
lines changed

test/common_device_type.py

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
import inspect
2+
from functools import wraps
3+
import unittest
4+
import torch
5+
from common_utils import TestCase
6+
7+
# Note: Generic Device-Type Testing
8+
#
9+
# [WRITING TESTS]
10+
#
11+
# Write your test class as usual except:
12+
# (1) Only define test methods in the test class itself. Helper methods
13+
# and non-methods must be inherited. This limitation is for Python2
14+
# compatibility.
15+
# (2) Each test method should have the signature
16+
# testX(self, device)
17+
# The device argument will be a string like 'cpu' or 'cuda.'
18+
# (3) Prefer using test decorators defined in this file to others.
19+
# For example, using the @skipIfNoLapack decorator instead of the
20+
# @skipCPUIfNoLapack will cause the test to not run on CUDA if
21+
# LAPACK is not available, which is wrong. If you need to use a decorator
22+
# you may want to ask about porting it to this framework.
23+
#
24+
# See the TestTorchDeviceType class in test_torch.py for an example.
25+
#
26+
# [RUNNING TESTS]
27+
#
28+
# After defining your test class call instantiate_device_type_tests on it
29+
# and pass in globals() for the second argument. This will instantiate
30+
# discoverable device-specific test classes from your generic class. It will
31+
# also hide the tests in your generic class so they're not run directly.
32+
#
33+
# For each generic testX, a new test textX_<device_type> will be created.
34+
# These tests will be put in classes named GenericTestClassName<DEVICE_TYPE>.
35+
# For example, test_diagonal in TestTorchDeviceType becomes test_diagonal_cpu
36+
# in TestTorchDeviceTypeCPU and test_diagonal_cuda in TestTorchDeviceTypeCUDA.
37+
#
38+
# In short, if you write a test signature like
39+
# def textX(self, device)
40+
# You are effectively writing
41+
# def testX_cpu(self, device='cpu')
42+
# def textX_cuda(self, device='cuda')
43+
# def testX_xla(self, device='xla')
44+
# ...
45+
#
46+
# These tests can be run directly like normal tests:
47+
# "python test_torch.py TestTorchDeviceTypeCPU.test_diagonal_cpu"
48+
#
49+
# Collections of tests can be run using pytest filtering. For example,
50+
# "pytest test_torch.py -k 'test_diag'"
51+
# will run test_diag on every available device.
52+
# To specify particular device types the 'and' keyword can be used:
53+
# "pytest test_torch.py -k 'test_diag and cpu'"
54+
# pytest filtering also makes it easy to run all tests on a particular device
55+
# type.
56+
#
57+
# [ADDING A DEVICE TYPE]
58+
#
59+
# To add a device type:
60+
#
61+
# (1) Create a new "TestBase" extending DeviceTypeTestBase.
62+
# See CPUTestBase and CUDATestBase below.
63+
# (2) Define the "device_type" attribute of the base to be the
64+
# appropriate string.
65+
# (3) Add logic to this file that appends your base class to
66+
# device_type_test_bases when your device type is available.
67+
# (4) (Optional) Write setUpClass/tearDownClass class methods that
68+
# instantiate dependencies (see MAGMA in CUDATestBase).
69+
# (5) (Optional) Override the "instantiate_test" method for total
70+
# control over how your class creates tests.
71+
#
72+
# setUpClass is called AFTER tests have been created and BEFORE and ONLY IF
73+
# they are run. This makes it useful for initializing devices and dependencies.
74+
#
75+
76+
# List of device type test bases that can be used to instantiate tests.
77+
# See below for how this list is populated. If you're adding a device type
78+
# you should check if it's available and (if it is) add it to this list.
79+
device_type_test_bases = []
80+
81+
82+
class DeviceTypeTestBase(TestCase):
83+
device_type = "generic_device_type"
84+
85+
# Creates device-specific tests.
86+
@classmethod
87+
def instantiate_test(cls, test):
88+
test_name = test.__name__ + "_" + cls.device_type
89+
90+
assert not hasattr(cls, test_name), "Redefinition of test {0}".format(test_name)
91+
92+
@wraps(test)
93+
def instantiated_test(self, test=test):
94+
return test(self, cls.device_type)
95+
96+
setattr(cls, test_name, instantiated_test)
97+
98+
99+
class CPUTestBase(DeviceTypeTestBase):
100+
device_type = "cpu"
101+
102+
103+
class CUDATestBase(DeviceTypeTestBase):
104+
device_type = "cuda"
105+
_do_cuda_memory_leak_check = True
106+
_do_cuda_non_default_stream = True
107+
108+
@classmethod
109+
def setUpClass(cls):
110+
# has_magma shows up after cuda is initialized
111+
torch.ones(1).cuda()
112+
cls.has_magma = torch.cuda.has_magma
113+
114+
115+
# Adds available device-type-specific test base classes
116+
device_type_test_bases.append(CPUTestBase)
117+
if torch.cuda.is_available():
118+
device_type_test_bases.append(CUDATestBase)
119+
120+
121+
# Adds 'instantiated' device-specific test cases to the given scope.
122+
# The tests in these test cases are derived from the generic tests in
123+
# generic_test_class.
124+
# See note "Generic Device Type Testing."
125+
def instantiate_device_type_tests(generic_test_class, scope):
126+
# Removes the generic test class from its enclosing scope so its tests
127+
# are not discoverable.
128+
del scope[generic_test_class.__name__]
129+
130+
# Creates an 'empty' version of the generic_test_class
131+
# Note: we don't inherit from the generic_test_class directly because
132+
# that would add its tests to our test classes and they would be
133+
# discovered (despite not being runnable). Inherited methods also
134+
# can't be removed later, and we can't rely on load_tests because
135+
# pytest doesn't support it (as of this writing).
136+
empty_name = generic_test_class.__name__ + "_base"
137+
empty_class = type(empty_name, generic_test_class.__bases__, {})
138+
139+
# Acquires members names
140+
generic_members = set(dir(generic_test_class)) - set(dir(empty_class))
141+
generic_tests = [x for x in generic_members if x.startswith('test')]
142+
143+
# Checks that the generic test suite only has test members
144+
# Note: for Python2 compat.
145+
# Note: Nontest members can be inherited, so if you want to use a helper
146+
# function you can put it in a base class.
147+
generic_nontests = generic_members - set(generic_tests)
148+
assert len(generic_nontests) == 0, "Generic device class has non-test members"
149+
150+
for base in device_type_test_bases:
151+
# Creates the device-specific test case
152+
class_name = generic_test_class.__name__ + base.device_type.upper()
153+
device_type_test_class = type(class_name, (base, empty_class), {})
154+
155+
for name in generic_tests:
156+
# Attempts to acquire a function from the attribute
157+
test = getattr(generic_test_class, name)
158+
if hasattr(test, '__func__'):
159+
test = test.__func__
160+
assert inspect.isfunction(test), "Couldn't extract function from '{0}'".format(name)
161+
# Instantiates the device-specific tests
162+
device_type_test_class.instantiate_test(test)
163+
164+
# Mimics defining the instantiated class in the caller's file
165+
# by setting its module to the given class's and adding
166+
# the module to the given scope.
167+
# This lets the instantiated class be discovered by unittest.
168+
device_type_test_class.__module__ = generic_test_class.__module__
169+
scope[class_name] = device_type_test_class
170+
171+
172+
# Decorator that specifies a test dependency.
173+
# Notes:
174+
# (1) Dependencies stack. Multiple dependencies are all evaluated.
175+
# (2) Dependencies can either be bools or strings. If a string the
176+
# test base must have defined the corresponding attribute to be True
177+
# for the test to run. If you want to use a string argument you should
178+
# probably define a new decorator instead (see below).
179+
# (3) Prefer the existing decorators to defining the 'device_type' kwarg.
180+
class skipIf(object):
181+
182+
def __init__(self, dep, reason, device_type=None):
183+
self.dep = dep
184+
self.reason = reason
185+
self.device_type = device_type
186+
187+
def __call__(self, fn):
188+
189+
@wraps(fn)
190+
def dep_fn(slf, device, *args, **kwargs):
191+
if self.device_type is None or self.device_type == slf.device_type:
192+
if not self.dep or (isinstance(self.dep, str) and not getattr(slf, self.dep, False)):
193+
raise unittest.SkipTest(self.reason)
194+
195+
return fn(slf, device, *args, **kwargs)
196+
return dep_fn
197+
198+
199+
# Specifies a CPU dependency.
200+
class skipCPUIf(skipIf):
201+
202+
def __init__(self, dep, reason):
203+
super(skipCPUIf, self).__init__(dep, reason, device_type='cpu')
204+
205+
206+
# Specifies a CUDA dependency.
207+
class skipCUDAIf(skipIf):
208+
209+
def __init__(self, dep, reason):
210+
super(skipCUDAIf, self).__init__(dep, reason, device_type='cuda')
211+
212+
213+
# Specifies LAPACK as a CPU dependency.
214+
def skipCPUIfNoLapack(fn):
215+
return skipCPUIf(torch._C.has_lapack, "PyTorch compiled without Lapack")(fn)
216+
217+
218+
# Specifies MAGMA as a CUDA dependency.
219+
def skipCUDAIfNoMagma(fn):
220+
return skipCUDAIf('has_magma', "no MAGMA library detected")(fn)

test/test_cuda.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,9 +1086,6 @@ def test_abs_zero(self):
10861086
for num in abs_zeros:
10871087
self.assertGreater(math.copysign(1.0, num), 0.0)
10881088

1089-
def test_neg(self):
1090-
_TestTorchMixin._test_neg(self, lambda t: t.cuda())
1091-
10921089
def test_bitwise_not(self):
10931090
_TestTorchMixin._test_bitwise_not(self, 'cuda')
10941091

@@ -2198,10 +2195,6 @@ def test_prod_large(self):
21982195
def _select_broadcastable_dims(dims_full=None):
21992196
return _TestTorchMixin._select_broadcastable_dims(dims_full)
22002197

2201-
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
2202-
def test_inverse(self):
2203-
_TestTorchMixin._test_inverse(self, lambda t: t.cuda())
2204-
22052198
@slowTest
22062199
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
22072200
def test_inverse_many_batches(self):
@@ -2748,9 +2741,6 @@ def test_logspace(self):
27482741
def test_lerp(self):
27492742
_TestTorchMixin._test_lerp(self, lambda t: t.cuda())
27502743

2751-
def test_diagonal(self):
2752-
_TestTorchMixin._test_diagonal(self, dtype=torch.float32, device='cuda')
2753-
27542744
def test_diagflat(self):
27552745
_TestTorchMixin._test_diagflat(self, dtype=torch.float32, device='cuda')
27562746

0 commit comments

Comments
 (0)