Skip to content

[ROCm/Windows] Support aotriton for scaled_dot_product_attention on Windows.#162330

Closed
jammm wants to merge 4 commits intopytorch:mainfrom
jammm:jam/hip_platform_aotriton_windows
Closed

[ROCm/Windows] Support aotriton for scaled_dot_product_attention on Windows.#162330
jammm wants to merge 4 commits intopytorch:mainfrom
jammm:jam/hip_platform_aotriton_windows

Conversation

@jammm
Copy link
Contributor

@jammm jammm commented Sep 6, 2025

Enables flash attention and/or memory efficient attention on Windows with scaled_dot_product_attention via. aotriton.
Already tested to be working on Windows with TheRock.

Steps to enable: simply set USE_FLASH_ATTENTION=1 and USE_MEM_EFF_ATTENTION=1 as usual. See https://github.com/ROCm/TheRock/blob/main/external-builds/pytorch/build_prod_wheels.py#L578-L604

cc @peterjc123 @mszhanyi @skyline75489 @nbcsm @iremyux @Blackhex @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 6, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/162330

Note: Links to docs will display an error until the docs builds have been completed.

❌ 19 New Failures, 2 Cancelled Jobs, 2 Unrelated Failures

As of commit 4f46a52 with merge base 5b9114b (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOBS - The following jobs were cancelled. Please retry:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the module: rocm AMD GPU support for Pytorch label Sep 6, 2025
@jammm
Copy link
Contributor Author

jammm commented Sep 6, 2025

cc @ScottTodd

@jammm
Copy link
Contributor Author

jammm commented Sep 6, 2025

cc @xinyazhang

@jammm
Copy link
Contributor Author

jammm commented Sep 6, 2025

@pytorchbot label "release notes: rocm

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 6, 2025

❌ 🤖 pytorchbot command failed:

Got EOF while in a quoted string```
Try `@pytorchbot --help` for more info.

@jammm
Copy link
Contributor Author

jammm commented Sep 6, 2025

@pytorchbot label "release notes: rocm"

@pytorch-bot pytorch-bot bot added the release notes: rocm mandatorylabel label Sep 6, 2025
@jammm
Copy link
Contributor Author

jammm commented Sep 6, 2025

@pytorchbot label "topic: performance"

@pytorch-bot pytorch-bot bot added the topic: performance topic category label Sep 6, 2025
@jammm
Copy link
Contributor Author

jammm commented Sep 6, 2025

@pytorchbot label "module: windows"

@pytorch-bot pytorch-bot bot added the module: windows Windows support for PyTorch label Sep 6, 2025
@jammm jammm force-pushed the jam/hip_platform_aotriton_windows branch from a44f41f to f7ebef2 Compare September 7, 2025 08:58
@Nem404
Copy link

Nem404 commented Sep 8, 2025

Wait a minute, so this is actually TheRock's external-builds/pytorch/patches/pytorch/main/pytorch/hipified/0001-Support-FLASH_ATTENTION-MEM_EFF_ATTENTION-via.-aotri.patch now upstreamed as a PR, making this patch unnecessary?

WoW

xinyazhang
xinyazhang previously approved these changes Sep 9, 2025
Copy link
Collaborator

@xinyazhang xinyazhang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Nem404
Copy link

Nem404 commented Sep 9, 2025

Great to see Xinya has approved :D

Who else do we need here as a reviewer with merge privileges? Jeff?

ScottTodd
ScottTodd previously approved these changes Sep 9, 2025
Comment on lines 910 to 913
if(USE_ROCM)
if(UNIX AND (USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION))
if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION)
include(cmake/External/aotriton.cmake)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I tested this using https://github.com/ROCm/TheRock/blob/main/external-builds/pytorch/build_prod_wheels.py, with and without --enable-pytorch-flash-attention-windows.

  • Both builds succeeded

  • Running pytorch succeeded with aotriton enabled, and comfyUI seemed to generate images on my gfx1100 GPU using the memory efficient attention implementation (after setting the TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 env var)

  • With the option enabled, I see 45MB more logs (15Mb -> 60MB), including 5337 instances of this warning. It seems to just be a warning, possibly fixed by forcing Python into UTF8 mode (will verify)

    Message: '%s %s -> %s'
    Arguments: ('copying', 'torch\\lib\\aotriton.images\\amd-gfx11xx\\flash\\bwd_kernel_dq\\FONLY__\uff0afp32@16_48_0_T_T_1___gfx11xx.aks2', 'build\\lib.win-amd64-cpython-312\\torch\\lib\\aotriton.images\\amd-gfx11xx\\flash\\bwd_kernel_dq')
    --- Logging error ---
    Traceback (most recent call last):
      File "C:\Users\Nod-Shark16\AppData\Local\Programs\Python\Python312\Lib\logging\__init__.py", line 1163, in emit
        stream.write(msg + self.terminator)
      File "C:\Users\Nod-Shark16\AppData\Local\Programs\Python\Python312\Lib\encodings\cp1252.py", line 19, in encode
        return codecs.charmap_encode(input,self.errors,encoding_table)[0]
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    UnicodeEncodeError: 'charmap' codec can't encode character '\uff0a' in position 73: character maps to <undefined>
    Call stack:
      File "D:\b\pytorch_main\setup.py", line 1785, in <module>
        main()
      File "D:\b\pytorch_main\setup.py", line 1766, in main
        setup(
      File "D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\setuptools\__init__.py", line 117, in setup
        return distutils.core.setup(**attrs)
      File "D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\setuptools\_distutils\core.py", line 186, in setup
        return run_commands(dist)
      File "D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\setuptools\_distutils\core.py", line 202, in run_commands
        dist.run_commands()
      File "D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\setuptools\_distutils\dist.py", line 1002, in run_commands
        self.run_command(cmd)
      File "D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\setuptools\dist.py", line 1104, in run_command
        super().run_command(command)
      File "D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\setuptools\_distutils\dist.py", line 1021, in run_command
        cmd_obj.run()
      File "D:\b\pytorch_main\setup.py", line 1353, in run
        super().run()
      File "D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\setuptools\command\bdist_wheel.py", line 370, in run
        self.run_command("build")
      File "D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\setuptools\_distutils\cmd.py", line 357, in run_command
        self.distribution.run_command(command)
      File "D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\setuptools\dist.py", line 1104, in run_command
        super().run_command(command)
      File "D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\setuptools\_distutils\dist.py", line 1021, in run_command
        cmd_obj.run()
      File "D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\setuptools\_distutils\command\build.py", line 135, in run
        self.run_command(cmd_name)
      File "D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\setuptools\_distutils\cmd.py", line 357, in run_command
        self.distribution.run_command(command)
      File "D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\setuptools\dist.py", line 1104, in run_command
        super().run_command(command)
      File "D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\setuptools\_distutils\dist.py", line 1021, in run_command
        cmd_obj.run()
      File "D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\setuptools\command\build_py.py", line 78, in run
        self.build_package_data()
      File "D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\setuptools\command\build_py.py", line 171, in build_package_data
        _outf, _copied = self.copy_file(srcfile, target)
      File "D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\setuptools\command\build_py.py", line 64, in copy_file
        return super().copy_file(  # pyright: ignore[reportReturnType] # pypa/distutils#309
      File "D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\setuptools\_distutils\cmd.py", line 421, in copy_file
        return file_util.copy_file(
      File "D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\setuptools\_distutils\file_util.py", line 130, in copy_file
        log.info("%s %s -> %s", action, src, dir)
    

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rebuilt (without fully cleaning my build/source dirs) with the PYTHONUTF8=1 environment variable and didn't see the warnings. Hopefully a clean rebuild (including deleting torch/lib/aotriton.images/ in the source dir) is also warning-free. We can add that env var to our downstream build script and any upstream build scripts we contribute (see #160776)

@jammm
Copy link
Contributor Author

jammm commented Sep 9, 2025

@jeffdaily PTAL. Received approval from @xinyazhang and @ScottTodd.
We can proceed with queuing it for merge.

@bdhirsh bdhirsh requested a review from jeffdaily September 9, 2025 17:18
@bdhirsh bdhirsh added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 9, 2025
ScottTodd added a commit to ROCm/TheRock that referenced this pull request Sep 9, 2025
## Motivation

Progress on #1040, getting closer
to enabling aotriton in PyTorch on Windows.

## Technical Details

This will supersede #1409 and is
dependent on pytorch/pytorch#162330.

The UTF8 change I believe helps with warnings about logs for copying
files with unicode characters in their names:
```
Message: '%s %s -> %s'
Arguments: ('copying', 'torch\\lib\\aotriton.images\\amd-gfx11xx\\flash\\bwd_kernel_dq\\FONLY__\uff0afp32@16_48_0_T_T_1___gfx11xx.aks2', 'build\\lib.win-amd64-cpython-312\\torch\\lib\\aotriton.images\\amd-gfx11xx\\flash\\bwd_kernel_dq')
--- Logging error ---
Traceback (most recent call last):
  File "C:\Users\Nod-Shark16\AppData\Local\Programs\Python\Python312\Lib\logging\__init__.py", line 1163, in emit
    stream.write(msg + self.terminator)
  File "C:\Users\Nod-Shark16\AppData\Local\Programs\Python\Python312\Lib\encodings\cp1252.py", line 19, in encode
    return codecs.charmap_encode(input,self.errors,encoding_table)[0]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
UnicodeEncodeError: 'charmap' codec can't encode character '\uff0a' in position 73: character maps to <undefined>
Call stack:
  File "D:\b\pytorch_main\setup.py", line 1785, in <module>
    main()
  File "D:\b\pytorch_main\setup.py", line 1766, in main
    setup(
  File "D:\projects\TheRock\external-builds\pytorch\3.12.venv\Lib\site-packages\setuptools\__init__.py", line 117, in setup
    return distutils.core.setup(**attrs)
```

## Test Plan

Tested with local builds on Windows with and without
`--enable-pytorch-flash-attention-windows`.

## Test Result

Builds succeeded, ComfyUI generated images on my gfx1100 GPU (needed
`TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1` for aotriton on that GPU).

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
@jammm
Copy link
Contributor Author

jammm commented Sep 10, 2025

Lint test fails because:

 Error: Failed due to ValueError:
/pytorch/pytorch/cmake/External/aotriton.cmake:83: DEPENDEES ==> DEPENDENCIES
/pytorch/pytorch/cmake/External/aotriton.cmake:113: DEPENDEES ==> DEPENDENCIES

Please either fix the error or add the word(s) to the dictionary file.
HINT: all-lowercase words in the dictionary can cover all case variations.

But DEPENDEES is a valid keyword for ExternalProject_Add_Step https://cmake.org/cmake/help/latest/module/ExternalProject.html#command:externalproject_add_step. We can ignore this I feel.

@jammm
Copy link
Contributor Author

jammm commented Sep 14, 2025

So did 4f46a52 cause the merge to fail? But why

No that's the fix to the regression that broke the CUDA builds. The merge failures are unrelated and should be fixed once they're fixed elsewhere

@Nem404
Copy link

Nem404 commented Sep 15, 2025

No that's the fix to the regression that broke the CUDA builds. The merge failures are unrelated and should be fixed once they're fixed elsewhere

Kinda curious where and when they should be fixed 🤔

Oh, #162881 (comment) 👀

@jeffdaily
Copy link
Collaborator

@pytorchbot merge -f "the cuda build OOM that caused a revert of this PR has been fixed, all other failures are unrelated"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…indows. (pytorch#162330)

Enables flash attention and/or memory efficient attention on Windows with scaled_dot_product_attention via. aotriton.
Already tested to be working on Windows with TheRock.

Steps to enable: simply set `USE_FLASH_ATTENTION=1` and `USE_MEM_EFF_ATTENTION=1` as usual. See https://github.com/ROCm/TheRock/blob/main/external-builds/pytorch/build_prod_wheels.py#L578-L604

Pull Request resolved: pytorch#162330
Approved by: https://github.com/xinyazhang, https://github.com/ScottTodd, https://github.com/jeffdaily

Co-authored-by: Scott Todd <scott.todd0@gmail.com>
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…ion on Windows. (pytorch#162330)"

This reverts commit 62843c1.

Reverted pytorch#162330 on behalf of https://github.com/atalman due to Sorry reverting looks like broke windows nightlies see pytorch#162881 ([comment](pytorch#162330 (comment)))
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…indows. (pytorch#162330)

Enables flash attention and/or memory efficient attention on Windows with scaled_dot_product_attention via. aotriton.
Already tested to be working on Windows with TheRock.

Steps to enable: simply set `USE_FLASH_ATTENTION=1` and `USE_MEM_EFF_ATTENTION=1` as usual. See https://github.com/ROCm/TheRock/blob/main/external-builds/pytorch/build_prod_wheels.py#L578-L604

Pull Request resolved: pytorch#162330
Approved by: https://github.com/jeffdaily

Co-authored-by: Scott Todd <scott.todd0@gmail.com>
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
…indows. (pytorch#162330)

Enables flash attention and/or memory efficient attention on Windows with scaled_dot_product_attention via. aotriton.
Already tested to be working on Windows with TheRock.

Steps to enable: simply set `USE_FLASH_ATTENTION=1` and `USE_MEM_EFF_ATTENTION=1` as usual. See https://github.com/ROCm/TheRock/blob/main/external-builds/pytorch/build_prod_wheels.py#L578-L604

Pull Request resolved: pytorch#162330
Approved by: https://github.com/xinyazhang, https://github.com/ScottTodd, https://github.com/jeffdaily

Co-authored-by: Scott Todd <scott.todd0@gmail.com>
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
…ion on Windows. (pytorch#162330)"

This reverts commit 62843c1.

Reverted pytorch#162330 on behalf of https://github.com/atalman due to Sorry reverting looks like broke windows nightlies see pytorch#162881 ([comment](pytorch#162330 (comment)))
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
…indows. (pytorch#162330)

Enables flash attention and/or memory efficient attention on Windows with scaled_dot_product_attention via. aotriton.
Already tested to be working on Windows with TheRock.

Steps to enable: simply set `USE_FLASH_ATTENTION=1` and `USE_MEM_EFF_ATTENTION=1` as usual. See https://github.com/ROCm/TheRock/blob/main/external-builds/pytorch/build_prod_wheels.py#L578-L604

Pull Request resolved: pytorch#162330
Approved by: https://github.com/jeffdaily

Co-authored-by: Scott Todd <scott.todd0@gmail.com>
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
…indows. (pytorch#162330)

Enables flash attention and/or memory efficient attention on Windows with scaled_dot_product_attention via. aotriton.
Already tested to be working on Windows with TheRock.

Steps to enable: simply set `USE_FLASH_ATTENTION=1` and `USE_MEM_EFF_ATTENTION=1` as usual. See https://github.com/ROCm/TheRock/blob/main/external-builds/pytorch/build_prod_wheels.py#L578-L604

Pull Request resolved: pytorch#162330
Approved by: https://github.com/xinyazhang, https://github.com/ScottTodd, https://github.com/jeffdaily

Co-authored-by: Scott Todd <scott.todd0@gmail.com>
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
…ion on Windows. (pytorch#162330)"

This reverts commit 62843c1.

Reverted pytorch#162330 on behalf of https://github.com/atalman due to Sorry reverting looks like broke windows nightlies see pytorch#162881 ([comment](pytorch#162330 (comment)))
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
…indows. (pytorch#162330)

Enables flash attention and/or memory efficient attention on Windows with scaled_dot_product_attention via. aotriton.
Already tested to be working on Windows with TheRock.

Steps to enable: simply set `USE_FLASH_ATTENTION=1` and `USE_MEM_EFF_ATTENTION=1` as usual. See https://github.com/ROCm/TheRock/blob/main/external-builds/pytorch/build_prod_wheels.py#L578-L604

Pull Request resolved: pytorch#162330
Approved by: https://github.com/jeffdaily

Co-authored-by: Scott Todd <scott.todd0@gmail.com>
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
…indows. (pytorch#162330)

Enables flash attention and/or memory efficient attention on Windows with scaled_dot_product_attention via. aotriton.
Already tested to be working on Windows with TheRock.

Steps to enable: simply set `USE_FLASH_ATTENTION=1` and `USE_MEM_EFF_ATTENTION=1` as usual. See https://github.com/ROCm/TheRock/blob/main/external-builds/pytorch/build_prod_wheels.py#L578-L604

Pull Request resolved: pytorch#162330
Approved by: https://github.com/xinyazhang, https://github.com/ScottTodd, https://github.com/jeffdaily

Co-authored-by: Scott Todd <scott.todd0@gmail.com>
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
…ion on Windows. (pytorch#162330)"

This reverts commit 62843c1.

Reverted pytorch#162330 on behalf of https://github.com/atalman due to Sorry reverting looks like broke windows nightlies see pytorch#162881 ([comment](pytorch#162330 (comment)))
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
…indows. (pytorch#162330)

Enables flash attention and/or memory efficient attention on Windows with scaled_dot_product_attention via. aotriton.
Already tested to be working on Windows with TheRock.

Steps to enable: simply set `USE_FLASH_ATTENTION=1` and `USE_MEM_EFF_ATTENTION=1` as usual. See https://github.com/ROCm/TheRock/blob/main/external-builds/pytorch/build_prod_wheels.py#L578-L604

Pull Request resolved: pytorch#162330
Approved by: https://github.com/jeffdaily

Co-authored-by: Scott Todd <scott.todd0@gmail.com>
xinyazhang pushed a commit to ROCm/pytorch that referenced this pull request Sep 29, 2025
…indows. (pytorch#162330)

Enables flash attention and/or memory efficient attention on Windows with scaled_dot_product_attention via. aotriton.
Already tested to be working on Windows with TheRock.

Steps to enable: simply set `USE_FLASH_ATTENTION=1` and `USE_MEM_EFF_ATTENTION=1` as usual. See https://github.com/ROCm/TheRock/blob/main/external-builds/pytorch/build_prod_wheels.py#L578-L604

Pull Request resolved: pytorch#162330
Approved by: https://github.com/jeffdaily

Co-authored-by: Scott Todd <scott.todd0@gmail.com>
mstankov-amd pushed a commit to ROCm/pytorch that referenced this pull request Oct 4, 2025
…indows. (pytorch#162330)

Enables flash attention and/or memory efficient attention on Windows with scaled_dot_product_attention via. aotriton.
Already tested to be working on Windows with TheRock.

Steps to enable: simply set `USE_FLASH_ATTENTION=1` and `USE_MEM_EFF_ATTENTION=1` as usual. See https://github.com/ROCm/TheRock/blob/main/external-builds/pytorch/build_prod_wheels.py#L578-L604

Pull Request resolved: pytorch#162330
Approved by: https://github.com/jeffdaily

Co-authored-by: Scott Todd <scott.todd0@gmail.com>
pruthvistony pushed a commit to ROCm/pytorch that referenced this pull request Oct 7, 2025
Fixes: pytorch#163958

Cherry-pick pytorch#161754
Cherry-pick pytorch#162330
Cherry-pick pytorch#163373
Cherry-pick pytorch#163745

Note TF32 support is still being plagued by `HIPBLASLT_ALLOW_TF32`,
which should be handled by another PR due to its complexity.

---------

Co-authored-by: Aaryaman Vasishta <aaryaman.vasishta@amd.com>
Co-authored-by: Scott Todd <scott.todd0@gmail.com>
ScottTodd added a commit to ScottTodd/pytorch that referenced this pull request Oct 15, 2025
…indows. (pytorch#162330)

Enables flash attention and/or memory efficient attention on Windows with scaled_dot_product_attention via. aotriton.
Already tested to be working on Windows with TheRock.

Steps to enable: simply set `USE_FLASH_ATTENTION=1` and `USE_MEM_EFF_ATTENTION=1` as usual. See https://github.com/ROCm/TheRock/blob/main/external-builds/pytorch/build_prod_wheels.py#L578-L604

Pull Request resolved: pytorch#162330
Approved by: https://github.com/jeffdaily

Co-authored-by: Scott Todd <scott.todd0@gmail.com>
slojosic-amd pushed a commit to ROCm/pytorch that referenced this pull request Oct 15, 2025
…indows. (pytorch#162330)

Enables flash attention and/or memory efficient attention on Windows with scaled_dot_product_attention via. aotriton.
Already tested to be working on Windows with TheRock.

Steps to enable: simply set `USE_FLASH_ATTENTION=1` and `USE_MEM_EFF_ATTENTION=1` as usual. See https://github.com/ROCm/TheRock/blob/main/external-builds/pytorch/build_prod_wheels.py#L578-L604

Pull Request resolved: pytorch#162330
Approved by: https://github.com/jeffdaily

Co-authored-by: Scott Todd <scott.todd0@gmail.com>
ScottTodd added a commit to ROCm/pytorch that referenced this pull request Oct 15, 2025
…indows. (pytorch#162330)

Enables flash attention and/or memory efficient attention on Windows with scaled_dot_product_attention via. aotriton.
Already tested to be working on Windows with TheRock.

Steps to enable: simply set `USE_FLASH_ATTENTION=1` and `USE_MEM_EFF_ATTENTION=1` as usual. See https://github.com/ROCm/TheRock/blob/main/external-builds/pytorch/build_prod_wheels.py#L578-L604

Pull Request resolved: pytorch#162330
Approved by: https://github.com/jeffdaily

Co-authored-by: Scott Todd <scott.todd0@gmail.com>
jithunnair-amd pushed a commit to ROCm/pytorch that referenced this pull request Oct 22, 2025
…indows. (pytorch#162330)

Enables flash attention and/or memory efficient attention on Windows with scaled_dot_product_attention via. aotriton.
Already tested to be working on Windows with TheRock.

Steps to enable: simply set `USE_FLASH_ATTENTION=1` and `USE_MEM_EFF_ATTENTION=1` as usual. See https://github.com/ROCm/TheRock/blob/main/external-builds/pytorch/build_prod_wheels.py#L578-L604

Pull Request resolved: pytorch#162330
Approved by: https://github.com/jeffdaily

Co-authored-by: Scott Todd <scott.todd0@gmail.com>
jeffdaily pushed a commit to ROCm/pytorch that referenced this pull request Nov 17, 2025
…indows. (pytorch#162330)

Enables flash attention and/or memory efficient attention on Windows with scaled_dot_product_attention via. aotriton.
Already tested to be working on Windows with TheRock.

Steps to enable: simply set `USE_FLASH_ATTENTION=1` and `USE_MEM_EFF_ATTENTION=1` as usual. See https://github.com/ROCm/TheRock/blob/main/external-builds/pytorch/build_prod_wheels.py#L578-L604

Pull Request resolved: pytorch#162330
Approved by: https://github.com/jeffdaily

Co-authored-by: Scott Todd <scott.todd0@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/binaries Trigger all binary build and upload jobs on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged module: rocm AMD GPU support for Pytorch module: windows Windows support for PyTorch open source release notes: rocm mandatorylabel Reverted topic: performance topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants