Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions aten/src/THC/THCBlas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ double THCudaBlas_Ddot(THCState *state, int64_t n, double *x, int64_t incx, doub
}

#ifdef CUDA_HALF_TENSOR
float THCudaBlas_Hdot(THCState *state, int64_t n, half *x, int64_t incx, half *y, int64_t incy)
half THCudaBlas_Hdot(THCState *state, int64_t n, half *x, int64_t incx, half *y, int64_t incy)

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

{
#if CUDA_VERSION >= 8000
if (n == 1) {
Expand All @@ -58,22 +58,23 @@ float THCudaBlas_Hdot(THCState *state, int64_t n, half *x, int64_t incx, half *y
}

if ((n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX)) {
int i_n = (int)n;
int i_incx = (int)incx;
int i_incy = (int)incy;
float result;
half result;
cublasHandle_t handle = THCState_getCurrentBlasHandle(state);
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasDotEx(handle, i_n, x, CUDA_R_16F, i_incx, y, CUDA_R_16F, i_incy, &result, CUDA_R_32F, CUDA_R_32F));
THCublasCheck(cublasDotEx(handle, n,
x, CUDA_R_16F, incx,
y, CUDA_R_16F, incy,
&result, CUDA_R_16F,
CUDA_R_32F));
return result;
}
}

THError("Cublas_Hdot only supports n, incx and incy "
"up to signed integer limits: %d", INT_MAX);
return 0;
return THC_float2half(0);
#else
THError("Cublas_Hdot requires CUDA 8.0+");
return 0;
return THC_half2float(0);
#endif
}
#endif
Expand Down Expand Up @@ -360,7 +361,7 @@ void THCudaBlas_SgemmStridedBatched(THCState *state, char transa, char transb, i
float beta, float *c, int64_t ldc, int64_t strideC, int64_t batchCount)
{
if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )

{
THError("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount"
"with the bound [val] <= %d", INT_MAX);
Expand Down Expand Up @@ -420,7 +421,7 @@ void THCudaBlas_DgemmStridedBatched(THCState *state, char transa, char transb, i
cublasSetStream(handle, THCState_getCurrentStream(state));
THCublasCheck(cublasDgemmStridedBatched(handle,
opa, opb, (int)m, (int)n, (int)k,
&alpha, a, (int)lda, strideA, b, (int)ldb, strideB, &beta, c, (int)ldc, strideC,
&alpha, a, (int)lda, strideA, b, (int)ldb, strideB, &beta, c, (int)ldc, strideC,
(int)batchCount));
}
#endif
Expand Down
4 changes: 2 additions & 2 deletions aten/src/THC/THCBlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
THC_API float THCudaBlas_Sdot(THCState *state, int64_t n, float *x, int64_t incx, float *y, int64_t incy);
THC_API double THCudaBlas_Ddot(THCState *state, int64_t n, double *x, int64_t incx, double *y, int64_t incy);
#ifdef CUDA_HALF_TENSOR
THC_API float THCudaBlas_Hdot(THCState *state, int64_t n, half *x, int64_t incx, half *y, int64_t incy);
THC_API half THCudaBlas_Hdot(THCState *state, int64_t n, half *x, int64_t incx, half *y, int64_t incy);
#endif

/* Level 2 */
Expand Down Expand Up @@ -36,7 +36,7 @@ THC_API void THCudaBlas_SgemmStridedBatched(THCState *state, char transa, char t
float alpha, const float *a, int64_t lda, int64_t strideA, const float *b, int64_t ldb, int64_t strideB,
float beta, float *c, int64_t ldc, int64_t strideC, int64_t batchCount);
THC_API void THCudaBlas_DgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k,
double alpha, const double *a, int64_t lda, int64_t strideA, const double *b, int64_t ldb, int64_t strideB,
double alpha, const double *a, int64_t lda, int64_t strideA, const double *b, int64_t ldb, int64_t strideB,
double beta, double *c, int64_t ldc, int64_t strideC, int64_t batchCount);
#endif
/* Inverse */
Expand Down
19 changes: 10 additions & 9 deletions aten/src/THC/generic/THCTensorMathBlas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ THCTensor_(dot)(THCState *state, THCTensor *self, THCTensor *src)
THCTensor_(data)(state, self), 1,
THCTensor_(data)(state, src), 1);
#elif defined(THC_REAL_IS_HALF)
accreal result = THCudaBlas_Hdot(state,
accreal result = ScalarConvert<half, accreal>::to(
THCudaBlas_Hdot(state,
THCTensor_(nElement)(state, self),
THCTensor_(data)(state, self), 1,
THCTensor_(data)(state, src), 1);
THCTensor_(data)(state, src), 1));
#endif

