Skip to content

Conversation

@zhaojuanmao
Copy link
Contributor

addmm_cuda_lt failed for some corner cases, so far we can not reproduce the corner cases in the unit tests, seems that the failures do not only depend on matrices' shape and strides. For now, add an environment variable to allow users disable this kernel for such corner cases.

See the case one with more error logs:

RuntimeError: 0CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul with transpose_mat1 1 transpose_mat2 0 m 80 n 1024 k 160 mat1_ld 160 mat2_ld 160 result_ld 80 abcType 14 computeType 68 scaleType 0 result_shape 1024 80 result_stride 80 1 self_shape 80 self_stride 1 mat1_shape 1024 160 mat1_stride 160 1 mat2_shape 160 80 mat2_stride 1 160
Exception raised from gemm_and_bias at fbcode/caffe2/aten/src/ATen/cuda/CUDABlas.cpp:1071 (most recent call first):

another case with more error logs:

RuntimeError: 0CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul with transpose_mat1 1 transpose_mat2 0 m 16 n 16384 k 48 mat1_ld 48 mat2_ld 48 result_ld 16 abcType 14 computeType 68 scaleType 0 result_shape 16384 16 result_stride 16 1 self_shape 16 self_stride 1 mat1_shape 16384 48 mat1_stride 48 1 mat2_shape 48 16 mat2_stride 1 48
Exception raised from gemm_and_bias at fbcode/caffe2/aten/src/ATen/cuda/CUDABlas.cpp:1071 (most recent call first):

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 28, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/91436

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d9e2714:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@zhaojuanmao zhaojuanmao force-pushed the disableaddmmcudalt branch 2 times, most recently from 8b46bdc to d54aa3e Compare December 28, 2022 19:07
@zhaojuanmao zhaojuanmao requested a review from ptrblck December 28, 2022 19:08
@ptrblck
Copy link
Collaborator

ptrblck commented Dec 28, 2022

addmm_cuda_lt failed for some corner cases, so far we can not reproduce the corner cases in the unit tests, seems that the failures do not only depend on matrices' shape and strides.

This might also mean that cublas is running into a sticky error / corrupt CUDA context and is just the victim.
What are these workloads and did you try to launch them via CUDA_LAUNCH_BLOCKING=1 or in a compute-sanitizer run?

@zhaojuanmao
Copy link
Contributor Author

addmm_cuda_lt failed for some corner cases, so far we can not reproduce the corner cases in the unit tests, seems that the failures do not only depend on matrices' shape and strides.

This might also mean that cublas is running into a sticky error / corrupt CUDA context and is just the victim. What are these workloads and did you try to launch them via CUDA_LAUNCH_BLOCKING=1 or in a compute-sanitizer run?

Thank you @ptrblck, I hard coded to disable 'addmm_cuda_lt' kernel, the training went through, so I think it should be related to 'addmm_cuda_lt' kernel?

I can run with CUDA_LAUNCH_BLOCKING=1 later on, what does compute-sanitizer run mean?

@zhaojuanmao
Copy link
Contributor Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased disableaddmmcudalt onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout disableaddmmcudalt && git pull --rebase)

@facebook-github-bot
Copy link
Contributor

@zhaojuanmao has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@ptrblck
Copy link
Collaborator

ptrblck commented Dec 29, 2022

I can run with CUDA_LAUNCH_BLOCKING=1 later on, what does compute-sanitizer run mean?

compute-sanitizer would allow you to check for memory violations, race conditions etc.
Could you post the workload which creates the issue, so that I could also try to reproduce and debug it?
cublas could certainly fail, but I would like to get more information about the failure case while this workaround is used.

@zhaojuanmao
Copy link
Contributor Author

The dynamo unit test failures are not related

@zhaojuanmao
Copy link
Contributor Author

I can run with CUDA_LAUNCH_BLOCKING=1 later on, what does compute-sanitizer run mean?

compute-sanitizer would allow you to check for memory violations, race conditions etc. Could you post the workload which creates the issue, so that I could also try to reproduce and debug it? cublas could certainly fail, but I would like to get more information about the failure case while this workaround is used.

It is internal workload that is hard to rewrite them in OSS though.

What I can do is to get extra error logs, so far we can reproduce 3 failure cases for this workload , I added error log inside CUDABlas.cpp when the error is hit:

Failure case 1

RuntimeError: 0CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul with transpose_mat1 1 transpose_mat2 0 m 80 n 1024 k 160 mat1_ld 160 mat2_ld 160 result_ld 80 abcType 14 computeType 68 scaleType 0 result_shape 1024 80 result_stride 80 1 self_shape 80 self_stride 1 mat1_shape 1024 160 mat1_stride 160 1 mat2_shape 160 80 mat2_stride 1 160
Exception raised from gemm_and_bias at fbcode/caffe2/aten/src/ATen/cuda/CUDABlas.cpp:1071 (most recent call first):

