[SDPA-CUDNN] Make CuDNN Attention Opt in#138587
Merged
kit1980 merged 1 commit intorelease/2.5from Oct 22, 2024
Merged
Conversation
# Summary Currently we have a `cudnn_order` that says on H100 w/ new enough CuDNN backend (we ship a 9.1 version in OSS) try to run CuDNN attention first. We have already encountered a few bugs with the release of 2.5: 1. #138529 2. huggingface/diffusers#9704 3. #138354 In light of the above we are going to make the CuDNN backend Opt-in by default. This can be done easily with the context manager for choosing backends I.e.: ``` Python from torch.nn.attention import sdpa_kernel, SDPBackend with sdpa_kernel(SDPBackend.CUDNN_ATTENTION): out = F.scaled_dot_product_attention(q, k, v) ``` This PR puts the CuDNN backend as the lowest precedence in the backend list, meaning that the Math backend will always be chosen unless disabled (which is done via the context manager). Cc @atalman Pull Request resolved: #138522 Approved by: https://github.com/ngimel, https://github.com/eqy, https://github.com/malfet (cherry picked from commit 9a9a0ab)
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/138587
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 61a03e4 with merge base b7eb725 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
malfet
approved these changes
Oct 22, 2024
aostrowski-hbn
pushed a commit
to HabanaAI/pytorch-fork
that referenced
this pull request
Jan 7, 2025
pytorch/pytorch@v2.5.0...v2.5.1 Squashed new commits are as follows: * update getting started xpu (pytorch#138090) * [Cherry-Pick] Use cuda 12.4 pytorch_extra_install_requirements as default (pytorch#138526) * Don't try to load cufile (pytorch#138539) * Add link to torch.compile the missing manual in troubleshooting (pytorch#137369) * Update cpuinfo submodule (pytorch#138600) * Update doc copyrights to 2024 (pytorch#138650) * [SDPA-CUDNN] Make CuDNN Attention Opt in (pytorch#138587) * [MPS] Fix sliced cast (pytorch#138535) * Disabling amp context when invoking compiler (pytorch#138659) Change-Id: I3e282e8b4809b97b38605420c64d1bd1b0b938aa
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stack from ghstack (oldest at bottom):
Summary
Currently we have a
cudnn_orderthat says on H100 w/ new enough CuDNN backend (we ship a 9.1 version in OSS) try to run CuDNN attention first. We have already encountered a few bugs with the release of 2.5:query's memory layout ordering foroutputin cuDNN SDPA #138354In light of the above we are going to make the CuDNN backend Opt-in by default.
This can be done easily with the context manager for choosing backends I.e.:
This PR puts the CuDNN backend as the lowest precedence in the backend list, meaning that the Math backend will always be chosen unless disabled (which is done via the context manager).
Cc @atalman
cc @mikaylagawarecki