33 */
44
55#include < ATen/ATen.h>
6+ #include < ATen/cuda/CUDAContextLight.h>
67#include < ATen/cuda/CUDABlas.h>
78#include < ATen/cuda/Exceptions.h>
89#include < ATen/cuda/CUDADataType.h>
@@ -214,6 +215,30 @@ static size_t _getWorkspaceSize() {
214215 return workspace_size;
215216}
216217
218+ void * _getWorkspaceWithoutHandle () {
219+ cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle ();
220+ auto stream = c10::cuda::getCurrentCUDAStream ();
221+ cudaStream_t _stream = stream;
222+ auto key = std::make_tuple (static_cast <void *>(handle), static_cast <void *>(_stream));
223+ auto workspace_it = at::cuda::cublas_handle_stream_to_workspace ().find (key);
224+ TORCH_CHECK (workspace_it != at::cuda::cublas_handle_stream_to_workspace ().end ());
225+ return workspace_it->second .mutable_get ();
226+ }
227+
228+ void * _getWorkspace (size_t & workspaceSize) {
229+ #if (defined(USE_ROCM) || defined(FBCODE_CAFFE2))
230+ workspaceSize = _getWorkspaceSize ();
231+ auto & allocator = *::c10::cuda::CUDACachingAllocator::get ();
232+ auto workspace = allocator.allocate (workspaceSize);
233+ auto workspace_ptr = workspace.mutable_get ();
234+ TORCH_CHECK (workspace_ptr != nullptr , " OOM trying to allocate workspace for cublaslt" );
235+ #else
236+ workspaceSize = at::cuda::getChosenWorkspaceSize ();
237+ auto workspace_ptr = _getWorkspaceWithoutHandle ();
238+ #endif
239+ return workspace_ptr;
240+ }
241+
217242} // anonymous namespace
218243
219244namespace at ::cuda::blas {
@@ -395,9 +420,8 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
395420 }
396421
397422 CuBlasLtMatmulPreference preference;
398- // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
399- // setting this to 1M.
400- size_t workspaceSize = _getWorkspaceSize ();
423+ size_t workspaceSize = 0 ;
424+ auto workspace_ptr = _getWorkspace (workspaceSize);
401425 preference.setAttribute (CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
402426
403427#ifndef USE_ROCM
@@ -409,8 +433,6 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
409433 preference.setAttribute (CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment);
410434#endif
411435
412- auto workspace = at::empty (static_cast <int64_t >(workspaceSize), at::TensorOptions ().dtype (at::kByte ).device (at::kCUDA ));
413-
414436 cublasLtMatmulHeuristicResult_t heuristicResult = {};
415437 int returnedResult = 0 ;
416438 TORCH_CUDABLAS_CHECK (cublasLtMatmulAlgoGetHeuristic (
@@ -442,7 +464,7 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
442464 c,
443465 Cdesc.descriptor (),
444466 &heuristicResult.algo ,
445- workspace. mutable_data_ptr () ,
467+ workspace_ptr ,
446468 workspaceSize,
447469 at::cuda::getCurrentCUDAStream ());
448470 TORCH_CHECK (
@@ -1328,9 +1350,8 @@ void gemm_and_bias(
13281350 CuBlasLtMatrixLayout Cdesc (abcType, m, n, result_ld);
13291351
13301352 CuBlasLtMatmulPreference preference;
1331- // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
1332- // setting this to 1M.
1333- size_t workspaceSize = _getWorkspaceSize ();
1353+ size_t workspaceSize = 0 ;
1354+ auto workspace_ptr = _getWorkspace (workspaceSize);
13341355 preference.setAttribute (CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
13351356
13361357#ifndef USE_ROCM
@@ -1344,8 +1365,7 @@ void gemm_and_bias(
13441365 preference.setAttribute (CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, d_alignment);
13451366#endif
13461367
1347- auto workspace = at::empty (static_cast <int64_t >(workspaceSize), at::TensorOptions ().dtype (at::kByte ).device (at::kCUDA ));
1348-
1368+ auto stream = c10::cuda::getCurrentCUDAStream ();
13491369 cublasLtMatmulHeuristicResult_t heuristicResult = {};
13501370 int returnedResult = 0 ;
13511371 cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle ();
@@ -1378,9 +1398,9 @@ void gemm_and_bias(
13781398 result_ptr,
13791399 Cdesc.descriptor (),
13801400 &heuristicResult.algo ,
1381- workspace. mutable_data_ptr () ,
1401+ workspace_ptr ,
13821402 workspaceSize,
1383- at::cuda::getCurrentCUDAStream () );
1403+ stream );
13841404 TORCH_CHECK (
13851405 cublasStatus == CUBLAS_STATUS_SUCCESS,
13861406 " CUDA error: " ,
@@ -1539,9 +1559,10 @@ void scaled_gemm(
15391559 computeDesc.setAttribute (CUBLASLT_MATMUL_DESC_EPILOGUE, CUBLASLT_EPILOGUE_BIAS);
15401560 computeDesc.setAttribute (CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType (bias_dtype));
15411561 }
1542- size_t workspaceSize = _getWorkspaceSize ();
1543- auto workspace = at::empty (static_cast <int64_t >(workspaceSize), at::TensorOptions ().dtype (at::kByte ).device (at::kCUDA ));
15441562
1563+ auto stream = c10::cuda::getCurrentCUDAStream ();
1564+ size_t workspaceSize = 0 ;
1565+ auto workspace_ptr = _getWorkspace (workspaceSize);
15451566 CuBlasLtMatmulPreference preference;
15461567 preference.setAttribute (CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
15471568 cublasLtMatmulHeuristicResult_t heuristicResult = {};
@@ -1624,9 +1645,9 @@ void scaled_gemm(
16241645 result_ptr,
16251646 Ddesc.descriptor (),
16261647 &heuristicResult.algo ,
1627- workspace. mutable_data_ptr () ,
1648+ workspace_ptr ,
16281649 workspaceSize,
1629- at::cuda::getCurrentCUDAStream () );
1650+ stream );
16301651 TORCH_CHECK (
16311652 cublasStatus == CUBLAS_STATUS_SUCCESS,
16321653 " CUDA error: " ,
@@ -1695,8 +1716,8 @@ void int8_gemm(
16951716 CuBlasLtMatmulPreference preference;
16961717 size_t workspaceSize = _getWorkspaceSize ();
16971718 preference.setAttribute (CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
1698- auto workspace = at::empty (workspaceSize, at::TensorOptions (). dtype (at:: kByte ). device (at:: kCUDA ) );
1699-
1719+ auto & allocator = *:: c10::cuda::CUDACachingAllocator::get ( );
1720+ auto workspace = allocator. allocate (workspaceSize);
17001721 cublasLtMatmulHeuristicResult_t heuristicResult = {};
17011722 int returnedResult = 0 ;
17021723 TORCH_CUDABLAS_CHECK (cublasLtMatmulAlgoGetHeuristic (
@@ -1734,7 +1755,7 @@ void int8_gemm(
17341755 nullptr , // Heuristics don't seem to work for int8
17351756#endif
17361757#ifdef USE_ROCM
1737- workspace.mutable_data_ptr (),
1758+ workspace.mutable_get (),
17381759#else
17391760 nullptr , // Non-zero workspace doesn't seem to work.
17401761#endif
0 commit comments