Skip to content

Commit dac5e65

Browse files
zou3519ezyang
authored andcommitted
Better error messages for blas ops with cuda.LongTensor (#4160)
* Better error messages for blas ops with cuda.LongTensor Fixes #4157 Test plan Try matrix multiplying with cuda.LongTensors >>> import torch >>> x = torch.randn(4, 4).long().cuda() >>> y = torch.randn(4, 4).long().cuda() >>> x.mm(y) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: addmm for CUDA tensors only supports floating-point types. Try converting the tensors with .flo at() at /private/home/rzou/pytorch/pytorch/aten/src/THC/generic/THCTensorMathBlas.cu:381
1 parent 16b7f3a commit dac5e65

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

aten/src/THC/generic/THCTensorMathBlas.cu

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
#define THC_GENERIC_FILE "generic/THCTensorMathBlas.cu"
33
#else
44

5+
#define ERROR_ONLY_FP_TYPES(func) \
6+
THError("%s for CUDA tensors only supports floating-point types. Try converting the tensors with .float()", func);
7+
58
THC_API accreal
69
THCTensor_(dot)(THCState *state, THCTensor *self, THCTensor *src)
710
{
@@ -36,7 +39,7 @@ THCTensor_(dot)(THCState *state, THCTensor *self, THCTensor *src)
3639
return result;
3740

3841
#else
39-
THError("unimplemented data type");
42+
ERROR_ONLY_FP_TYPES("dot");
4043
return ScalarConvert<int, accreal>::to(0);
4144
#endif
4245
}
@@ -128,7 +131,7 @@ THCTensor_(addmv)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real
128131
THCTensor_(free)(state, tAsMatrix);
129132
#endif
130133
#else
131-
THError("unimplemented data type");
134+
ERROR_ONLY_FP_TYPES("addmv");
132135
#endif
133136
}
134137

@@ -221,7 +224,7 @@ THCTensor_(addr)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real a
221224
THCTensor_(free)(state, vec1M);
222225
#endif
223226
#else
224-
THError("unimplemented data type");
227+
ERROR_ONLY_FP_TYPES("addr");
225228
#endif
226229
}
227230

@@ -375,7 +378,7 @@ THCTensor_(addmm)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real
375378
THCTensor_(freeCopyTo)(state, r__, r_);
376379
}
377380
#else
378-
THError("unimplemented data type");
381+
ERROR_ONLY_FP_TYPES("addmm");
379382
#endif
380383
}
381384

@@ -422,7 +425,7 @@ THCTensor_(addbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
422425
THCTensor_(free)(state, slice1);
423426
THCTensor_(free)(state, slice2);
424427
#else
425-
THError("unimplemented data type");
428+
ERROR_ONLY_FP_TYPES("addbmm");
426429
#endif
427430
}
428431

@@ -657,7 +660,7 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
657660
}
658661

659662
#else
660-
THError("unimplemented data type");
663+
ERROR_ONLY_FP_TYPES("baddbmm");
661664
#endif
662665
}
663666

@@ -758,7 +761,7 @@ THC_API void THCTensor_(btrifact)(THCState *state, THCTensor *ra_, THCudaIntTens
758761
}
759762

760763
#else
761-
THError("unimplemented data type");
764+
THError("btrifact for CUDA tensors is only supported for floats and doubles");
762765
#endif
763766
}
764767

@@ -877,7 +880,7 @@ THC_API void THCTensor_(btrisolve)(THCState *state, THCTensor *rb_, THCTensor *b
877880
}
878881

879882
#else
880-
THError("unimplemented data type");
883+
THError("btrisolve for CUDA tensors is only supported for floats and doubles");
881884
#endif
882885
}
883886

test/test_cuda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ def tmp(self):
402402
gpu_result = getattr(gpu_tensor, fn)(*gpu_args)
403403
except RuntimeError as e:
404404
reason = e.args[0]
405-
if 'unimplemented data type' in reason:
405+
if 'only supports floating-point types' in reason or 'unimplemented data type' in reason:
406406
raise unittest.SkipTest('unimplemented data type')
407407
raise
408408
except AttributeError as e:

0 commit comments

Comments
 (0)