Skip to content

[Inductor] Naive foreach autotune support#162053

Closed
jataylo wants to merge 7 commits intopytorch:mainfrom
jataylo:jack-for-tuning
Closed

[Inductor] Naive foreach autotune support#162053
jataylo wants to merge 7 commits intopytorch:mainfrom
jataylo:jack-for-tuning

Conversation

@jataylo
Copy link
Collaborator

@jataylo jataylo commented Sep 3, 2025

Initial autotuning support for foreach kernels, 4x improvement for some kernels in internal workload. More improvements can surely be made here in the future. Removing num_warps for definition to enable autotune support in generated wrapper code.

Before:
triton_for_fused_18.kd 🔍 | 4.986 ms | 4.986 ms | 2.493 ms | 2 |
triton_for_fused_6.kd 🔍 | 0.098 ms | 0.098 ms | 0.049 ms | 2 |
triton_for_fused_7.kd 🔍 | 0.036 ms | 0.036 ms | 0.018 ms | 2 |

After:
triton_for_fused_18.kd 🔍 | 1.273 ms | 1.273 ms | 0.636 ms | 2 |
triton_for_fused_6.kd 🔍 | 0.044 ms | 0.044 ms | 0.022 ms | 2 |
triton_for_fused_7.kd 🔍 | 0.024 ms | 0.024 ms | 0.012 ms | 2 |

num_warps=8 default due to https://github.com/pytorch/pytorch/blob/main/torch/_inductor/codegen/triton_combo_kernel.py#L374

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @mlazos

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 3, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/162053

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 13 Unrelated Failures

As of commit ff32d98 with merge base 2b93d5b (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@jataylo jataylo marked this pull request as draft September 3, 2025 10:33
@jataylo jataylo added ciflow/rocm Trigger "default" config CI on ROCm ciflow/inductor-rocm Trigger "inductor" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/rocm-mi355 Trigger "default" config CI on ROCm MI355 runners labels Sep 3, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 3, 2025

Warning: Unknown label ciflow/rocm-mi355.
Currently recognized labels are

  • ciflow/binaries
  • ciflow/binaries_libtorch
  • ciflow/binaries_wheel
  • ciflow/triton_binaries
  • ciflow/inductor
  • ciflow/inductor-periodic
  • ciflow/inductor-rocm
  • ciflow/inductor-perf-test-nightly-rocm
  • ciflow/inductor-perf-compare
  • ciflow/inductor-micro-benchmark
  • ciflow/inductor-micro-benchmark-cpu-x86
  • ciflow/inductor-perf-test-nightly-x86-zen
  • ciflow/inductor-cu126
  • ciflow/linux-aarch64
  • ciflow/mps
  • ciflow/nightly
  • ciflow/periodic
  • ciflow/periodic-rocm-mi300
  • ciflow/rocm
  • ciflow/rocm-mi300
  • ciflow/s390
  • ciflow/riscv64
  • ciflow/slow
  • ciflow/trunk
  • ciflow/unstable
  • ciflow/xpu
  • ciflow/vllm
  • ciflow/torchbench
  • ciflow/op-benchmark
  • ciflow/pull
  • ciflow/h100
  • ciflow/h100-distributed
  • ciflow/win-arm64
  • ciflow/h100-symm-mem
  • ciflow/h100-cutlass-backend

Please add the new label to .github/pytorch-probot.yml

@eellison eellison requested a review from mlazos September 3, 2025 14:41
):
configs.append(triton.Config({}, num_stages=1, num_warps=8))
else:
for warps in [1, 2, 4, 8]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have any more details on the kernels that were improved ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me dig that out

@mlazos
Copy link
Contributor

mlazos commented Sep 8, 2025

Along w/ Elias comment, depending on the kernels it might be good to run some optimizer microbenchmarks after this to see if there's any improvement. Thanks for looking at this!

@jataylo
Copy link
Collaborator Author

jataylo commented Sep 9, 2025

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased jack-for-tuning onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout jack-for-tuning && git pull --rebase)

@mlazos mlazos added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 9, 2025
@jataylo
Copy link
Collaborator Author

jataylo commented Sep 23, 2025

@mlazos, @eellison is there any sort of benchmark we can run to display perf results. Problem is the workload this is using is private.

Another annoying issue is for for kernels do not support TORCHINDUCTOR_BENCHMARK_KERNEL=1 so its not enough to copy over some codegen as there is no call harness. We can conditionalise for ROCm only if this makes things a bit easier to land this

@jataylo jataylo marked this pull request as ready for review September 23, 2025 09:43
@jataylo
Copy link
Collaborator Author

