Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 220 additions & 0 deletions test/common_device_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
import inspect
from functools import wraps
import unittest
import torch
from common_utils import TestCase

# Note: Generic Device-Type Testing
#
# [WRITING TESTS]
#
# Write your test class as usual except:
# (1) Only define test methods in the test class itself. Helper methods
# and non-methods must be inherited. This limitation is for Python2
# compatibility.
# (2) Each test method should have the signature
# testX(self, device)
# The device argument will be a string like 'cpu' or 'cuda.'
# (3) Prefer using test decorators defined in this file to others.
# For example, using the @skipIfNoLapack decorator instead of the
# @skipCPUIfNoLapack will cause the test to not run on CUDA if
# LAPACK is not available, which is wrong. If you need to use a decorator
# you may want to ask about porting it to this framework.
#
# See the TestTorchDeviceType class in test_torch.py for an example.
#
# [RUNNING TESTS]
#
# After defining your test class call instantiate_device_type_tests on it
# and pass in globals() for the second argument. This will instantiate
# discoverable device-specific test classes from your generic class. It will
# also hide the tests in your generic class so they're not run directly.
#
# For each generic testX, a new test textX_<device_type> will be created.
# These tests will be put in classes named GenericTestClassName<DEVICE_TYPE>.
# For example, test_diagonal in TestTorchDeviceType becomes test_diagonal_cpu
# in TestTorchDeviceTypeCPU and test_diagonal_cuda in TestTorchDeviceTypeCUDA.
#
# In short, if you write a test signature like
# def textX(self, device)
# You are effectively writing
# def testX_cpu(self, device='cpu')
# def textX_cuda(self, device='cuda')
# def testX_xla(self, device='xla')
# ...
#
# These tests can be run directly like normal tests:
# "python test_torch.py TestTorchDeviceTypeCPU.test_diagonal_cpu"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we blocked on just calling these "TestTorchCPU" and similar?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can call them whatever you like but there already is a distinct TestTorch class. Maybe there won't be after we refactor?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good question, but there isn't a TestTorchCPU class in any case :).

#
# Collections of tests can be run using pytest filtering. For example,
# "pytest test_torch.py -k 'test_diag'"
# will run test_diag on every available device.
# To specify particular device types the 'and' keyword can be used:
# "pytest test_torch.py -k 'test_diag and cpu'"
# pytest filtering also makes it easy to run all tests on a particular device
# type.
#
# [ADDING A DEVICE TYPE]
#
# To add a device type:
#
# (1) Create a new "TestBase" extending DeviceTypeTestBase.
# See CPUTestBase and CUDATestBase below.
# (2) Define the "device_type" attribute of the base to be the
# appropriate string.
# (3) Add logic to this file that appends your base class to
# device_type_test_bases when your device type is available.
# (4) (Optional) Write setUpClass/tearDownClass class methods that
# instantiate dependencies (see MAGMA in CUDATestBase).
# (5) (Optional) Override the "instantiate_test" method for total
# control over how your class creates tests.
#
# setUpClass is called AFTER tests have been created and BEFORE and ONLY IF
# they are run. This makes it useful for initializing devices and dependencies.
#

# List of device type test bases that can be used to instantiate tests.
# See below for how this list is populated. If you're adding a device type
# you should check if it's available and (if it is) add it to this list.
device_type_test_bases = []


class DeviceTypeTestBase(TestCase):
device_type = "generic_device_type"

# Creates device-specific tests.
@classmethod
def instantiate_test(cls, test):
test_name = test.__name__ + "_" + cls.device_type

assert not hasattr(cls, test_name), "Redefinition of test {0}".format(test_name)

@wraps(test)
def instantiated_test(self, test=test):
return test(self, cls.device_type)

setattr(cls, test_name, instantiated_test)


class CPUTestBase(DeviceTypeTestBase):
device_type = "cpu"


class CUDATestBase(DeviceTypeTestBase):
device_type = "cuda"
_do_cuda_memory_leak_check = True
_do_cuda_non_default_stream = True

@classmethod
def setUpClass(cls):
# has_magma shows up after cuda is initialized
torch.ones(1).cuda()
cls.has_magma = torch.cuda.has_magma


# Adds available device-type-specific test base classes
device_type_test_bases.append(CPUTestBase)
if torch.cuda.is_available():
device_type_test_bases.append(CUDATestBase)


# Adds 'instantiated' device-specific test cases to the given scope.
# The tests in these test cases are derived from the generic tests in
# generic_test_class.
# See note "Generic Device Type Testing."
def instantiate_device_type_tests(generic_test_class, scope):
# Removes the generic test class from its enclosing scope so its tests
# are not discoverable.
del scope[generic_test_class.__name__]

