Skip to content

Commit 2e1f834

Browse files
committed
Enable FloatTensor <-> CUDA HalfTensor checks in test_cuda.py
1 parent 2bd779c commit 2e1f834

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

test/test_cuda.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -384,12 +384,18 @@ def get_cycles_per_ms():
384384
return _cycles_per_ms
385385

386386

387-
def compare_cpu_gpu(tensor_constructor, arg_constructor, fn, t, precision=1e-5):
387+
def compare_cpu_gpu(tensor_constructor, arg_constructor, fn, t, precision=1e-5, force_gpu_half=False):
388388
def tmp(self):
389389
cpu_tensor = tensor_constructor(t)
390-
gpu_tensor = to_gpu(cpu_tensor)
390+
type_map = {}
391+
if force_gpu_half:
392+
type_map = {
393+
'torch.FloatTensor': 'torch.cuda.HalfTensor',
394+
'torch.DoubleTensor': 'torch.cuda.HalfTensor',
395+
}
396+
gpu_tensor = to_gpu(cpu_tensor, type_map)
391397
cpu_args = arg_constructor(t)
392-
gpu_args = [to_gpu(arg) for arg in cpu_args]
398+
gpu_args = [to_gpu(arg, type_map) for arg in cpu_args]
393399
cpu_result = getattr(cpu_tensor, fn)(*cpu_args)
394400
try:
395401
gpu_result = getattr(gpu_tensor, fn)(*gpu_args)
@@ -1100,6 +1106,8 @@ def test_nvtx(self):
11001106

11011107
assert not hasattr(TestCuda, test_name), "Duplicated test name: " + test_name
11021108
setattr(TestCuda, test_name, compare_cpu_gpu(constr, arg_constr, name_inner, t, precision))
1109+
if t == torch.FloatTensor:
1110+
setattr(TestCuda, test_name + '_gpu_half', compare_cpu_gpu(constr, arg_constr, name_inner, t, precision, force_gpu_half=True))
11031111

11041112

11051113
if __name__ == '__main__':

0 commit comments

Comments
 (0)