Skip to content

Commit 08cd7ca

Browse files
Refactor skip logic for test_all_reduce to allow float32 version to run
1 parent 6b5125b commit 08cd7ca

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

test/distributed/test_nccl.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,13 @@ def test_reduce(self, device, dtype):
6262

6363
self.assertEqual(tensors[0], expected)
6464

65-
@unittest.skipIf(TEST_WITH_ROCM and HIP_VERSION < 3.5 and dtype == torch.bfloat16,
66-
"Skip bfloat16 testing for ROCm versions before 3.5")
6765
@unittest.skipIf(IS_WINDOWS, "NCCL doesn't support Windows")
6866
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
6967
@dtypes(*datatypes)
7068
def test_all_reduce(self, device, dtype):
69+
if TEST_WITH_ROCM and HIP_VERSION < 3.5 and dtype == torch.bfloat16:
70+
raise unittest.SkipTest("Skip bfloat16 test for ROCm < 3.5")
71+
7172
tensors = [torch.zeros(128).uniform_().to(dtype=dtype) for i in range(nGPUs)]
7273
expected = torch.zeros(128, dtype=dtype)
7374
for t in tensors:

0 commit comments

Comments
 (0)