Skip to content

[ROCm] roll kernel as grid stride loop#169474

Closed
PaulMullowney wants to merge 7 commits intopytorch:mainfrom
PaulMullowney:roll_cuda_grid_stride_impl
Closed

[ROCm] roll kernel as grid stride loop#169474
PaulMullowney wants to merge 7 commits intopytorch:mainfrom
PaulMullowney:roll_cuda_grid_stride_impl

Conversation

@PaulMullowney
Copy link
Contributor

@PaulMullowney PaulMullowney commented Dec 3, 2025

Reimplement the roll kernel as a grid stride loop. On AMD devices, we see launch failures in the original version when gridDim.x*blockDim.x exceeds 4294967295. This implementation should work and be performant on both AMD and Nvidia devices. The issue can be seen on AMD devices with the following small repro:

import torch
N = 21913096
input_tensor_torch = torch.randn(1, 2, N, 98, device='cuda')
output = input_tensor_torch.roll(-1, dims=1)
input_tensor_torch_cpu = input_tensor_torch.cpu()
output_cpu = input_tensor_torch_cpu.roll(-1, dims=1)
assert torch.equal(output.cpu(), output_cpu)

Gives:
torch.AcceleratorError: HIP error: invalid configuration argument

If you set N=21913095, the original version of the kernel runs successfully.

Performance (averaged across 20 invocations) on an MI325x:
N Original (us) Grid Stride (us)
21913 12.5 12.4
219130 128.9 99.4
2191309 1286 1068
21913095 12381 10168

Fixes ROCm#2631

cc @ptrblck @msaroufim @eqy @jerryzh168 @tinglvv @nWEIdia @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @jataylo @hongxiayang @naromero77amd @pragupta @jerrymannil @xinyazhang

@pytorch-bot pytorch-bot bot added the release notes: cuda release notes category label Dec 3, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 3, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/169474

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 8ed9462 with merge base 7c593b9 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Dec 3, 2025

CLA Signed

The committers listed above are authorized under a signed CLA.

@PaulMullowney PaulMullowney marked this pull request as draft December 3, 2025 17:08
@jeffdaily jeffdaily added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 3, 2025
@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Dec 4, 2025
jerrymannil added a commit to ROCm/pytorch that referenced this pull request Dec 4, 2025
@jerrymannil jerrymannil added module: cuda Related to torch.cuda, and CUDA support in general module: rocm AMD GPU support for Pytorch ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 labels Dec 4, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 4, 2025

To add the ciflow label ciflow/rocm-mi300 please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 4, 2025

To add the ciflow label ciflow/rocm please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot pytorch-bot bot removed ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/rocm Trigger "default" config CI on ROCm labels Dec 4, 2025
@jerrymannil jerrymannil added ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 labels Dec 4, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 4, 2025

To add the ciflow label ciflow/rocm please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 4, 2025

To add the ciflow label ciflow/rocm-mi300 please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot pytorch-bot bot removed ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 labels Dec 4, 2025
@jerrymannil jerrymannil added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 4, 2025
@PaulMullowney PaulMullowney force-pushed the roll_cuda_grid_stride_impl branch from 69dff67 to 5c51fe1 Compare December 8, 2025 17:45
@pytorch-bot pytorch-bot bot removed ciflow/trunk Trigger trunk jobs on your pull request ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 labels Dec 8, 2025
@jerrymannil jerrymannil changed the title roll kernel as grid stride loop [ROCm] roll kernel as grid stride loop Dec 9, 2025
@PaulMullowney PaulMullowney force-pushed the roll_cuda_grid_stride_impl branch from 5317c40 to 8ed9462 Compare December 9, 2025 00:45
@jerrymannil jerrymannil marked this pull request as ready for review December 9, 2025 19:14
@jerrymannil jerrymannil added ciflow/trunk Trigger trunk jobs on your pull request ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 labels Dec 9, 2025
@jerrymannil jerrymannil requested a review from jeffdaily December 9, 2025 23:48
@jerrymannil
Copy link
Collaborator

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

skpark-rh pushed a commit to skpark-rh/pytorch that referenced this pull request Dec 10, 2025
Reimplement the roll kernel as a grid stride loop. On AMD devices, we see launch failures in the original version when gridDim.x*blockDim.x exceeds 4294967295. This implementation should work and be performant on both AMD and Nvidia devices. The issue can be seen on AMD devices with the following small repro:

import torch
N = 21913096
input_tensor_torch = torch.randn(1, 2, N, 98, device='cuda')
output = input_tensor_torch.roll(-1, dims=1)
input_tensor_torch_cpu = input_tensor_torch.cpu()
output_cpu = input_tensor_torch_cpu.roll(-1, dims=1)
assert torch.equal(output.cpu(), output_cpu)

Gives:
torch.AcceleratorError: HIP error: invalid configuration argument

If you set N=21913095, the original version of the kernel runs successfully.

Performance (averaged across 20 invocations) on an MI325x:
N                 Original (us)   Grid Stride (us)
21913          12.5                12.4
219130        128.9              99.4
2191309      1286               1068
21913095    12381             10168

Fixes ROCm#2631

Pull Request resolved: pytorch#169474
Approved by: https://github.com/jerrymannil, https://github.com/jeffdaily
@jeffdaily
Copy link
Collaborator

@pytorchbot merge -f "unrelated rocm failure, all other CI passing on trunk"

@jeffdaily
Copy link
Collaborator

stale browser window, sorry

@pytorchmergebot
Copy link
Collaborator

Can't merge closed PR #169474

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/trunk Trigger trunk jobs on your pull request Merged module: cuda Related to torch.cuda, and CUDA support in general module: rocm AMD GPU support for Pytorch open source release notes: cuda release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

roll_cuda_kernel broken in main rocm 6.4 for large input.

5 participants