Skip to content

Conversation

@eqy
Copy link
Collaborator

@eqy eqy commented Nov 1, 2024

CC @zdevito @janeyx99

This isn't ideal but cuBLASLt workspaces are not currently cached, so this additional untracked allocation will cause test_cuda_tracker_equivalence to fail with a large enough workspace size e.g., CUBLAS_LT_WORKSPACE_SIZE=32768. One solution is to just use byte-tensors for the workspace instead of going directly to the caching allocator.

cc @ptrblck @msaroufim @csarofeen @xwang233

@eqy eqy added module: cuda Related to torch.cuda, and CUDA support in general module: cublas Problem related to cublas support topic: not user facing topic category labels Nov 1, 2024
@eqy eqy requested a review from syed-ahmed as a code owner November 1, 2024 00:01
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 1, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/139442

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 104eb92 with merge base d0fd42e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@eqy
Copy link
Collaborator Author

eqy commented Nov 1, 2024

CC @nWEIdia @tinglvv as we discussed this

@Aidyn-A as 32MiB default workspace size for H100 is relevant here

Copy link
Collaborator

@Aidyn-A Aidyn-A left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. QQ: does at::empty has the nullptr check? If yes, this:

TORCH_CHECK(workspace.data_ptr() != nullptr, "OOM trying to allocate workspace for cublaslt");

would be redundant.

@janeyx99
Copy link
Contributor

janeyx99 commented Nov 1, 2024

Does this have performance implications? Like is at::empty overhead going to make this slower than it was before to call cuBlas APIs?

@eqy
Copy link
Collaborator Author

eqy commented Nov 15, 2024

Does this have performance implications? Like is at::empty overhead going to make this slower than it was before to call cuBlas APIs?

Did some microbenchmarking finally, seems like the overhead is in the range of 0.5us

#include <torch/torch.h>
#include <torch/cuda.h>
#include <c10/cuda/CUDACachingAllocator.h>

#include <iostream>
#include <chrono>

#define ITER 1000000

int main() {
  auto caching_allocator = c10::cuda::CUDACachingAllocator::get();
  std::cout << "start empty" << std::endl;
  auto t = at::empty({1024}, at::device(at::kCUDA));
  torch::cuda::synchronize();
  std::cout << "empty warmup finished" << std::endl;
  auto t0 = std::chrono::high_resolution_clock::now();
  for (int i = 0; i < ITER; i++) {
    t = at::empty({1024}, at::device(at::kCUDA));
  }
  torch::cuda::synchronize();
  auto t1 = std::chrono::high_resolution_clock::now();
  std::cout << "start allocate" << std::endl;
  auto ptr = caching_allocator->allocate(1024*4);
  torch::cuda::synchronize();
  std::cout << "allocate warmup finished" << std::endl;
  auto t2 = std::chrono::high_resolution_clock::now();
  for (int i = 0; i < ITER; i++) {
    ptr = caching_allocator->allocate(1024*4);
  }
  torch::cuda::synchronize();
  auto t3 = std::chrono::high_resolution_clock::now();
  auto empty_time = std::chrono::duration_cast<std::chrono::duration<double>>(t1 - t0).count();
  auto allocate_time = std::chrono::duration_cast<std::chrono::duration<double>>(t3 - t2).count();
  std::cout << "empty time per iter: " << empty_time/ITER << std::endl;
  std::cout << "allocate time per iter: " << allocate_time/ITER << std::endl;
}
start empty
empty warmup finished
start allocate
allocate warmup finished
empty time per iter: 7.88432e-07
allocate time per iter: 2.3696e-07

@eqy eqy added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 18, 2024
@eqy
Copy link
Collaborator Author

eqy commented Nov 18, 2024

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased memtrackerlt onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout memtrackerlt && git pull --rebase)

@eqy
Copy link
Collaborator Author

eqy commented Nov 25, 2024

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased memtrackerlt onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout memtrackerlt && git pull --rebase)

@eqy
Copy link
Collaborator Author

eqy commented Nov 26, 2024

@pytorchmergebot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Approvers from one of the following sets are needed:

  • superuser (pytorch/metamates)
  • Core Reviewers (mruberry, lezcano, Skylion007, ngimel, peterbell10, ...)
  • Core Maintainers (soumith, gchanan, ezyang, dzhulgakov, malfet, ...)
Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM

@janeyx99
Copy link
Contributor

