Skip to content

Commit 560c524

Browse files
committed
Update test case for any/all with hardcode input matrix.
Signed-off-by: HE, Tao <sighingnow@gmail.com>
1 parent 6286ce7 commit 560c524

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

test/test_torch.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -463,18 +463,20 @@ def test_all_any_empty(self):
463463

464464
def test_all_any_with_dim(self):
465465
def test(x):
466-
r1 = x.prod(dim=1, keepdim=False)
467-
r2 = x.all(dim=1, keepdim=False)
466+
r1 = x.prod(dim=0, keepdim=False)
467+
r2 = x.all(dim=0, keepdim=False)
468468
self.assertEqual(r1.shape, r2.shape)
469469
self.assertTrue((r1 == r2).all())
470470

471-
r3 = x.sum(dim=2, keepdim=True).clamp(0, 1)
472-
r4 = x.any(dim=2, keepdim=True)
471+
r3 = x.sum(dim=1, keepdim=True).clamp(0, 1)
472+
r4 = x.any(dim=1, keepdim=True)
473473
self.assertEqual(r3.shape, r4.shape)
474474
self.assertTrue((r3 == r4).all())
475475

476-
test(torch.rand((1, 2, 3, 4)).round().byte())
477-
test(torch.rand((4, 3, 2, 1)).round().byte())
476+
test(torch.ByteTensor([[0, 0, 0],
477+
[0, 0, 1],
478+
[0, 1, 1],
479+
[1, 1, 1]]))
478480

479481
@unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
480482
def test_all_any_empty_cuda(self):

0 commit comments

Comments
 (0)