Skip to content

Commit bc1b4c8

Browse files
cpuhrschezyang
authored andcommitted
ByteTensor sum test (#6042)
1 parent 60a16e5 commit bc1b4c8

File tree

1 file changed

+20
-20
lines changed

1 file changed

+20
-20
lines changed

test/test_torch.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -438,52 +438,52 @@ def test_min(self):
438438
def _test_dim_reduction(self, cast):
439439
example = [[-1, 2, 1], [5, 3, 6]]
440440

441-
types = ['torch.DoubleTensor',
442-
'torch.FloatTensor',
443-
'torch.LongTensor',
444-
'torch.IntTensor',
445-
'torch.ShortTensor',
446-
'torch.ByteTensor']
441+
types = [torch.double,
442+
torch.float,
443+
torch.int64,
444+
torch.int32,
445+
torch.int16,
446+
torch.uint8]
447447

448448
# This won't test for 256bit instructions, since we usually
449449
# only work on 1 cacheline (1024bit) at a time and these
450450
# examples aren't big enough to trigger that.
451-
for tname in types:
452-
x = cast(torch.FloatTensor(example).type(tname))
451+
for dtype in types:
452+
x = cast(torch.tensor(example, dtype=dtype))
453453
self.assertEqual(x.sum().item(), 16)
454454
self.assertEqual(x.sum(0), torch.FloatTensor([4, 5, 7]))
455455
self.assertEqual(x.sum(1), torch.FloatTensor([2, 14]))
456-
y = cast(torch.FloatTensor(example).type(tname))
456+
y = cast(torch.tensor(example, dtype=dtype))
457457
torch.sum(x, 0, out=y)
458458
self.assertEqual(x.sum(0), y)
459459

460460
# Mean not supported for Int types
461-
for tname in types[:2]:
462-
x = cast(torch.FloatTensor(example).type(tname))
461+
for dtype in types[:2]:
462+
x = cast(torch.tensor(example, dtype=dtype))
463463
self.assertEqual(x.mean().item(), 16.0 / 6)
464464
self.assertEqual(x.mean(0), torch.FloatTensor([2.0, 2.5, 7.0 / 2]))
465465
self.assertEqual(x.mean(1), torch.FloatTensor([2.0 / 3, 14.0 / 3]))
466466

467-
for tname in types:
468-
if tname == 'torch.ByteTensor': # Overflows
467+
for dtype in types:
468+
if dtype == torch.uint8: # Overflows
469469
continue
470-
x = cast(torch.FloatTensor(example).type(tname))
470+
x = cast(torch.tensor(example, dtype=dtype))
471471
self.assertEqual(x.prod().item(), -180)
472472
self.assertEqual(x.prod(0), torch.FloatTensor([-5, 6, 6]))
473473
self.assertEqual(x.prod(1), torch.FloatTensor([-2, 90]))
474474

475-
for tname in types:
476-
if tname == 'torch.ByteTensor': # Doesn't support negative values
475+
for dtype in types:
476+
if dtype == torch.uint8: # Doesn't support negative values
477477
continue
478-
x = cast(torch.FloatTensor(example).type(tname))
478+
x = cast(torch.tensor(example, dtype=dtype))
479479
self.assertEqual(x.max().item(), 6)
480480
self.assertEqual(x.max(0), (torch.FloatTensor([5, 3, 6]), torch.FloatTensor([1, 1, 1])))
481481
self.assertEqual(x.max(1), (torch.FloatTensor([2, 6]), torch.FloatTensor([1, 2])))
482482

483-
for tname in types:
484-
if tname == 'torch.ByteTensor': # Doesn't support negative values
483+
for dtype in types:
484+
if dtype == torch.uint8: # Doesn't support negative values
485485
continue
486-
x = cast(torch.FloatTensor(example).type(tname))
486+
x = cast(torch.tensor(example, dtype=dtype))
487487
self.assertEqual(x.min().item(), -1)
488488
self.assertEqual(x.min(0), (torch.FloatTensor([-1, 2, 1]), torch.FloatTensor([0, 0, 0])))
489489
self.assertEqual(x.min(1), (torch.FloatTensor([-1, 3]), torch.FloatTensor([0, 1])))

0 commit comments

Comments
 (0)