Skip to content

Commit 9cfa66c

Browse files
eqyaditew01
authored andcommitted
[cuBLAS][cuBLASLt] Unify cuBLASLt workspaces with cuBLAS workspaces (#145130)
As `cuBLAS` workspaces are already per-stream, there shouldn't be kernel execution overlap with `cuBLASLt` kernels. This PR reuses `cuBLAS` workspaces for `cuBLASLt` for the following benefits: + caching (`cuBLAS` workspaces were already cached, so now we get that for `cuBLASLt`) + "free" workspace size bump for `cuBLASLt` `cuBLASLt` workspace sizes were previously smaller than those for `cuBLAS` by default which potentially hurts performance, and we encountered difficulty in increasing the size due to downstream OOMs , see also #120925 + fixes behavior broken behavior with the memtracker; #139442 attempted to handle peaky allocation behavior that broke memtracker equivalence tests but it didn't seem to fully work, here the cached/reused `cuBLAS` workspace seems to fix it + one environment variable to rule them all: `CUBLAS_WORKSPACE_CONFIG` applies directly to `cuBLASLt` without a confusing `CUBLASLT_WORKSPACE_SIZE` that users would also need to consider Pull Request resolved: #145130 Approved by: https://github.com/ngimel
1 parent 0409f1b commit 9cfa66c

File tree

4 files changed

+58
-25
lines changed

4 files changed

+58
-25
lines changed

aten/src/ATen/cuda/CUDABlas.cpp

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
*/
44

55
#include <ATen/ATen.h>
6+
#include <ATen/cuda/CUDAContextLight.h>
67
#include <ATen/cuda/CUDABlas.h>
78
#include <ATen/cuda/Exceptions.h>
89
#include <ATen/cuda/CUDADataType.h>
@@ -214,6 +215,30 @@ static size_t _getWorkspaceSize() {
214215
return workspace_size;
215216
}
216217

218+
void* _getWorkspaceWithoutHandle() {
219+
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
220+
auto stream = c10::cuda::getCurrentCUDAStream();
221+
cudaStream_t _stream = stream;
222+
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
223+
auto workspace_it = at::cuda::cublas_handle_stream_to_workspace().find(key);
224+
TORCH_CHECK(workspace_it != at::cuda::cublas_handle_stream_to_workspace().end());
225+
return workspace_it->second.mutable_get();
226+
}
227+
228+
void* _getWorkspace(size_t& workspaceSize) {
229+
#if (defined(USE_ROCM) || defined(FBCODE_CAFFE2))
230+
workspaceSize = _getWorkspaceSize();
231+
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
232+
auto workspace = allocator.allocate(workspaceSize);
233+
auto workspace_ptr = workspace.mutable_get();
234+
TORCH_CHECK(workspace_ptr != nullptr, "OOM trying to allocate workspace for cublaslt");
235+
#else
236+
workspaceSize = at::cuda::getChosenWorkspaceSize();
237+
auto workspace_ptr = _getWorkspaceWithoutHandle();
238+
#endif
239+
return workspace_ptr;
240+
}
241+
217242
} // anonymous namespace
218243

219244
namespace at::cuda::blas {
@@ -395,9 +420,8 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
395420
}
396421

397422
CuBlasLtMatmulPreference preference;
398-
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
399-
// setting this to 1M.
400-
size_t workspaceSize = _getWorkspaceSize();
423+
size_t workspaceSize = 0;
424+
auto workspace_ptr = _getWorkspace(workspaceSize);
401425
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
402426

403427
#ifndef USE_ROCM
@@ -409,8 +433,6 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
409433
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment);
410434
#endif
411435