janeyx99 commented Dec 6, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorch-bot bot pushed a commit that referenced this pull request Dec 9, 2024
…tensors rather than going to the caching allocator directly (#139442)

CC @zdevito @janeyx99

This isn't ideal but cuBLASLt workspaces are not currently cached, so this additional untracked allocation will cause `test_cuda_tracker_equivalence` to fail with a large enough workspace size e.g., `CUBLAS_LT_WORKSPACE_SIZE=32768`. One solution is to just use byte-tensors for the workspace instead of going directly to the caching allocator.

Pull Request resolved: #139442
Approved by: https://github.com/Aidyn-A, https://github.com/albanD, https://github.com/janeyx99
pytorchmergebot pushed a commit that referenced this pull request Feb 6, 2025
…es (#145130)

As `cuBLAS` workspaces are already per-stream, there shouldn't be kernel execution overlap with `cuBLASLt` kernels.

This PR reuses `cuBLAS` workspaces for `cuBLASLt` for the following benefits:

+ caching (`cuBLAS` workspaces were already cached, so now we get that for `cuBLASLt`)
+ "free" workspace size bump for `cuBLASLt` `cuBLASLt` workspace sizes were previously smaller than those for `cuBLAS` by default which potentially hurts performance, and we encountered difficulty in increasing the size due to downstream OOMs , see also #120925
+ fixes behavior broken behavior with the memtracker; #139442 attempted to handle peaky allocation behavior that broke memtracker equivalence tests but it didn't seem to fully work, here the cached/reused `cuBLAS` workspace seems to fix it
+ one environment variable to rule them all: `CUBLAS_WORKSPACE_CONFIG` applies directly to `cuBLASLt` without a confusing `CUBLASLT_WORKSPACE_SIZE` that users would also need to consider

Pull Request resolved: #145130
Approved by: https://github.com/ngimel
facebook-github-bot pushed a commit to pytorch/benchmark that referenced this pull request Feb 7, 2025
Summary:
As `cuBLAS` workspaces are already per-stream, there shouldn't be kernel execution overlap with `cuBLASLt` kernels.

This PR reuses `cuBLAS` workspaces for `cuBLASLt` for the following benefits:

+ caching (`cuBLAS` workspaces were already cached, so now we get that for `cuBLASLt`)
+ "free" workspace size bump for `cuBLASLt` `cuBLASLt` workspace sizes were previously smaller than those for `cuBLAS` by default which potentially hurts performance, and we encountered difficulty in increasing the size due to downstream OOMs , see also #120925
+ fixes behavior broken behavior with the memtracker; pytorch/pytorch#139442 attempted to handle peaky allocation behavior that broke memtracker equivalence tests but it didn't seem to fully work, here the cached/reused `cuBLAS` workspace seems to fix it
+ one environment variable to rule them all: `CUBLAS_WORKSPACE_CONFIG` applies directly to `cuBLASLt` without a confusing `CUBLASLT_WORKSPACE_SIZE` that users would also need to consider

X-link: pytorch/pytorch#145130
Approved by: https://github.com/ngimel

Reviewed By: atalman

Differential Revision: D69257102

fbshipit-source-id: 4a2e6391fa899829758596ab2e2f4b16003e5197
pytorchmergebot pushed a commit that referenced this pull request Feb 23, 2025
…es (#145130)

As `cuBLAS` workspaces are already per-stream, there shouldn't be kernel execution overlap with `cuBLASLt` kernels.

This PR reuses `cuBLAS` workspaces for `cuBLASLt` for the following benefits:

+ caching (`cuBLAS` workspaces were already cached, so now we get that for `cuBLASLt`)
+ "free" workspace size bump for `cuBLASLt` `cuBLASLt` workspace sizes were previously smaller than those for `cuBLAS` by default which potentially hurts performance, and we encountered difficulty in increasing the size due to downstream OOMs , see also #120925
+ fixes behavior broken behavior with the memtracker; #139442 attempted to handle peaky allocation behavior that broke memtracker equivalence tests but it didn't seem to fully work, here the cached/reused `cuBLAS` workspace seems to fix it
+ one environment variable to rule them all: `CUBLAS_WORKSPACE_CONFIG` applies directly to `cuBLASLt` without a confusing `CUBLASLT_WORKSPACE_SIZE` that users would also need to consider

Pull Request resolved: #145130
Approved by: https://github.com/ngimel
facebook-github-bot pushed a commit to pytorch/benchmark that referenced this pull request Feb 24, 2025
Summary:
As `cuBLAS` workspaces are already per-stream, there shouldn't be kernel execution overlap with `cuBLASLt` kernels.

This PR reuses `cuBLAS` workspaces for `cuBLASLt` for the following benefits:

+ caching (`cuBLAS` workspaces were already cached, so now we get that for `cuBLASLt`)
+ "free" workspace size bump for `cuBLASLt` `cuBLASLt` workspace sizes were previously smaller than those for `cuBLAS` by default which potentially hurts performance, and we encountered difficulty in increasing the size due to downstream OOMs , see also #120925
+ fixes behavior broken behavior with the memtracker; pytorch/pytorch#139442 attempted to handle peaky allocation behavior that broke memtracker equivalence tests but it didn't seem to fully work, here the cached/reused `cuBLAS` workspace seems to fix it
+ one environment variable to rule them all: `CUBLAS_WORKSPACE_CONFIG` applies directly to `cuBLASLt` without a confusing `CUBLASLT_WORKSPACE_SIZE` that users would also need to consider

X-link: pytorch/pytorch#145130
Approved by: https://github.com/ngimel

Reviewed By: jeanschmidt

Differential Revision: D70075331

fbshipit-source-id: cf4d0d687b299c942793a758c6fec4b064c44227
aditew01 pushed a commit that referenced this pull request Feb 28, 2025
…es (#145130)

As `cuBLAS` workspaces are already per-stream, there shouldn't be kernel execution overlap with `cuBLASLt` kernels.

This PR reuses `cuBLAS` workspaces for `cuBLASLt` for the following benefits:

+ caching (`cuBLAS` workspaces were already cached, so now we get that for `cuBLASLt`)
+ "free" workspace size bump for `cuBLASLt` `cuBLASLt` workspace sizes were previously smaller than those for `cuBLAS` by default which potentially hurts performance, and we encountered difficulty in increasing the size due to downstream OOMs , see also #120925
+ fixes behavior broken behavior with the memtracker; #139442 attempted to handle peaky allocation behavior that broke memtracker equivalence tests but it didn't seem to fully work, here the cached/reused `cuBLAS` workspace seems to fix it
+ one environment variable to rule them all: `CUBLAS_WORKSPACE_CONFIG` applies directly to `cuBLASLt` without a confusing `CUBLASLT_WORKSPACE_SIZE` that users would also need to consider

Pull Request resolved: #145130
Approved by: https://github.com/ngimel
pytorchmergebot pushed a commit that referenced this pull request Mar 22, 2025
…es (#145130)

As `cuBLAS` workspaces are already per-stream, there shouldn't be kernel execution overlap with `cuBLASLt` kernels.

This PR reuses `cuBLAS` workspaces for `cuBLASLt` for the following benefits:

+ caching (`cuBLAS` workspaces were already cached, so now we get that for `cuBLASLt`)
+ "free" workspace size bump for `cuBLASLt` `cuBLASLt` workspace sizes were previously smaller than those for `cuBLAS` by default which potentially hurts performance, and we encountered difficulty in increasing the size due to downstream OOMs , see also #120925
+ fixes behavior broken behavior with the memtracker; #139442 attempted to handle peaky allocation behavior that broke memtracker equivalence tests but it didn't seem to fully work, here the cached/reused `cuBLAS` workspace seems to fix it
+ one environment variable to rule them all: `CUBLAS_WORKSPACE_CONFIG` applies directly to `cuBLASLt` without a confusing `CUBLASLT_WORKSPACE_SIZE` that users would also need to consider

Pull Request resolved: #145130
Approved by: https://github.com/ngimel
facebook-github-bot pushed a commit to pytorch/benchmark that referenced this pull request Mar 24, 2025
Summary:
As `cuBLAS` workspaces are already per-stream, there shouldn't be kernel execution overlap with `cuBLASLt` kernels.

This PR reuses `cuBLAS` workspaces for `cuBLASLt` for the following benefits:

+ caching (`cuBLAS` workspaces were already cached, so now we get that for `cuBLASLt`)
+ "free" workspace size bump for `cuBLASLt` `cuBLASLt` workspace sizes were previously smaller than those for `cuBLAS` by default which potentially hurts performance, and we encountered difficulty in increasing the size due to downstream OOMs , see also #120925
+ fixes behavior broken behavior with the memtracker; pytorch/pytorch#139442 attempted to handle peaky allocation behavior that broke memtracker equivalence tests but it didn't seem to fully work, here the cached/reused `cuBLAS` workspace seems to fix it
+ one environment variable to rule them all: `CUBLAS_WORKSPACE_CONFIG` applies directly to `cuBLASLt` without a confusing `CUBLASLT_WORKSPACE_SIZE` that users would also need to consider

X-link: pytorch/pytorch#145130
Approved by: https://github.com/ngimel

Reviewed By: izaitsevfb

Differential Revision: D71711852

fbshipit-source-id: 4f57539b8f37f1f4c92a57c19276e84f81bffa23
amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
…es (pytorch#145130)

As `cuBLAS` workspaces are already per-stream, there shouldn't be kernel execution overlap with `cuBLASLt` kernels.

This PR reuses `cuBLAS` workspaces for `cuBLASLt` for the following benefits:

+ caching (`cuBLAS` workspaces were already cached, so now we get that for `cuBLASLt`)
+ "free" workspace size bump for `cuBLASLt` `cuBLASLt` workspace sizes were previously smaller than those for `cuBLAS` by default which potentially hurts performance, and we encountered difficulty in increasing the size due to downstream OOMs , see also pytorch#120925
+ fixes behavior broken behavior with the memtracker; pytorch#139442 attempted to handle peaky allocation behavior that broke memtracker equivalence tests but it didn't seem to fully work, here the cached/reused `cuBLAS` workspace seems to fix it
+ one environment variable to rule them all: `CUBLAS_WORKSPACE_CONFIG` applies directly to `cuBLASLt` without a confusing `CUBLASLT_WORKSPACE_SIZE` that users would also need to consider

Pull Request resolved: pytorch#145130
Approved by: https://github.com/ngimel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: cublas Problem related to cublas support module: cuda Related to torch.cuda, and CUDA support in general open source topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants