Skip to content

[ROCm][inductor] More configs for pointwise kernels.#166470

Closed
naromero77amd wants to merge 3 commits intopytorch:mainfrom
ROCm:rocm_autotune_pointwise_more_configs
Closed

[ROCm][inductor] More configs for pointwise kernels.#166470
naromero77amd wants to merge 3 commits intopytorch:mainfrom
ROCm:rocm_autotune_pointwise_more_configs

Conversation

@naromero77amd
Copy link
Collaborator

@naromero77amd naromero77amd commented Oct 28, 2025

This config improves performance by 250% on some kernels that contain t1.atomic_add(...). Again, we conditionalize for ROCm/HIP, so there is no impact to NV.

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @mlazos

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 28, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/166470

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit ac96ea4 with merge base 1e836bc (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor ciflow/rocm Trigger "default" config CI on ROCm module: inductor module: rocm AMD GPU support for Pytorch labels Oct 28, 2025
@naromero77amd naromero77amd marked this pull request as draft October 28, 2025 23:49
@naromero77amd naromero77amd marked this pull request as ready for review October 29, 2025 14:45
Copy link
Contributor

@PaulZhang12 PaulZhang12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Seems like another one of those layout issues with num_warps=1....

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, we really need a repo, where we can check in kernels that are used for any of benchmark, and make sure we capture, dont regress, and can optimize in future. tritonbench sort of covers this but might be too general.

num_stages=2,
waves_per_eu=1, # 20% improvement
),
triton_config_with_settings(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we conditionalize this if atomic add is actually present ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see how we can do this.

@jataylo any ideas?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See, num_stores in heuristics:

"num_store": self.num_store,

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the moment, t1.atomic_add are counted as stores. So I could not distinguish kernel with one atomic_add vs one store. So I had to add another field into this inductor_meta structure.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good - was just pointing to where we do similar analysis. makes sense you need new field.

@naromero77amd naromero77amd added the ciflow/inductor-rocm Trigger "inductor" config CI on ROCm label Oct 30, 2025
@pruthvistony
Copy link
Collaborator

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 30, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

BoyuanFeng pushed a commit that referenced this pull request Oct 31, 2025
This config improves performance by 250% on some kernels that contain `t1.atomic_add(...)`. Again, we conditionalize for ROCm/HIP, so there is no impact to NV.

Pull Request resolved: #166470
Approved by: https://github.com/PaulZhang12, https://github.com/mlazos, https://github.com/eellison, https://github.com/jansel
etaf pushed a commit to etaf/pytorch-inductor-xpu that referenced this pull request Nov 4, 2025
This config improves performance by 250% on some kernels that contain `t1.atomic_add(...)`. Again, we conditionalize for ROCm/HIP, so there is no impact to NV.

Pull Request resolved: pytorch#166470
Approved by: https://github.com/PaulZhang12, https://github.com/mlazos, https://github.com/eellison, https://github.com/jansel
pruthvistony pushed a commit to ROCm/pytorch that referenced this pull request Nov 26, 2025
…ports (#2807)

These are backports based on these upstream PRs. Cherrypicks were
performed when they where possible.

pytorch#163908 (persistent reduction
autotune)
pytorch#161280 (reduction)
pytorch#162053 (foreach)
pytorch#163197 (pointwise)
pytorch#166470 (pointwise config for
atomic add)

Also included are some additional customer-specific configs which were
not upstreamed but are in this backport to 2.9
#2723

Did not backport filter functions such as `
_maybe_filter_configs_for_tma_restrictions`

https://github.com/ROCm/pytorch/blob/release/2.9/torch/_inductor/runtime/triton_heuristics.py#L2614

---------

Co-authored-by: Jack Taylor <jack.taylor@amd.com>
Co-authored-by: Jack Taylor <108682042+jataylo@users.noreply.github.com>
Co-authored-by: Sampsa Riikonen <sriikone@amd.com>
Co-authored-by: AmdSampsa <sampsa.riikonen@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/inductor-rocm Trigger "inductor" config CI on ROCm ciflow/rocm Trigger "default" config CI on ROCm ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor module: rocm AMD GPU support for Pytorch open source release notes: inductor

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants