Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions aten/src/THC/THCBlas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <ATen/cuda/CUDABlas.h>

#include <algorithm>
#include <mutex>

float THCudaBlas_Sdot(THCState *state, int64_t n, float *x, int64_t incx, float *y, int64_t incy)
{
Expand Down Expand Up @@ -189,6 +190,26 @@ void adjustLdLevel3(char transa, char transb, int64_t m, int64_t n, int64_t k, i

}

// Check https://github.com/pytorch/pytorch/issues/22078
// for information about the bug. We don't know the exact conditions that trigger it,
// but using Sgemm or Hgemm on Maxwell or Pascal seems to be a
// necessary condition.
static void checkCuda90Bug(int i_m, int i_n, int i_k)
{
#if CUDA_VERSION < 9200 && CUDA_VERSION >= 9000
static std::once_flag alreadyWarned;
const int LIMIT = 1 << 21;
if (i_m > LIMIT || i_n > LIMIT || i_k > LIMIT) {
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
if (prop->major == 5 || prop->major == 6) {
std::call_once(alreadyWarned, []() {
TORCH_WARN("Matrix multiplication for dimensions larger than 2^21 has known bugs on your combination of CUDA version and device type. Please consider upgrading to CUDA 9.2 or later.");
});
}
}
#endif
}

/* Level 3 */
void THCudaBlas_Sgemm(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, float alpha, float *a, int64_t lda, float *b, int64_t ldb, float beta, float *c, int64_t ldc)
{
Expand Down