Skip to content

Commit 85e22b5

Browse files
authored
Reverts force_gpu_half changes from #3660 (#5000)
The test_cuda.py setup purports to test half tensors, but actually just re-tests FloatTensors because the keys in type_map were str instead of type. Testing HalfTensors is more complicated, requiring changes to precision and requires excluding some unimplemented methods. We should fully test half CUDA tensors. This change just deletes the duplicate tests of FloatTensor.
1 parent 3e85613 commit 85e22b5

File tree

1 file changed

+3
-15
lines changed

1 file changed

+3
-15
lines changed

test/test_cuda.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -406,18 +406,12 @@ def get_cycles_per_ms():
406406
return _cycles_per_ms
407407

408408

409-
def compare_cpu_gpu(tensor_constructor, arg_constructor, fn, t, precision=1e-5, force_gpu_half=False):
409+
def compare_cpu_gpu(tensor_constructor, arg_constructor, fn, t, precision=1e-5):
410410
def tmp(self):
411411
cpu_tensor = tensor_constructor(t)
412-
type_map = {}
413-
if force_gpu_half:
414-
type_map = {
415-
'torch.FloatTensor': 'torch.cuda.HalfTensor',
416-
'torch.DoubleTensor': 'torch.cuda.HalfTensor',
417-
}
418-
gpu_tensor = to_gpu(cpu_tensor, type_map)
412+
gpu_tensor = to_gpu(cpu_tensor)
419413
cpu_args = arg_constructor(t)
420-
gpu_args = [to_gpu(arg, type_map) for arg in cpu_args]
414+
gpu_args = [to_gpu(arg) for arg in cpu_args]
421415
cpu_result = getattr(cpu_tensor, fn)(*cpu_args)
422416
try:
423417
gpu_result = getattr(gpu_tensor, fn)(*gpu_args)
@@ -1407,12 +1401,6 @@ def test_nvtx(self):
14071401
setattr(TestCuda,
14081402
test_name,
14091403
compare_cpu_gpu(constr, arg_constr, name_inner, t, precision))
1410-
if t == torch.FloatTensor and not IS_WINDOWS: # CUDA HalfTensor currently doesn't work on Windows
1411-
assert not hasattr(TestCuda, test_name + '_gpu_half'), "Duplicated test name: " + test_name
1412-
setattr(TestCuda,
1413-
test_name + '_gpu_half',
1414-
compare_cpu_gpu(constr, arg_constr, name_inner, t,
1415-
precision, force_gpu_half=True))
14161404

14171405

14181406
if __name__ == '__main__':

0 commit comments

Comments
 (0)