-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Creates generic device type testing framework #25967
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
ee570a9
Generic device type testing framework
5bfe18e
removes unused imports
acfa454
python2 syntax, cuda mem leaks
4bfa0fe
actually fix lint
f0ecae1
python 2 compat
c5b6525
simplifies import and attempts to acquire function from generic descr…
7179ffc
updated per review
2feb039
removes dtype filtering, nit fixes, code golfing
bb16aad
test torch fixes
4d7ab84
Removes dtype variant instantiation
24a1062
Removes dtype variant instantiation
e53e963
comment fixes per review
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,220 @@ | ||
| import inspect | ||
| from functools import wraps | ||
| import unittest | ||
| import torch | ||
| from common_utils import TestCase | ||
|
|
||
| # Note: Generic Device-Type Testing | ||
| # | ||
| # [WRITING TESTS] | ||
| # | ||
| # Write your test class as usual except: | ||
| # (1) Only define test methods in the test class itself. Helper methods | ||
| # and non-methods must be inherited. This limitation is for Python2 | ||
| # compatibility. | ||
| # (2) Each test method should have the signature | ||
| # testX(self, device) | ||
| # The device argument will be a string like 'cpu' or 'cuda.' | ||
| # (3) Prefer using test decorators defined in this file to others. | ||
| # For example, using the @skipIfNoLapack decorator instead of the | ||
| # @skipCPUIfNoLapack will cause the test to not run on CUDA if | ||
| # LAPACK is not available, which is wrong. If you need to use a decorator | ||
| # you may want to ask about porting it to this framework. | ||
| # | ||
| # See the TestTorchDeviceType class in test_torch.py for an example. | ||
| # | ||
| # [RUNNING TESTS] | ||
| # | ||
| # After defining your test class call instantiate_device_type_tests on it | ||
| # and pass in globals() for the second argument. This will instantiate | ||
| # discoverable device-specific test classes from your generic class. It will | ||
| # also hide the tests in your generic class so they're not run directly. | ||
| # | ||
| # For each generic testX, a new test textX_<device_type> will be created. | ||
| # These tests will be put in classes named GenericTestClassName<DEVICE_TYPE>. | ||
| # For example, test_diagonal in TestTorchDeviceType becomes test_diagonal_cpu | ||
| # in TestTorchDeviceTypeCPU and test_diagonal_cuda in TestTorchDeviceTypeCUDA. | ||
| # | ||
| # In short, if you write a test signature like | ||
| # def textX(self, device) | ||
| # You are effectively writing | ||
| # def testX_cpu(self, device='cpu') | ||
| # def textX_cuda(self, device='cuda') | ||
| # def testX_xla(self, device='xla') | ||
| # ... | ||
| # | ||
| # These tests can be run directly like normal tests: | ||
| # "python test_torch.py TestTorchDeviceTypeCPU.test_diagonal_cpu" | ||
| # | ||
| # Collections of tests can be run using pytest filtering. For example, | ||
| # "pytest test_torch.py -k 'test_diag'" | ||
| # will run test_diag on every available device. | ||
| # To specify particular device types the 'and' keyword can be used: | ||
| # "pytest test_torch.py -k 'test_diag and cpu'" | ||
| # pytest filtering also makes it easy to run all tests on a particular device | ||
| # type. | ||
| # | ||
| # [ADDING A DEVICE TYPE] | ||
| # | ||
| # To add a device type: | ||
| # | ||
| # (1) Create a new "TestBase" extending DeviceTypeTestBase. | ||
| # See CPUTestBase and CUDATestBase below. | ||
| # (2) Define the "device_type" attribute of the base to be the | ||
| # appropriate string. | ||
| # (3) Add logic to this file that appends your base class to | ||
| # device_type_test_bases when your device type is available. | ||
| # (4) (Optional) Write setUpClass/tearDownClass class methods that | ||
| # instantiate dependencies (see MAGMA in CUDATestBase). | ||
| # (5) (Optional) Override the "instantiate_test" method for total | ||
| # control over how your class creates tests. | ||
| # | ||
| # setUpClass is called AFTER tests have been created and BEFORE and ONLY IF | ||
| # they are run. This makes it useful for initializing devices and dependencies. | ||
| # | ||
|
|
||
| # List of device type test bases that can be used to instantiate tests. | ||
| # See below for how this list is populated. If you're adding a device type | ||
| # you should check if it's available and (if it is) add it to this list. | ||
| device_type_test_bases = [] | ||
|
|
||
|
|
||
| class DeviceTypeTestBase(TestCase): | ||
| device_type = "generic_device_type" | ||
|
|
||
| # Creates device-specific tests. | ||
| @classmethod | ||
| def instantiate_test(cls, test): | ||
| test_name = test.__name__ + "_" + cls.device_type | ||
|
|
||
| assert not hasattr(cls, test_name), "Redefinition of test {0}".format(test_name) | ||
|
|
||
| @wraps(test) | ||
| def instantiated_test(self, test=test): | ||
| return test(self, cls.device_type) | ||
|
|
||
| setattr(cls, test_name, instantiated_test) | ||
|
|
||
|
|
||
| class CPUTestBase(DeviceTypeTestBase): | ||
| device_type = "cpu" | ||
|
|
||
|
|
||
| class CUDATestBase(DeviceTypeTestBase): | ||
| device_type = "cuda" | ||
| _do_cuda_memory_leak_check = True | ||
| _do_cuda_non_default_stream = True | ||
|
|
||
| @classmethod | ||
| def setUpClass(cls): | ||
| # has_magma shows up after cuda is initialized | ||
| torch.ones(1).cuda() | ||
| cls.has_magma = torch.cuda.has_magma | ||
|
|
||
|
|
||
| # Adds available device-type-specific test base classes | ||
| device_type_test_bases.append(CPUTestBase) | ||
mruberry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if torch.cuda.is_available(): | ||
| device_type_test_bases.append(CUDATestBase) | ||
|
|
||
|
|
||
| # Adds 'instantiated' device-specific test cases to the given scope. | ||
mruberry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # The tests in these test cases are derived from the generic tests in | ||
| # generic_test_class. | ||
| # See note "Generic Device Type Testing." | ||
mruberry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| def instantiate_device_type_tests(generic_test_class, scope): | ||
| # Removes the generic test class from its enclosing scope so its tests | ||
| # are not discoverable. | ||
| del scope[generic_test_class.__name__] | ||
mruberry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Creates an 'empty' version of the generic_test_class | ||
mruberry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # Note: we don't inherit from the generic_test_class directly because | ||
| # that would add its tests to our test classes and they would be | ||
| # discovered (despite not being runnable). Inherited methods also | ||
| # can't be removed later, and we can't rely on load_tests because | ||
| # pytest doesn't support it (as of this writing). | ||
| empty_name = generic_test_class.__name__ + "_base" | ||
| empty_class = type(empty_name, generic_test_class.__bases__, {}) | ||
|
|
||
| # Acquires members names | ||
| generic_members = set(dir(generic_test_class)) - set(dir(empty_class)) | ||
| generic_tests = [x for x in generic_members if x.startswith('test')] | ||
|
|
||
| # Checks that the generic test suite only has test members | ||
| # Note: for Python2 compat. | ||
| # Note: Nontest members can be inherited, so if you want to use a helper | ||
| # function you can put it in a base class. | ||
| generic_nontests = generic_members - set(generic_tests) | ||
| assert len(generic_nontests) == 0, "Generic device class has non-test members" | ||
|
|
||
| for base in device_type_test_bases: | ||
| # Creates the device-specific test case | ||
| class_name = generic_test_class.__name__ + base.device_type.upper() | ||
mruberry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| device_type_test_class = type(class_name, (base, empty_class), {}) | ||
|
|
||
| for name in generic_tests: | ||
| # Attempts to acquire a function from the attribute | ||
| test = getattr(generic_test_class, name) | ||
| if hasattr(test, '__func__'): | ||
| test = test.__func__ | ||
| assert inspect.isfunction(test), "Couldn't extract function from '{0}'".format(name) | ||
| # Instantiates the device-specific tests | ||
| device_type_test_class.instantiate_test(test) | ||
|
|
||
| # Mimics defining the instantiated class in the caller's file | ||
| # by setting its module to the given class's and adding | ||
| # the module to the given scope. | ||
| # This lets the instantiated class be discovered by unittest. | ||
| device_type_test_class.__module__ = generic_test_class.__module__ | ||
| scope[class_name] = device_type_test_class | ||
|
|
||
|
|
||
| # Decorator that specifies a test dependency. | ||
| # Notes: | ||
| # (1) Dependencies stack. Multiple dependencies are all evaluated. | ||
| # (2) Dependencies can either be bools or strings. If a string the | ||
| # test base must have defined the corresponding attribute to be True | ||
| # for the test to run. If you want to use a string argument you should | ||
| # probably define a new decorator instead (see below). | ||
| # (3) Prefer the existing decorators to defining the 'device_type' kwarg. | ||
| class skipIf(object): | ||
|
|
||
| def __init__(self, dep, reason, device_type=None): | ||
| self.dep = dep | ||
| self.reason = reason | ||
| self.device_type = device_type | ||
|
|
||
| def __call__(self, fn): | ||
|
|
||
| @wraps(fn) | ||
| def dep_fn(slf, device, *args, **kwargs): | ||
| if self.device_type is None or self.device_type == slf.device_type: | ||
| if not self.dep or (isinstance(self.dep, str) and not getattr(slf, self.dep, False)): | ||
| raise unittest.SkipTest(self.reason) | ||
|
|
||
| return fn(slf, device, *args, **kwargs) | ||
| return dep_fn | ||
|
|
||
|
|
||
| # Specifies a CPU dependency. | ||
| class skipCPUIf(skipIf): | ||
|
|
||
| def __init__(self, dep, reason): | ||
| super(skipCPUIf, self).__init__(dep, reason, device_type='cpu') | ||
|
|
||
|
|
||
| # Specifies a CUDA dependency. | ||
| class skipCUDAIf(skipIf): | ||
mruberry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def __init__(self, dep, reason): | ||
| super(skipCUDAIf, self).__init__(dep, reason, device_type='cuda') | ||
|
|
||
|
|
||
| # Specifies LAPACK as a CPU dependency. | ||
| def skipCPUIfNoLapack(fn): | ||
mruberry marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return skipCPUIf(torch._C.has_lapack, "PyTorch compiled without Lapack")(fn) | ||
|
|
||
|
|
||
| # Specifies MAGMA as a CUDA dependency. | ||
| def skipCUDAIfNoMagma(fn): | ||
| return skipCUDAIf('has_magma', "no MAGMA library detected")(fn) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are we blocked on just calling these "TestTorchCPU" and similar?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can call them whatever you like but there already is a distinct TestTorch class. Maybe there won't be after we refactor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good question, but there isn't a TestTorchCPU class in any case :).