Skip to content

Commit c44c4c0

Browse files
committed
fix lin issues
1 parent caa6a32 commit c44c4c0

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

test/test_tensor_creation_ops.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -345,11 +345,11 @@ def block_diag_workaround(*arrs):
345345

346346
if device != 'cpu':
347347
with self.assertRaisesRegex(
348-
RuntimeError,
349-
(
350-
"torch.block_diag: input tensors must all be on the same device."
351-
" Input 0 is on device cpu and input 1 is on device "
352-
)
348+
RuntimeError,
349+
(
350+
"torch.block_diag: input tensors must all be on the same device."
351+
" Input 0 is on device cpu and input 1 is on device "
352+
)
353353
):
354354
torch.block_diag(torch.ones(2, 2).cpu(), torch.ones(2, 2, device=device))
355355

@@ -474,8 +474,7 @@ def complex_dtype_name(dtype):
474474
out = torch.zeros(2, device=device, dtype=dtype)
475475
expected_dtype = torch.complex64 if dtype == torch.float32 else torch.complex128
476476
error = "Expected object of scalar type {} but got scalar type " \
477-
"{} for argument 'out'".format(
478-
complex_dtype_name(expected_dtype), dtype_name(dtype))
477+
"{} for argument 'out'".format(complex_dtype_name(expected_dtype), dtype_name(dtype))
479478
with self.assertRaisesRegex(RuntimeError, error):
480479
op(a, b, out=out)
481480

@@ -2995,8 +2994,10 @@ def test_logspace_special_steps(self, device, dtype):
29952994
self._test_logspace_base2(device, dtype, steps=steps)
29962995

29972996
@dtypes(*all_types_and(torch.bfloat16))
2998-
@dtypesIfCUDA(*integral_types_and(torch.half, torch.bfloat16, torch.float32, torch.float64) if TEST_WITH_ROCM else
2999-
all_types_and(torch.half, torch.bfloat16))
2997+
@dtypesIfCUDA(
2998+
*integral_types_and(torch.half, torch.bfloat16, torch.float32, torch.float64) if TEST_WITH_ROCM else
2999+
all_types_and(torch.half, torch.bfloat16)
3000+
)
30003001
def test_logspace(self, device, dtype):
30013002
_from = random.random()
30023003
to = _from + random.random()

0 commit comments

Comments
 (0)