Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions aten/src/THC/generic/THCTensorMathBlas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
#define THC_GENERIC_FILE "generic/THCTensorMathBlas.cu"
#else

#define ERROR_ONLY_FP_TYPES(func) \
THError("%s for CUDA tensors only supports floating-point types. Try converting the tensors with .float()", func);

THC_API accreal
THCTensor_(dot)(THCState *state, THCTensor *self, THCTensor *src)
{
Expand Down Expand Up @@ -36,7 +39,7 @@ THCTensor_(dot)(THCState *state, THCTensor *self, THCTensor *src)
return result;

#else
THError("unimplemented data type");
ERROR_ONLY_FP_TYPES("dot");
return ScalarConvert<int, accreal>::to(0);
#endif
}
Expand Down Expand Up @@ -128,7 +131,7 @@ THCTensor_(addmv)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real
THCTensor_(free)(state, tAsMatrix);
#endif
#else
THError("unimplemented data type");
ERROR_ONLY_FP_TYPES("addmv");
#endif
}

Expand Down Expand Up @@ -221,7 +224,7 @@ THCTensor_(addr)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real a
THCTensor_(free)(state, vec1M);
#endif
#else
THError("unimplemented data type");
ERROR_ONLY_FP_TYPES("addr");
#endif
}

Expand Down Expand Up @@ -375,7 +378,7 @@ THCTensor_(addmm)(THCState *state, THCTensor *r_, real beta, THCTensor *t, real
THCTensor_(freeCopyTo)(state, r__, r_);
}
#else
THError("unimplemented data type");
ERROR_ONLY_FP_TYPES("addmm");
#endif
}

Expand Down Expand Up @@ -422,7 +425,7 @@ THCTensor_(addbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
THCTensor_(free)(state, slice1);
THCTensor_(free)(state, slice2);
#else
THError("unimplemented data type");
ERROR_ONLY_FP_TYPES("addbmm");
#endif
}

Expand Down Expand Up @@ -657,7 +660,7 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
}

#else
THError("unimplemented data type");
ERROR_ONLY_FP_TYPES("baddbmm");
#endif
}

Expand Down Expand Up @@ -758,7 +761,7 @@ THC_API void THCTensor_(btrifact)(THCState *state, THCTensor *ra_, THCudaIntTens
}

#else
THError("unimplemented data type");
THError("btrifact for CUDA tensors is only supported for floats and doubles");
#endif
}

Expand Down Expand Up @@ -877,7 +880,7 @@ THC_API void THCTensor_(btrisolve)(THCState *state, THCTensor *rb_, THCTensor *b
}

#else
THError("unimplemented data type");
THError("btrisolve for CUDA tensors is only supported for floats and doubles");
#endif
}

Expand Down
2 changes: 1 addition & 1 deletion test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def tmp(self):
gpu_result = getattr(gpu_tensor, fn)(*gpu_args)
except RuntimeError as e:
reason = e.args[0]
if 'unimplemented data type' in reason:
if 'only supports floating-point types' in reason or 'unimplemented data type' in reason:
raise unittest.SkipTest('unimplemented data type')
raise
except AttributeError as e:
Expand Down