THCTensor_(free)(state, src);
Expand Down Expand Up @@ -492,14 +493,14 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
ldc = result_->stride[2];
}

if (batch1->stride[transpose_result ? 2 : 1] == 1 &&
if (batch1->stride[transpose_result ? 2 : 1] == 1 &&
batch1->stride[transpose_result ? 1 : 2] != 0)
{
transpose_batch1 = 'n';
batch1_ = batch1;
lda = batch1_->stride[transpose_result ? 1 : 2];
}
else if (batch1->stride[transpose_result ? 1 : 2] == 1 &&
else if (batch1->stride[transpose_result ? 1 : 2] == 1 &&
batch1->stride[transpose_result ? 2 : 1] != 0)
{
transpose_batch1 = 't';
Expand All @@ -513,7 +514,7 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
lda = batch1_->stride[1];
}

if (batch2->stride[transpose_result ? 2 : 1] == 1 &&
if (batch2->stride[transpose_result ? 2 : 1] == 1 &&
batch2->stride[transpose_result ? 1 : 2] != 0)
{
transpose_batch2 = 'n';
Expand All @@ -537,7 +538,7 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,

#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE)
// Compute pointers to matrices in each batch.
#if CUDA_VERSION < 8000
#if CUDA_VERSION < 8000
size_t matrices_size = num_batches * sizeof(real*);

// Copy pointers to device.
Expand Down Expand Up @@ -592,7 +593,7 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
THCudaFree(state, d_matrices1);
THCudaFree(state, d_matrices2);
THCudaFree(state, d_result_matrices);

#else
#ifdef THC_REAL_IS_FLOAT
THCudaBlas_SgemmStridedBatched(
Expand All @@ -606,7 +607,7 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
THCTensor_(data)(state, batch1_), lda, batch1_->stride[0],
THCTensor_(data)(state, batch2_), ldb, batch2_->stride[0],
beta,
THCTensor_(data)(state, result_), ldc, result_->stride[0],
THCTensor_(data)(state, result_), ldc, result_->stride[0],
num_batches);
#elif defined(THC_REAL_IS_DOUBLE)
THCudaBlas_DgemmStridedBatched(
Expand All @@ -620,7 +621,7 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t,
THCTensor_(data)(state, batch1_), lda, batch1_->stride[0],
THCTensor_(data)(state, batch2_), ldb, batch2_->stride[0],
beta,
THCTensor_(data)(state, result_), ldc, result_->stride[0],
THCTensor_(data)(state, result_), ldc, result_->stride[0],
num_batches);
#endif
#endif
Expand Down
22 changes: 18 additions & 4 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,12 +384,18 @@ def get_cycles_per_ms():
return _cycles_per_ms


def compare_cpu_gpu(tensor_constructor, arg_constructor, fn, t, precision=1e-5):
def compare_cpu_gpu(tensor_constructor, arg_constructor, fn, t, precision=1e-5, force_gpu_half=False):
def tmp(self):
cpu_tensor = tensor_constructor(t)
gpu_tensor = to_gpu(cpu_tensor)
type_map = {}
if force_gpu_half:
type_map = {
'torch.FloatTensor': 'torch.cuda.HalfTensor',
'torch.DoubleTensor': 'torch.cuda.HalfTensor',
}
gpu_tensor = to_gpu(cpu_tensor, type_map)
cpu_args = arg_constructor(t)
gpu_args = [to_gpu(arg) for arg in cpu_args]
gpu_args = [to_gpu(arg, type_map) for arg in cpu_args]
cpu_result = getattr(cpu_tensor, fn)(*cpu_args)
try:
gpu_result = getattr(gpu_tensor, fn)(*gpu_args)
Expand Down Expand Up @@ -1099,7 +1105,15 @@ def test_nvtx(self):
test_name += '_' + desc

assert not hasattr(TestCuda, test_name), "Duplicated test name: " + test_name
setattr(TestCuda, test_name, compare_cpu_gpu(constr, arg_constr, name_inner, t, precision))
setattr(TestCuda,
test_name,
compare_cpu_gpu(constr, arg_constr, name_inner, t, precision))
if t == torch.FloatTensor:
assert not hasattr(TestCuda, test_name + '_gpu_half'), "Duplicated test name: " + test_name
setattr(TestCuda,
test_name + '_gpu_half',
compare_cpu_gpu(constr, arg_constr, name_inner, t,
precision, force_gpu_half=True))


if __name__ == '__main__':
Expand Down