# Creates an 'empty' version of the generic_test_class
# Note: we don't inherit from the generic_test_class directly because
# that would add its tests to our test classes and they would be
# discovered (despite not being runnable). Inherited methods also
# can't be removed later, and we can't rely on load_tests because
# pytest doesn't support it (as of this writing).
empty_name = generic_test_class.__name__ + "_base"
empty_class = type(empty_name, generic_test_class.__bases__, {})

# Acquires members names
generic_members = set(dir(generic_test_class)) - set(dir(empty_class))
generic_tests = [x for x in generic_members if x.startswith('test')]

# Checks that the generic test suite only has test members
# Note: for Python2 compat.
# Note: Nontest members can be inherited, so if you want to use a helper
# function you can put it in a base class.
generic_nontests = generic_members - set(generic_tests)
assert len(generic_nontests) == 0, "Generic device class has non-test members"

for base in device_type_test_bases:
# Creates the device-specific test case
class_name = generic_test_class.__name__ + base.device_type.upper()
device_type_test_class = type(class_name, (base, empty_class), {})

for name in generic_tests:
# Attempts to acquire a function from the attribute
test = getattr(generic_test_class, name)
if hasattr(test, '__func__'):
test = test.__func__
assert inspect.isfunction(test), "Couldn't extract function from '{0}'".format(name)
# Instantiates the device-specific tests
device_type_test_class.instantiate_test(test)

# Mimics defining the instantiated class in the caller's file
# by setting its module to the given class's and adding
# the module to the given scope.
# This lets the instantiated class be discovered by unittest.
device_type_test_class.__module__ = generic_test_class.__module__
scope[class_name] = device_type_test_class


# Decorator that specifies a test dependency.
# Notes:
# (1) Dependencies stack. Multiple dependencies are all evaluated.
# (2) Dependencies can either be bools or strings. If a string the
# test base must have defined the corresponding attribute to be True
# for the test to run. If you want to use a string argument you should
# probably define a new decorator instead (see below).
# (3) Prefer the existing decorators to defining the 'device_type' kwarg.
class skipIf(object):

def __init__(self, dep, reason, device_type=None):
self.dep = dep
self.reason = reason
self.device_type = device_type

def __call__(self, fn):

@wraps(fn)
def dep_fn(slf, device, *args, **kwargs):
if self.device_type is None or self.device_type == slf.device_type:
if not self.dep or (isinstance(self.dep, str) and not getattr(slf, self.dep, False)):
raise unittest.SkipTest(self.reason)

return fn(slf, device, *args, **kwargs)
return dep_fn


# Specifies a CPU dependency.
class skipCPUIf(skipIf):

def __init__(self, dep, reason):
super(skipCPUIf, self).__init__(dep, reason, device_type='cpu')


# Specifies a CUDA dependency.
class skipCUDAIf(skipIf):

def __init__(self, dep, reason):
super(skipCUDAIf, self).__init__(dep, reason, device_type='cuda')


# Specifies LAPACK as a CPU dependency.
def skipCPUIfNoLapack(fn):
return skipCPUIf(torch._C.has_lapack, "PyTorch compiled without Lapack")(fn)


# Specifies MAGMA as a CUDA dependency.
def skipCUDAIfNoMagma(fn):
return skipCUDAIf('has_magma', "no MAGMA library detected")(fn)
10 changes: 0 additions & 10 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,9 +1094,6 @@ def test_abs_zero(self):
for num in abs_zeros:
self.assertGreater(math.copysign(1.0, num), 0.0)

def test_neg(self):
_TestTorchMixin._test_neg(self, lambda t: t.cuda())

def test_bitwise_not(self):
_TestTorchMixin._test_bitwise_not(self, 'cuda')

Expand Down Expand Up @@ -2206,10 +2203,6 @@ def test_prod_large(self):
def _select_broadcastable_dims(dims_full=None):
return _TestTorchMixin._select_broadcastable_dims(dims_full)

@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
def test_inverse(self):
_TestTorchMixin._test_inverse(self, lambda t: t.cuda())

@slowTest
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
def test_inverse_many_batches(self):
Expand Down Expand Up @@ -2746,9 +2739,6 @@ def test_logspace(self):
def test_lerp(self):
_TestTorchMixin._test_lerp(self, lambda t: t.cuda())

def test_diagonal(self):
_TestTorchMixin._test_diagonal(self, dtype=torch.float32, device='cuda')

def test_diagflat(self):
_TestTorchMixin._test_diagflat(self, dtype=torch.float32, device='cuda')

Expand Down
Loading