Skip to content

Commit adc14ad

Browse files
janeyx99pytorchmergebot
authored andcommitted
Fix flakiness with test_binary_op_list_error_cases (#129003)
So how come this PR fixes any flakiness? Well, following my investigation (read pt 1 in the linked ghstack PR below), I had realized that this test only consistently errors after another test was found flaky. Why? Because TORCH_SHOW_CPP_STACKTRACES=1 gets turned on for _every_ test after _any_ test reruns, following this PR #119408. And yea, this test checked for exact error message matching, which no longer would match since the stacktrace for a foreach function is obviously going to be different from a nonforeach. So we improve the test. Pull Request resolved: #129003 Approved by: https://github.com/soulitzer
1 parent 61fa3de commit adc14ad

File tree

1 file changed

+14
-19
lines changed

1 file changed

+14
-19
lines changed

test/test_foreach.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -579,10 +579,6 @@ def test_binary_op_scalar_with_different_tensor_dtypes(self, device, dtype, op):
579579
filter(lambda op: op.supports_out, foreach_binary_op_db),
580580
dtypes=OpDTypes.supported,
581581
)
582-
@unittest.skipIf(
583-
torch.cuda.is_available() and not torch.cuda.get_device_capability(0) == (8, 6),
584-
"failing flakily on non sm86 cuda jobs, ex https://github.com/pytorch/pytorch/issues/125035",
585-
)
586582
def test_binary_op_list_error_cases(self, device, dtype, op):
587583
foreach_op, foreach_op_, ref, ref_ = (
588584
op.method_variant,
@@ -630,28 +626,27 @@ def test_binary_op_list_error_cases(self, device, dtype, op):
630626
# to be the same as torch regular function.
631627
tensors1 = [torch.zeros(10, 10, device=device, dtype=dtype) for _ in range(10)]
632628
tensors2 = [torch.ones(11, 11, device=device, dtype=dtype) for _ in range(10)]
633-
try:
629+
630+
if dtype == torch.bool and foreach_op == torch._foreach_sub:
631+
for fop in ops_to_test:
632+
with self.assertRaisesRegex(RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)):
633+
fop(tensors1, tensors2)
634+
return
635+
with self.assertRaisesRegex(
636+
RuntimeError,
637+
r"The size of tensor a \(10\) must match the size of tensor b \(11\) at non-singleton dimension 1",
638+
):
634639
foreach_op(tensors1, tensors2)
635-
except RuntimeError as e:
636-
with self.assertRaisesRegex(type(e), re.escape(str(e))):
637-
[ref(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
638-
try:
640+
with self.assertRaisesRegex(
641+
RuntimeError,
642+
r"The size of tensor a \(10\) must match the size of tensor b \(11\) at non-singleton dimension 1",
643+
):
639644
foreach_op_(tensors1, tensors2)
640-
except RuntimeError as e:
641-
with self.assertRaisesRegex(type(e), re.escape(str(e))):
642-
[ref_(t1, t2) for t1, t2 in zip(tensors1, tensors2)]
643645

644646
# different devices
645647
if self.device_type == "cuda" and torch.cuda.device_count() > 1:
646648
tensor1 = torch.zeros(10, 10, device="cuda:0", dtype=dtype)
647649
tensor2 = torch.ones(10, 10, device="cuda:1", dtype=dtype)
648-
if dtype == torch.bool and foreach_op == torch._foreach_sub:
649-
for fop in ops_to_test:
650-
with self.assertRaisesRegex(
651-
RuntimeError, re.escape(_BOOL_SUB_ERR_MSG)
652-
):
653-
fop([tensor1], [tensor2])
654-
return
655650
with self.assertRaisesRegex(
656651
RuntimeError, "Expected all tensors to be on the same device"
657652
):

0 commit comments

Comments
 (0)