We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 6286ce7 commit 560c524Copy full SHA for 560c524
test/test_torch.py
@@ -463,18 +463,20 @@ def test_all_any_empty(self):
463
464
def test_all_any_with_dim(self):
465
def test(x):
466
- r1 = x.prod(dim=1, keepdim=False)
467
- r2 = x.all(dim=1, keepdim=False)
+ r1 = x.prod(dim=0, keepdim=False)
+ r2 = x.all(dim=0, keepdim=False)
468
self.assertEqual(r1.shape, r2.shape)
469
self.assertTrue((r1 == r2).all())
470
471
- r3 = x.sum(dim=2, keepdim=True).clamp(0, 1)
472
- r4 = x.any(dim=2, keepdim=True)
+ r3 = x.sum(dim=1, keepdim=True).clamp(0, 1)
+ r4 = x.any(dim=1, keepdim=True)
473
self.assertEqual(r3.shape, r4.shape)
474
self.assertTrue((r3 == r4).all())
475
476
- test(torch.rand((1, 2, 3, 4)).round().byte())
477
- test(torch.rand((4, 3, 2, 1)).round().byte())
+ test(torch.ByteTensor([[0, 0, 0],
+ [0, 0, 1],
478
+ [0, 1, 1],
479
+ [1, 1, 1]]))
480
481
@unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
482
def test_all_any_empty_cuda(self):
0 commit comments