412-
auto workspace = at::empty(static_cast<int64_t>(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
413-
414436
cublasLtMatmulHeuristicResult_t heuristicResult = {};
415437
int returnedResult = 0;
416438
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
@@ -442,7 +464,7 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
442464
c,
443465
Cdesc.descriptor(),
444466
&heuristicResult.algo,
445-
workspace.mutable_data_ptr(),
467+
workspace_ptr,
446468
workspaceSize,
447469
at::cuda::getCurrentCUDAStream());
448470
TORCH_CHECK(
@@ -1328,9 +1350,8 @@ void gemm_and_bias(
13281350
CuBlasLtMatrixLayout Cdesc(abcType, m, n, result_ld);
13291351

13301352
CuBlasLtMatmulPreference preference;
1331-
// See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
1332-
// setting this to 1M.
1333-
size_t workspaceSize = _getWorkspaceSize();
1353+
size_t workspaceSize = 0;
1354+
auto workspace_ptr = _getWorkspace(workspaceSize);
13341355
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
13351356

13361357
#ifndef USE_ROCM
@@ -1344,8 +1365,7 @@ void gemm_and_bias(
13441365
preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, d_alignment);
13451366
#endif
13461367

1347-
auto workspace = at::empty(static_cast<int64_t>(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
1348-
1368+
auto stream = c10::cuda::getCurrentCUDAStream();
13491369
cublasLtMatmulHeuristicResult_t heuristicResult = {};
13501370
int returnedResult = 0;
13511371
cublasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle();
@@ -1378,9 +1398,9 @@ void gemm_and_bias(
13781398
result_ptr,
13791399
Cdesc.descriptor(),
13801400
&heuristicResult.algo,
1381-
workspace.mutable_data_ptr(),
1401+
workspace_ptr,
13821402
workspaceSize,
1383-
at::cuda::getCurrentCUDAStream());
1403+
stream);
13841404
TORCH_CHECK(
13851405
cublasStatus == CUBLAS_STATUS_SUCCESS,
13861406
"CUDA error: ",
@@ -1539,9 +1559,10 @@ void scaled_gemm(
15391559
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_EPILOGUE, CUBLASLT_EPILOGUE_BIAS);
15401560
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype));
15411561
}
1542-
size_t workspaceSize = _getWorkspaceSize();
1543-
auto workspace = at::empty(static_cast<int64_t>(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
15441562

1563+
auto stream = c10::cuda::getCurrentCUDAStream();
1564+
size_t workspaceSize = 0;
1565+
auto workspace_ptr = _getWorkspace(workspaceSize);
15451566
CuBlasLtMatmulPreference preference;
15461567
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
15471568
cublasLtMatmulHeuristicResult_t heuristicResult = {};
@@ -1624,9 +1645,9 @@ void scaled_gemm(
16241645
result_ptr,
16251646
Ddesc.descriptor(),
16261647
&heuristicResult.algo,
1627-
workspace.mutable_data_ptr(),
1648+
workspace_ptr,
16281649
workspaceSize,
1629-
at::cuda::getCurrentCUDAStream());
1650+
stream);
16301651
TORCH_CHECK(
16311652
cublasStatus == CUBLAS_STATUS_SUCCESS,
16321653
"CUDA error: ",
@@ -1695,8 +1716,8 @@ void int8_gemm(
16951716
CuBlasLtMatmulPreference preference;
16961717
size_t workspaceSize = _getWorkspaceSize();
16971718
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize);
1698-
auto workspace = at::empty(workspaceSize, at::TensorOptions().dtype(at::kByte).device(at::kCUDA));
1699-
1719+
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
1720+
auto workspace = allocator.allocate(workspaceSize);
17001721
cublasLtMatmulHeuristicResult_t heuristicResult = {};
17011722
int returnedResult = 0;
17021723
TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic(
@@ -1734,7 +1755,7 @@ void int8_gemm(
17341755
nullptr, // Heuristics don't seem to work for int8
17351756
#endif
17361757
#ifdef USE_ROCM
1737-
workspace.mutable_data_ptr(),
1758+
workspace.mutable_get(),
17381759
#else
17391760
nullptr, // Non-zero workspace doesn't seem to work.
17401761
#endif

aten/src/ATen/cuda/CUDAContextLight.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Light-weight version of CUDAContext.h with fewer transitive includes
33

44
#include <cstdint>
5+
#include <map>
56

67
#include <cuda_runtime_api.h>
78
#include <cusparse.h>
@@ -87,6 +88,8 @@ TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
8788
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
8889

8990
TORCH_CUDA_CPP_API void clearCublasWorkspaces();
91+
TORCH_CUDA_CPP_API std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace();
92+
TORCH_CUDA_CPP_API size_t getChosenWorkspaceSize();
9093

9194
#if defined(CUDART_VERSION) || defined(USE_ROCM)
9295
TORCH_CUDA_CPP_API cusolverDnHandle_t getCurrentCUDASolverDnHandle();

aten/src/ATen/cuda/CublasHandlePool.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,6 @@ static hipblasStatus_t hipblasSetWorkspace_replacement(hipblasHandle_t handle, v
8383

8484
#endif
8585

86-
std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace() {
87-
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
88-
return instance;
89-
}
90-
9186
void createCublasHandle(cublasHandle_t *handle) {
9287
TORCH_CUDABLAS_CHECK(cublasCreate(handle));
9388
}
@@ -109,6 +104,11 @@ using CuBlasPoolType = DeviceThreadHandlePool<cublasHandle_t, createCublasHandle
109104

110105
} // namespace
111106

107+
std::map<std::tuple<void *, void *>, at::DataPtr>& cublas_handle_stream_to_workspace() {
108+
static auto& instance = *new std::map<std::tuple<void *, void *>, at::DataPtr>;
109+
return instance;
110+
}
111+
112112
void clearCublasWorkspaces() {
113113
cublas_handle_stream_to_workspace().clear();
114114
}

benchmarks/dynamo/common.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3576,6 +3576,15 @@ def run(runner, args, original_dir=None):
35763576
# some of the models do not support use_deterministic_algorithms
35773577
torch.use_deterministic_algorithms(True)
35783578
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
3579+
if args.only is not None and args.only in {
3580+
"DebertaForQuestionAnswering",
3581+
"RobertaForQuestionAnswering",
3582+
"nvidia_deeprecommender",
3583+
"volo_d1_224",
3584+
}:
3585+
# These seem unhappy with numerics of larger cuBLASLt workspace
3586+
# sizes following #145130 (due to enabling split-k?)
3587+
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
35793588
torch.backends.cudnn.deterministic = True
35803589
torch.backends.cudnn.allow_tf32 = False
35813590
torch.backends.cudnn.benchmark = False

0 commit comments

Comments
 (0)