@@ -384,12 +384,18 @@ def get_cycles_per_ms():
384384 return _cycles_per_ms
385385
386386
387- def compare_cpu_gpu (tensor_constructor , arg_constructor , fn , t , precision = 1e-5 ):
387+ def compare_cpu_gpu (tensor_constructor , arg_constructor , fn , t , precision = 1e-5 , force_gpu_half = False ):
388388 def tmp (self ):
389389 cpu_tensor = tensor_constructor (t )
390- gpu_tensor = to_gpu (cpu_tensor )
390+ type_map = {}
391+ if force_gpu_half :
392+ type_map = {
393+ 'torch.FloatTensor' : 'torch.cuda.HalfTensor' ,
394+ 'torch.DoubleTensor' : 'torch.cuda.HalfTensor' ,
395+ }
396+ gpu_tensor = to_gpu (cpu_tensor , type_map )
391397 cpu_args = arg_constructor (t )
392- gpu_args = [to_gpu (arg ) for arg in cpu_args ]
398+ gpu_args = [to_gpu (arg , type_map ) for arg in cpu_args ]
393399 cpu_result = getattr (cpu_tensor , fn )(* cpu_args )
394400 try :
395401 gpu_result = getattr (gpu_tensor , fn )(* gpu_args )
@@ -1099,7 +1105,15 @@ def test_nvtx(self):
10991105 test_name += '_' + desc
11001106
11011107 assert not hasattr (TestCuda , test_name ), "Duplicated test name: " + test_name
1102- setattr (TestCuda , test_name , compare_cpu_gpu (constr , arg_constr , name_inner , t , precision ))
1108+ setattr (TestCuda ,
1109+ test_name ,
1110+ compare_cpu_gpu (constr , arg_constr , name_inner , t , precision ))
1111+ if t == torch .FloatTensor :
1112+ assert not hasattr (TestCuda , test_name + '_gpu_half' ), "Duplicated test name: " + test_name
1113+ setattr (TestCuda ,
1114+ test_name + '_gpu_half' ,
1115+ compare_cpu_gpu (constr , arg_constr , name_inner , t ,
1116+ precision , force_gpu_half = True ))
11031117
11041118
11051119if __name__ == '__main__' :
0 commit comments