jataylo commented Sep 23, 2025

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased jack-for-tuning onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout jack-for-tuning && git pull --rebase)

@mlazos
Copy link
Contributor

mlazos commented Sep 24, 2025

Optimizers are a pretty good workload for this, they will fuse a bunch of foreach ops together. To make it realistic you can just take weight sizes from a public model and init w/ random data.

@jbschlosser jbschlosser added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module topic: performance topic category labels Sep 24, 2025
@mlazos
Copy link
Contributor

mlazos commented Nov 13, 2025

I think if we just only have one config in the else statement we're good to reland, lmk if that's okay.

@naromero77amd
Copy link
Collaborator

I think if we just only have one config in the else statement we're good to reland, lmk if that's okay.

The else statement needs to have more configs, you wouldn't know apriori which is the fastest config. The point is to try more configs, this is what is doing for reductions, persistent reductions, etc.

There should be no increased compile time with autotune disabled, since there is no one config is the same one as before by construction.

More details would be helpful with respect to the slowdown which is being observed.

@mlazos
Copy link
Contributor

mlazos commented Nov 14, 2025

Yeah sorry for the misunderstanding, it turns out the internal team has max_autotune_pointwise enabled which explains why they were seeing this and I missed the negation in the conditional so my previous comment was incorrect. I asked them to turn it off, if they are unable to because there are too many jobs, we should switch this to only max_autotune. I'm waiting for their response but I think we should be able to reland after they confirm. Sorry for the churn! I also really want this and want to get it in. The perf looks great!

@naromero77amd
Copy link
Collaborator

cc: @jataylo

@mlazos Currently, TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE=1 enables autotuning of the following Triton kernels:

  • pointwise
  • reduction
  • persistent reduction
  • foreach (with this PR)

TORCHINDUCTOR_MAX_AUTOTUNE=1 would include GEMMs and convolutions that are much more time consuming to tune.

The logic is that TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE is low tuning overhead and high reward.

So, I would recommend against only activating foreach tuning with TORCHINDUCTOR_MAX_AUTOTUNE=1.

@mlazos
Copy link
Contributor

mlazos commented Nov 18, 2025

cc: @jataylo

@mlazos Currently, TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE=1 enables autotuning of the following Triton kernels:

  • pointwise
  • reduction
  • persistent reduction
  • foreach (with this PR)

TORCHINDUCTOR_MAX_AUTOTUNE=1 would include GEMMs and convolutions that are much more time consuming to tune.

The logic is that TORCHINDUCTOR_MAX_AUTOTUNE_POINTWISE is low tuning overhead and high reward.

So, I would recommend against only activating foreach tuning with TORCHINDUCTOR_MAX_AUTOTUNE=1.

