Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
comment fixes per review
  • Loading branch information
Mike Ruberry committed Sep 14, 2019
commit e53e9637dcd0e19022edb6f0a9b51bfbb3cc761e
10 changes: 6 additions & 4 deletions test/common_device_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# (1) Only define test methods in the test class itself. Helper methods
# and non-methods must be inherited. This limitation is for Python2
# compatibility.
# (2) Each test method should have have the signature
# (2) Each test method should have the signature
# testX(self, device)
# The device argument will be a string like 'cpu' or 'cuda.'
# (3) Prefer using test decorators defined in this file to others.
Expand Down Expand Up @@ -87,6 +87,8 @@ class DeviceTypeTestBase(TestCase):
def instantiate_test(cls, test):
test_name = test.__name__ + "_" + cls.device_type

assert not hasattr(cls, test_name), "Redefinition of test {0}".format(test_name)

@wraps(test)
def instantiated_test(self, test=test):
return test(self, cls.device_type)
Expand Down Expand Up @@ -121,7 +123,7 @@ def setUpClass(cls):
# generic_test_class.
# See note "Generic Device Type Testing."
def instantiate_device_type_tests(generic_test_class, scope):
# Removes the generic test class from its enclosing scope so it's tests
# Removes the generic test class from its enclosing scope so its tests
# are not discoverable.
del scope[generic_test_class.__name__]

Expand Down Expand Up @@ -177,7 +179,7 @@ def instantiate_device_type_tests(generic_test_class, scope):
# (3) Prefer the existing decorators to defining the 'device_type' kwarg.
class skipIf(object):

def __init__(self, dep, reason, device_type='all'):
def __init__(self, dep, reason, device_type=None):
self.dep = dep
self.reason = reason
self.device_type = device_type
Expand All @@ -186,7 +188,7 @@ def __call__(self, fn):

@wraps(fn)
def dep_fn(slf, device, *args, **kwargs):
if self.device_type == 'all' or self.device_type == slf.device_type:
if self.device_type is None or self.device_type == slf.device_type:
if not self.dep or (isinstance(self.dep, str) and not getattr(slf, self.dep, False)):
raise unittest.SkipTest(self.reason)

Expand Down
2 changes: 1 addition & 1 deletion test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13326,7 +13326,7 @@ def add_neg_dim_tests():
assert not hasattr(_TestTorchMixin, test_name), "Duplicated test name: " + test_name
setattr(_TestTorchMixin, test_name, make_neg_dim_test(name, tensor_arg, arg_constr, types, extra_dim))

# Device-agnostic tests. Instantiated below and not run directly.
# Device-generic tests. Instantiated below and not run directly.
class TestTorchDeviceType(TestCase):
def test_diagonal(self, device):
x = torch.randn((100, 100), device=device)
Expand Down