Skip to content

Commit 474dec4

Browse files
Brennan Vincentfacebook-github-bot
authored andcommitted
Warn on conditions that can trigger cuBLAS sgemm bug (#22034)
Summary: The sgemm in cuBLAS 9.0 has some issues with sizes above 2M on Maxwell and Pascal architectures. Warn in this case. Pull Request resolved: #22034 Differential Revision: D15949930 Pulled By: zhangguanheng66 fbshipit-source-id: 0af977ec7900c76328d23898071de9c23778ff8b
1 parent f5b3f9e commit 474dec4

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

aten/src/THC/THCBlas.cu

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <ATen/cuda/CUDABlas.h>
66

77
#include <algorithm>
8+
#include <mutex>
89

910
float THCudaBlas_Sdot(THCState *state, int64_t n, float *x, int64_t incx, float *y, int64_t incy)
1011
{
@@ -189,6 +190,26 @@ void adjustLdLevel3(char transa, char transb, int64_t m, int64_t n, int64_t k, i
189190

190191
}
191192

193+
// Check https://github.com/pytorch/pytorch/issues/22078
194+
// for information about the bug. We don't know the exact conditions that trigger it,
195+
// but using Sgemm or Hgemm on Maxwell or Pascal seems to be a
196+
// necessary condition.
197+
static void checkCuda90Bug(int i_m, int i_n, int i_k)
198+
{
199+
#if CUDA_VERSION < 9200 && CUDA_VERSION >= 9000
200+
static std::once_flag alreadyWarned;
201+
const int LIMIT = 1 << 21;
202+
if (i_m > LIMIT || i_n > LIMIT || i_k > LIMIT) {
203+
cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
204+
if (prop->major == 5 || prop->major == 6) {
205+
std::call_once(alreadyWarned, []() {
206+
TORCH_WARN("Matrix multiplication for dimensions larger than 2^21 has known bugs on your combination of CUDA version and device type. Please consider upgrading to CUDA 9.2 or later.");
207+
});
208+
}
209+
}
210+
#endif
211+
}
212+
192213
/* Level 3 */
193214
void THCudaBlas_Sgemm(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, float alpha, float *a, int64_t lda, float *b, int64_t ldb, float beta, float *c, int64_t ldc)
194215
{

0 commit comments

Comments
 (0)