@naromero77amd Yeah I think adding an additional flag TORCHINDUCTOR_MAX_AUTOTUNE_COMBOKERNEL is perhaps a better solution. The main reason being is that 1) it takes a long time because these kernels can be large w/ optimizers and 2) it can cause OOMs when autotuning the optimizer. This is because the optimizer kernels we generate are inplace mutations which induce a clone when we autotune (we can't mutate the actual params). what do you think?

Silv3S pushed a commit to Silv3S/pytorch that referenced this pull request Nov 18, 2025
This reverts commit 6c5db82.

Reverted pytorch#162053 on behalf of https://github.com/mlazos due to Sorry, there's an internal slowdown due to the extra triton configs you added ([comment](pytorch#162053 (comment)))
@mlazos
Copy link
Contributor

mlazos commented Nov 19, 2025

@pytorchbot merge

@mlazos
Copy link
Contributor

mlazos commented Nov 19, 2025

I talked to the internal team, they were able to remove the flag. Relanding! Thank you for your patience.

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@naromero77amd
Copy link
Collaborator

@mlazos Due to some unrelated failures, you will need to do something like:

@pytorchbot merge -i "unrelated failures <some message>"

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 19, 2025

❌ 🤖 pytorchbot command failed:

@pytorchbot: error: unrecognized arguments: unrelated failures <some message>

usage: @pytorchbot [-h] {merge,revert,rebase,label,drci,cherry-pick} ...

Try @pytorchbot --help for more info.

@mlazos
Copy link
Contributor

mlazos commented Nov 20, 2025

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 15 checks: pull / linux-noble-xpu-n-py3.10 / build, pull / linux-jammy-py3.13-clang12 / test (default, 2, 5, lf.linux.4xlarge), pull / linux-jammy-py3.10-gcc11 / test (default, 2, 5, lf.linux.2xlarge), pull / linux-jammy-py3.10-clang18-asan / test (default, 3, 7, lf.linux.4xlarge), pull / linux-jammy-py3.10-clang12 / test (default, 2, 5, lf.linux.4xlarge), trunk / linux-jammy-py3-clang12-executorch / test (executorch, 1, 1, linux.2xlarge, unstable), trunk / linux-jammy-cuda12.8-py3.10-gcc11 / test (default, 3, 5, linux.g6.4xlarge.experimental.nvidia.gpu), trunk / macos-py3-arm64 / test (default, 3, 3, macos-m1-stable), rocm / linux-jammy-rocm-py3.10 / test (default, 2, 6, linux.rocm.gpu.2), inductor / unit-test / inductor-test / test (inductor, 1, 2, linux.g5.4xlarge.nvidia.gpu), rocm-mi355 / linux-noble-rocm-py3.12-mi355 / test (default, 1, 6, linux.rocm.gpu.mi355.1), inductor-rocm / rocm-py3.10-inductor / test (inductor, 1, 2, linux.rocm.gpu.2), inductor-rocm-mi300 / rocm-py3.10-inductor-mi300 / test (inductor, 2, 2, linux.rocm.gpu.gfx942.1), rocm-mi300 / linux-noble-rocm-py3.12-mi300 / test (default, 3, 6, linux.rocm.gpu.gfx942.1), inductor-perf-nightly-rocm-mi300 / rocm-py3_10-inductor-benchmark-test / test (inductor_torchbench_perf_rocm_mi300, 4, 9, linux.rocm.gpu.gfx942.1)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@mlazos
Copy link
Contributor

mlazos commented Nov 20, 2025

@pytorchbot merge -f "Unrelated failures, relanding a reverted PR"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pruthvistony pushed a commit to ROCm/pytorch that referenced this pull request Nov 26, 2025
…ports (#2807)

These are backports based on these upstream PRs. Cherrypicks were
performed when they where possible.

pytorch#163908 (persistent reduction
autotune)
pytorch#161280 (reduction)
pytorch#162053 (foreach)
pytorch#163197 (pointwise)
pytorch#166470 (pointwise config for
atomic add)

Also included are some additional customer-specific configs which were
not upstreamed but are in this backport to 2.9
#2723

Did not backport filter functions such as `
_maybe_filter_configs_for_tma_restrictions`

https://github.com/ROCm/pytorch/blob/release/2.9/torch/_inductor/runtime/triton_heuristics.py#L2614

---------

Co-authored-by: Jack Taylor <jack.taylor@amd.com>
Co-authored-by: Jack Taylor <108682042+jataylo@users.noreply.github.com>
Co-authored-by: Sampsa Riikonen <sriikone@amd.com>
Co-authored-by: AmdSampsa <sampsa.riikonen@amd.com>
JacobSzwejbka pushed a commit that referenced this pull request Dec 8, 2025
Initial autotuning support for foreach kernels, 4x improvement for some kernels in internal workload. More improvements can surely be made here in the future. Removing num_warps for definition to enable autotune support in generated wrapper code.

Before:
triton_for_fused_18.kd 🔍 | 4.986 ms | 4.986 ms | 2.493 ms | 2 |
triton_for_fused_6.kd 🔍 | 0.098 ms | 0.098 ms | 0.049 ms | 2 |
triton_for_fused_7.kd 🔍 | 0.036 ms | 0.036 ms | 0.018 ms | 2 |

After:
triton_for_fused_18.kd 🔍 | 1.273 ms | 1.273 ms | 0.636 ms | 2 |
triton_for_fused_6.kd 🔍 | 0.044 ms | 0.044 ms | 0.022 ms | 2 |
triton_for_fused_7.kd 🔍 | 0.024 ms | 0.024 ms | 0.012 ms | 2 |

num_warps=8 default due to https://github.com/pytorch/pytorch/blob/main/torch/_inductor/codegen/triton_combo_kernel.py#L374

Pull Request resolved: #162053
Approved by: https://github.com/mlazos, https://github.com/naromero77amd, https://github.com/jeffdaily

Co-authored-by: Nichols A. Romero <nick.romero@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/inductor ciflow/inductor-perf-test-nightly Trigger nightly inductor perf tests ciflow/inductor-perf-test-nightly-rocm-mi300 Trigger inductor perf tests on ROCm MI300 ciflow/inductor-rocm Trigger "inductor" config CI on ROCm ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/rocm-mi355 Trigger "default" config CI on ROCm MI355 runners ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor open source release notes: inductor Reverted topic: performance topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants