Skip to content

Conversation

@eqy
Copy link
Collaborator

@eqy eqy commented Jan 18, 2025

As cuBLAS workspaces are already per-stream, there shouldn't be kernel execution overlap with cuBLASLt kernels.

This PR reuses cuBLAS workspaces for cuBLASLt for the following benefits:

Edit: for now, CUBLASLT_WORKSPACE_SIZE still exists to preserve previous behavior (we noticed some accuracy differences when automatically enabling larger workspace for CUBLASLT)

cc @ptrblck @msaroufim @csarofeen @xwang233 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames

@eqy eqy added module: cuda Related to torch.cuda, and CUDA support in general module: cublas Problem related to cublas support matrix multiplication labels Jan 18, 2025
@eqy eqy requested a review from syed-ahmed as a code owner January 18, 2025 00:33
@pytorch-bot
Copy link

pytorch-bot bot commented Jan 18, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/145130

Note: Links to docs will display an error until the docs builds have been completed.

❌ 13 New Failures, 5 Unrelated Failures

As of commit 519681e with merge base d072254 (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@eqy eqy added ciflow/trunk Trigger trunk jobs on your pull request ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/inductor rocm This tag is for PRs from ROCm team labels Jan 18, 2025
@eqy eqy changed the title [cuBLAS][cuBLASLt] Unify cuBLASLt workspace with cuBLAS [cuBLAS][cuBLASLt] Unify cuBLASLt workspaces with cuBLAS workspaces Jan 18, 2025
@eqy eqy added open source ciflow/rocm Trigger "default" config CI on ROCm topic: not user facing topic category and removed rocm This tag is for PRs from ROCm team labels Jan 18, 2025
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 18, 2025 01:53 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 18, 2025 01:53 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 18, 2025 01:53 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 18, 2025 01:53 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 18, 2025 01:53 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 18, 2025 01:53 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 18, 2025 01:53 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 18, 2025 01:53 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 18, 2025 01:53 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 18, 2025 01:53 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 18, 2025 01:53 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 18, 2025 01:53 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 18, 2025 01:53 Inactive
@pytorch-bot pytorch-bot bot temporarily deployed to upload-benchmark-results January 18, 2025 01:53 Inactive
@eqy
Copy link
Collaborator Author

eqy commented Jan 18, 2025

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: s390x-periodic / linux-manylinux-2_28-py3-cpu-s390x / build

Details for Dev Infra team Raised by workflow job

@eqy
Copy link
Collaborator Author

eqy commented Mar 22, 2025

@pytorchmergebot merge -i

facebook-github-bot pushed a commit to pytorch/benchmark that referenced this pull request Mar 24, 2025
Summary:
As `cuBLAS` workspaces are already per-stream, there shouldn't be kernel execution overlap with `cuBLASLt` kernels.

This PR reuses `cuBLAS` workspaces for `cuBLASLt` for the following benefits:

+ caching (`cuBLAS` workspaces were already cached, so now we get that for `cuBLASLt`)
+ "free" workspace size bump for `cuBLASLt` `cuBLASLt` workspace sizes were previously smaller than those for `cuBLAS` by default which potentially hurts performance, and we encountered difficulty in increasing the size due to downstream OOMs , see also #120925
+ fixes behavior broken behavior with the memtracker; pytorch/pytorch#139442 attempted to handle peaky allocation behavior that broke memtracker equivalence tests but it didn't seem to fully work, here the cached/reused `cuBLAS` workspace seems to fix it
+ one environment variable to rule them all: `CUBLAS_WORKSPACE_CONFIG` applies directly to `cuBLASLt` without a confusing `CUBLASLT_WORKSPACE_SIZE` that users would also need to consider

X-link: pytorch/pytorch#145130
Approved by: https://github.com/ngimel

Reviewed By: izaitsevfb

Differential Revision: D71711852

fbshipit-source-id: 4f57539b8f37f1f4c92a57c19276e84f81bffa23
jeffdaily pushed a commit to ROCm/pytorch that referenced this pull request Mar 28, 2025
Follow up to pytorch#145130. That PR caused a warning on ROCm the first time
hipblaslt was called for any workload, always.
pytorchmergebot pushed a commit to ROCm/pytorch that referenced this pull request Mar 31, 2025
Follow up to pytorch#145130. That PR caused a warning on ROCm the first time
hipblaslt was called for any workload, always.
pytorchmergebot pushed a commit that referenced this pull request Mar 31, 2025
Follow up to #145130. That PR caused a warning on ROCm the first time hipblaslt was called for any workload, always.

Fixes #ISSUE_NUMBER

Pull Request resolved: #150227
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
@clee2000
Copy link
Contributor

clee2000 commented Apr 1, 2025

@pytorchbot revert -m "reverted internally by D72140190" -c ghfirst
cc @izaitsevfb for context, @ngimel as reviewer

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Apr 1, 2025
…workspaces (#145130)"

This reverts commit 8f7fbe3.

Reverted #145130 on behalf of https://github.com/clee2000 due to reverted internally by D72140190 ([comment](#145130 (comment)))
@pytorchmergebot
Copy link
Collaborator

@eqy your PR has been successfully reverted.

@izaitsevfb
Copy link
Contributor

pytorchbot revert -m "reverted internally by D72140190" -c ghfirst cc @izaitsevfb for context, @ngimel as reviewer

this PR caused performance regression internally, I'll ask relevant PoCs to provide more context

@jeffdaily
Copy link
Collaborator

If this PR is reworked please consider the changes from #150227. That PR was a necessary follow-up to this PR to fix behavior and the new warning on ROCm.

@izaitsevfb
Copy link
Contributor

the only info that we have:

module forward function was taking 70% extra time with this diff, don’t know the rootcause which operation was taking this time

@eqy eqy closed this Apr 12, 2025
amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
…es (pytorch#145130)

As `cuBLAS` workspaces are already per-stream, there shouldn't be kernel execution overlap with `cuBLASLt` kernels.

This PR reuses `cuBLAS` workspaces for `cuBLASLt` for the following benefits:

+ caching (`cuBLAS` workspaces were already cached, so now we get that for `cuBLASLt`)
+ "free" workspace size bump for `cuBLASLt` `cuBLASLt` workspace sizes were previously smaller than those for `cuBLAS` by default which potentially hurts performance, and we encountered difficulty in increasing the size due to downstream OOMs , see also pytorch#120925
+ fixes behavior broken behavior with the memtracker; pytorch#139442 attempted to handle peaky allocation behavior that broke memtracker equivalence tests but it didn't seem to fully work, here the cached/reused `cuBLAS` workspace seems to fix it
+ one environment variable to rule them all: `CUBLAS_WORKSPACE_CONFIG` applies directly to `cuBLASLt` without a confusing `CUBLASLT_WORKSPACE_SIZE` that users would also need to consider

Pull Request resolved: pytorch#145130
Approved by: https://github.com/ngimel
amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
…150227)

Follow up to pytorch#145130. That PR caused a warning on ROCm the first time hipblaslt was called for any workload, always.

Fixes #ISSUE_NUMBER

Pull Request resolved: pytorch#150227
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
…workspaces (pytorch#145130)"

This reverts commit 8f7fbe3.

Reverted pytorch#145130 on behalf of https://github.com/clee2000 due to reverted internally by D72140190 ([comment](pytorch#145130 (comment)))
pytorchmergebot pushed a commit that referenced this pull request Apr 23, 2025
…151163)

opt-in version of #145130 as there was a lack of repro for the 70% forward issue
`TORCH_CUBLASLT_UNIFIED_WORKSPACE=1`

@izaitsevfb could you comment if it was repeatable per every forward pass, on startup, or something else?

Pull Request resolved: #151163
Approved by: https://github.com/ngimel
facebook-github-bot pushed a commit to pytorch/benchmark that referenced this pull request Apr 24, 2025
Summary:
opt-in version of pytorch/pytorch#145130 as there was a lack of repro for the 70% forward issue
`TORCH_CUBLASLT_UNIFIED_WORKSPACE=1`

izaitsevfb could you comment if it was repeatable per every forward pass, on startup, or something else?

X-link: pytorch/pytorch#151163
Approved by: https://github.com/ngimel

Reviewed By: ZainRizvi

Differential Revision: D73519132

fbshipit-source-id: cf3f786893a3b64e742411a1c027b375ed9103c9
DanaBCooper pushed a commit to graphcore/torchbench-fork that referenced this pull request Apr 24, 2025
Summary:
There are a few issues I'm solving:.
1. It's too hard to measure total pt2 overhead using the dynamo_compile table because users need to know the columns representing all the top-level events (dynamo_cumulative_compile_time_us, etc.). Instead, let's populate the existing duration_us field for all top-level events. The complication is that runtime events in particular (Triton autotuning, cudagraphify) can be collapsed into a single row, with gaps in between, so we can't simply use `end_time - start_time` in all cases. Instead, we'll sum durations for all outer events when updating the compile-time or runtime metrics context. Introduce a 'depth' counter in TLS to track the nesting of CompilationMetrics events.
2. The existing implementation relies on callers of dynamo_timed to specify whether the event is a runtime or compile-time event. That doesn't work because some methods can be called in both situations, e.g., `CachingAutotuner.benchmark_all_configs`. For example `TORCHINDUCTOR_BENCHMARK_FUSION=1` enables benchmarking during compile-time. Instead, we can figure out automatically whether we're measuring a compile-time or runtime event and log accordingling.
3. If `log_compilation_events` were to throw an exception, we'd fail to clear the aggregated counters for runtime logs and they could be attributed to the wrong compile ID. I didn't actually find evidence of this in practice, but I added exception handling for extra safety.

X-link: pytorch/pytorch#151749
Approved by: https://github.com/Skylion007

Reviewed By: wdvr

Differential Revision: D73440137

fbshipit-source-id: 7f176a9ffb4a87bc7176cf737f4bed04a5879a34

Put "everything" WaitCounters in dynamo_timed (#151757)

Summary:
The main motivation is to capture the cudagraphs overhead in a WaitCounter. We'll combine that with Triton autotuning, and therefore rename to "compile_runtime_overheads". Since we have a couple WaitCounters where we want to capture all runtime and compile overheads, let's put the accounting in dynamo_timed so we'll automatically capture any toplevel timed regions that get added in the future. Also, dynamo_timed already has to figure out if we're timing a runtime vs. compile-time event, so we can reuse some of that logic.

X-link: pytorch/pytorch#151757
Approved by: https://github.com/ppanchalia
ghstack dependencies: #151749

Reviewed By: wdvr

Differential Revision: D73440149

fbshipit-source-id: 1b9074bef52b902da09001b4c006661c7d537477

context manager/decorator for dynamo config patching during tracing (#150586)

Summary:
Implement traceable config patching for Dynamo: enables restricted patching of Dynamo config where user can use a context manager/decorator to change tracing behavior for parts of the code.

The new `dont_skip_tracing` decorator/context manager for ignoring most trace rules is easily implemented with this more generic traceable config patching feature.

Implementation:
- Create a new specialized context manager class representing a wrapper around torch._dynamo.config.patch
- Dynamo doesn't trace into the context manager but updates config at compile time
- Correctness is based on our correctness for handling supported context managers
- Implementation is inspired by how `GradModeVariable` is implemented.

Previous attempts: pytorch/pytorch#148736 (decorator-only global approach) and pytorch/pytorch#149439 (decorator-only traceback approach)

See https://docs.google.com/document/d/1vWNwKL_jpg-PLopifcaSa338wks3GqSVF4GHRguybGg/edit?tab=t.0 for more details on implementation - including previous approaches.

NOTE: this PR fixes a bug where skipped code objects were not tracked by convert_frame.py, leading to cases where code objects would be automatically skipped even after `torch._dynamo.reset()`. This exposed some latent dynamo-wrapped test failures in CI that previously passed in CI but not locally.

X-link: pytorch/pytorch#150586
Approved by: https://github.com/jansel, https://github.com/zou3519, https://github.com/anijain2305

Reviewed By: ZainRizvi

Differential Revision: D73519157

fbshipit-source-id: 37a42c1aedbe27c27b3eda8514cdb67a6fc54793

Opt-in unified cuBLAS + cuBLASLt workspaces (#151163)

Summary:
opt-in version of pytorch/pytorch#145130 as there was a lack of repro for the 70% forward issue
`TORCH_CUBLASLT_UNIFIED_WORKSPACE=1`

izaitsevfb could you comment if it was repeatable per every forward pass, on startup, or something else?

X-link: pytorch/pytorch#151163
Approved by: https://github.com/ngimel

Reviewed By: ZainRizvi

Differential Revision: D73519132

fbshipit-source-id: cf3f786893a3b64e742411a1c027b375ed9103c9

Improve metadata skip log message
jeffdaily pushed a commit to ROCm/pytorch that referenced this pull request Jun 20, 2025
…ytorch#151163)

opt-in version of pytorch#145130 as there was a lack of repro for the 70% forward issue
`TORCH_CUBLASLT_UNIFIED_WORKSPACE=1`

@izaitsevfb could you comment if it was repeatable per every forward pass, on startup, or something else?

Pull Request resolved: pytorch#151163
Approved by: https://github.com/ngimel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/inductor ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/rocm Trigger "default" config CI on ROCm ciflow/trunk Trigger trunk jobs on your pull request matrix multiplication Merged module: cublas Problem related to cublas support module: cuda Related to torch.cuda, and CUDA support in general module: dynamo open source Reverted topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants