@@ -438,52 +438,52 @@ def test_min(self):
438438 def _test_dim_reduction (self , cast ):
439439 example = [[- 1 , 2 , 1 ], [5 , 3 , 6 ]]
440440
441- types = [' torch.DoubleTensor' ,
442- ' torch.FloatTensor' ,
443- ' torch.LongTensor' ,
444- ' torch.IntTensor' ,
445- ' torch.ShortTensor' ,
446- ' torch.ByteTensor' ]
441+ types = [torch .double ,
442+ torch .float ,
443+ torch .int64 ,
444+ torch .int32 ,
445+ torch .int16 ,
446+ torch .uint8 ]
447447
448448 # This won't test for 256bit instructions, since we usually
449449 # only work on 1 cacheline (1024bit) at a time and these
450450 # examples aren't big enough to trigger that.
451- for tname in types :
452- x = cast (torch .FloatTensor (example ). type ( tname ))
451+ for dtype in types :
452+ x = cast (torch .tensor (example , dtype = dtype ))
453453 self .assertEqual (x .sum ().item (), 16 )
454454 self .assertEqual (x .sum (0 ), torch .FloatTensor ([4 , 5 , 7 ]))
455455 self .assertEqual (x .sum (1 ), torch .FloatTensor ([2 , 14 ]))
456- y = cast (torch .FloatTensor (example ). type ( tname ))
456+ y = cast (torch .tensor (example , dtype = dtype ))
457457 torch .sum (x , 0 , out = y )
458458 self .assertEqual (x .sum (0 ), y )
459459
460460 # Mean not supported for Int types
461- for tname in types [:2 ]:
462- x = cast (torch .FloatTensor (example ). type ( tname ))
461+ for dtype in types [:2 ]:
462+ x = cast (torch .tensor (example , dtype = dtype ))
463463 self .assertEqual (x .mean ().item (), 16.0 / 6 )
464464 self .assertEqual (x .mean (0 ), torch .FloatTensor ([2.0 , 2.5 , 7.0 / 2 ]))
465465 self .assertEqual (x .mean (1 ), torch .FloatTensor ([2.0 / 3 , 14.0 / 3 ]))
466466
467- for tname in types :
468- if tname == ' torch.ByteTensor' : # Overflows
467+ for dtype in types :
468+ if dtype == torch .uint8 : # Overflows
469469 continue
470- x = cast (torch .FloatTensor (example ). type ( tname ))
470+ x = cast (torch .tensor (example , dtype = dtype ))
471471 self .assertEqual (x .prod ().item (), - 180 )
472472 self .assertEqual (x .prod (0 ), torch .FloatTensor ([- 5 , 6 , 6 ]))
473473 self .assertEqual (x .prod (1 ), torch .FloatTensor ([- 2 , 90 ]))
474474
475- for tname in types :
476- if tname == ' torch.ByteTensor' : # Doesn't support negative values
475+ for dtype in types :
476+ if dtype == torch .uint8 : # Doesn't support negative values
477477 continue
478- x = cast (torch .FloatTensor (example ). type ( tname ))
478+ x = cast (torch .tensor (example , dtype = dtype ))
479479 self .assertEqual (x .max ().item (), 6 )
480480 self .assertEqual (x .max (0 ), (torch .FloatTensor ([5 , 3 , 6 ]), torch .FloatTensor ([1 , 1 , 1 ])))
481481 self .assertEqual (x .max (1 ), (torch .FloatTensor ([2 , 6 ]), torch .FloatTensor ([1 , 2 ])))
482482
483- for tname in types :
484- if tname == ' torch.ByteTensor' : # Doesn't support negative values
483+ for dtype in types :
484+ if dtype == torch .uint8 : # Doesn't support negative values
485485 continue
486- x = cast (torch .FloatTensor (example ). type ( tname ))
486+ x = cast (torch .tensor (example , dtype = dtype ))
487487 self .assertEqual (x .min ().item (), - 1 )
488488 self .assertEqual (x .min (0 ), (torch .FloatTensor ([- 1 , 2 , 1 ]), torch .FloatTensor ([0 , 0 , 0 ])))
489489 self .assertEqual (x .min (1 ), (torch .FloatTensor ([- 1 , 3 ]), torch .FloatTensor ([0 , 1 ])))
0 commit comments