Skip to content

Commit f17f2bf

Browse files
committed
Enable FloatTensor <-> CUDA HalfTensor checks in test_cuda.py
1 parent 2bd779c commit f17f2bf

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

test/test_cuda.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -384,12 +384,18 @@ def get_cycles_per_ms():
384384
return _cycles_per_ms
385385

386386

387-
def compare_cpu_gpu(tensor_constructor, arg_constructor, fn, t, precision=1e-5):
387+
def compare_cpu_gpu(tensor_constructor, arg_constructor, fn, t, precision=1e-5, force_gpu_half=False):
388388
def tmp(self):
389389
cpu_tensor = tensor_constructor(t)
390-
gpu_tensor = to_gpu(cpu_tensor)
390+
type_map = {}
391+
if force_gpu_half:
392+
type_map = {
393+
'torch.FloatTensor': 'torch.cuda.HalfTensor',
394+
'torch.DoubleTensor': 'torch.cuda.HalfTensor',
395+
}
396+
gpu_tensor = to_gpu(cpu_tensor, type_map)
391397
cpu_args = arg_constructor(t)
392-
gpu_args = [to_gpu(arg) for arg in cpu_args]
398+
gpu_args = [to_gpu(arg, type_map) for arg in cpu_args]
393399
cpu_result = getattr(cpu_tensor, fn)(*cpu_args)
394400
try:
395401
gpu_result = getattr(gpu_tensor, fn)(*gpu_args)
@@ -1099,7 +1105,15 @@ def test_nvtx(self):
10991105
test_name += '_' + desc
11001106

11011107
assert not hasattr(TestCuda, test_name), "Duplicated test name: " + test_name
1102-
setattr(TestCuda, test_name, compare_cpu_gpu(constr, arg_constr, name_inner, t, precision))
1108+
setattr(TestCuda,
1109+
test_name,
1110+
compare_cpu_gpu(constr, arg_constr, name_inner, t, precision))
1111+
if t == torch.FloatTensor:
1112+
assert not hasattr(TestCuda, test_name + '_gpu_half'), "Duplicated test name: " + test_name
1113+
setattr(TestCuda,
1114+
test_name + '_gpu_half',
1115+
compare_cpu_gpu(constr, arg_constr, name_inner, t,
1116+
precision, force_gpu_half=True))
11031117

11041118

11051119
if __name__ == '__main__':

0 commit comments

Comments
 (0)