Failure case 2:

RuntimeError: 0CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul with transpose_mat1 1 transpose_mat2 0 m 16 n 16384 k 48 mat1_ld 48 mat2_ld 48 result_ld 16 abcType 14 computeType 68 scaleType 0 result_shape 16384 16 result_stride 16 1 self_shape 16 self_stride 1 mat1_shape 16384 48 mat1_stride 48 1 mat2_shape 48 16 mat2_stride 1 48
Exception raised from gemm_and_bias at fbcode/caffe2/aten/src/ATen/cuda/CUDABlas.cpp:1071 (most recent call first):

Failure case 3:

RuntimeError: CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul with transpose_mat1 1 transpose_mat2 0 m 16 n 327680 k 16 mat1_ld 16 mat2_ld 16 result_ld 16 abcType 14 computeType 68 scaleType 0 result_shape 327680 16 result_stride 16 1 result continuguous 1 self_shape 16 self_stride 1 self continuguous 1 mat1_shape 327680 16 mat1_stride 16 1 mat1 continuguous 1 mat2_shape 16 16 mat2_stride 1 16 mat2 continuguous 0

Exception raised from gemm_and_bias at fbcode/caffe2/aten/src/ATen/cuda/[CUDABlas.cpp:849]

@ptrblck did you see any common thing for the above 3 failure cases? or any other extra debugging info I can add to help further root cause it?

Thanks for your help!

@zhaojuanmao
Copy link
Contributor Author

@ptrblck by the way, the extra error log I added is like this after 'cublasStatus_t cublasStatus = cublasLtMatmul(...)' call:

TORCH_CHECK(
cublasStatus == CUBLAS_STATUS_SUCCESS,
"CUDA error: ",
at::cuda::blas::_cublasGetErrorEnum(cublasStatus),
" when calling cublasLtMatmul with transpose_mat1 ",
transpose_mat1,
" transpose_mat2 ",
transpose_mat2,
" m ",
m,
" n ",
n,
" k ",
k,
" mat1_ld ",
mat1_ld,
" mat2_ld ",
mat2_ld,
" result_ld ",
result_ld,
" abcType ",
abcType,
" computeType ",
computeType,
" scaleType ",
scaleType,
" result_shape ",
result_shape.str(),
" result_stride ",
result_stride.str(),
" result continuguous ",
result.is_contiguous(),
" self_shape ",
self_shape.str(),
" self_stride ",
self_stride.str(),
" self continuguous ",
self.is_contiguous(),
" mat1_shape ",
mat1_shape.str(),
" mat1_stride ",
mat1_stride.str(),
" mat1 continuguous ",
mat1.is_contiguous(),
" mat2_shape ",
mat2_shape.str(),
" mat2_stride ",
mat2_stride.str(),
" mat2 continuguous ",
mat2.is_contiguous());

@facebook-github-bot
Copy link
Contributor

@zhaojuanmao has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@zhaojuanmao
Copy link
Contributor Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased disableaddmmcudalt onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout disableaddmmcudalt && git pull --rebase)

@facebook-github-bot
Copy link
Contributor

@zhaojuanmao has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should also avoid strcmp on the hot path, do it once.

@ptrblck
Copy link
Collaborator

ptrblck commented Dec 29, 2022

@zhaojuanmao Thanks for the information about your debugging steps!
I understand it might not be trivial to write a minimal code snippet which reproduces the issue, but any information you could share might help me trying to reproduce it (also feel free to ping me on Slack in case you would prefer it).
So far, I would add a sync and a CUDA error check at gemm_and_bias before any cublas call is invoked to check if a sticky error is already reported.
I don't want to hijack this PR for the debugging discussion so let's follow up on Slack or in a new issue (in case you can share any more information about the use case and error).

@facebook-github-bot
Copy link
Contributor

@zhaojuanmao has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@zhaojuanmao
Copy link
Contributor Author

@zhaojuanmao Thanks for the information about your debugging steps! I understand it might not be trivial to write a minimal code snippet which reproduces the issue, but any information you could share might help me trying to reproduce it (also feel free to ping me on Slack in case you would prefer it). So far, I would add a sync and a CUDA error check at gemm_and_bias before any cublas call is invoked to check if a sticky error is already reported. I don't want to hijack this PR for the debugging discussion so let's follow up on Slack or in a new issue (in case you can share any more information about the use case and error).

sounds great, thanks!

@zhaojuanmao
Copy link
Contributor Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a rebase job. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased disableaddmmcudalt onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout disableaddmmcudalt && git pull --rebase)

@facebook-github-bot
Copy link
Contributor

@zhaojuanmao has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@zhaojuanmao
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 3, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the disableaddmmcudalt branch July 5, 2024 01:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (fsdp) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants