Skip to content

[Inductor] Fix combo kernels by populating constants for equal_to_1 args#168127

Closed
karthickai wants to merge 19 commits intogh/karthickai/16/basefrom
gh/karthickai/16/head
Closed

[Inductor] Fix combo kernels by populating constants for equal_to_1 args#168127
karthickai wants to merge 19 commits intogh/karthickai/16/basefrom
gh/karthickai/16/head

Conversation

@karthickai
Copy link
Collaborator

@karthickai karthickai commented Nov 19, 2025

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 19, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/168127

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 5 Unrelated Failures

As of commit 747aadb with merge base e770c95 (image):

NEW FAILURE - The following job has failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@karthickai
Copy link
Collaborator Author

Generated triton kernel before and after fix:

before fix: constant is missing for load_seed_offset1

import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties

@triton_heuristics.pointwise(
    size_hints={'x': 1024}, tile_hint=TileHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'load_seed_offset': 'i32', 'load_seed_offset1': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'RoundRobinComboKernelGrid', 'combo_grid_meta': {'num_kernels': 2, 'min_blocks': 0, 'default_config': None, 'no_x_dim_0': False, 'xnumel_0': 1024, 'no_x_dim_1': False, 'xnumel_1': 1024}, 'kernel_name': 'triton_poi_fused_0', 'mutated_arg_names': [], 'backend_hash': '07C4B3116EC6B0BD20166279782DB98EA71861B79334EBBC8CCB3D36A1E5D7F2', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': False, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False}
)
@triton.jit
def triton_poi_fused_0(in_ptr0, out_ptr0, out_ptr1, load_seed_offset, load_seed_offset1, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 2 == 0:
        pid_offset = pid // 2
        xnumel = 1024
        r0_numel = 1
        xoffset = pid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:]
        xmask = xindex < xnumel
        x0 = xindex
        tmp0 = tl.load(in_ptr0 + load_seed_offset)
        tmp1 = x0
        tmp2 = tl.rand(tmp0, (tmp1).to(tl.uint32))
        tl.store(out_ptr0 + (x0), tmp2, xmask)
    elif pid % 2 == 1:
        pid_offset = pid // 2
        xnumel = 1024
        r0_numel = 1
        xoffset = pid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:]
        xmask = xindex < xnumel
        x1 = xindex
        tmp3 = tl.load(in_ptr0 + load_seed_offset1)
        tmp4 = x1
        tmp5 = tl.rand(tmp3, (tmp4).to(tl.uint32))
        tl.store(out_ptr1 + (x1), tmp5, xmask)
    else:
        pass

After fix constant populated properly constants': {'load_seed_offset1': 1}

import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties

@triton_heuristics.pointwise(
    size_hints={'x': 1024}, tile_hint=TileHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'load_seed_offset': 'i32', 'load_seed_offset1': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'load_seed_offset1': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'RoundRobinComboKernelGrid', 'combo_grid_meta': {'num_kernels': 2, 'min_blocks': 0, 'default_config': None, 'no_x_dim_0': False, 'xnumel_0': 1024, 'no_x_dim_1': False, 'xnumel_1': 1024}, 'kernel_name': 'triton_poi_fused_0', 'mutated_arg_names': [], 'backend_hash': '07C4B3116EC6B0BD20166279782DB98EA71861B79334EBBC8CCB3D36A1E5D7F2', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': False, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False}
)
@triton.jit
def triton_poi_fused_0(in_ptr0, out_ptr0, out_ptr1, load_seed_offset, load_seed_offset1, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 2 == 0:
        pid_offset = pid // 2
        xnumel = 1024
        r0_numel = 1
        xoffset = pid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:]
        xmask = xindex < xnumel
        x0 = xindex
        tmp0 = tl.load(in_ptr0 + load_seed_offset)
        tmp1 = x0
        tmp2 = tl.rand(tmp0, (tmp1).to(tl.uint32))
        tl.store(out_ptr0 + (x0), tmp2, xmask)
    elif pid % 2 == 1:
        pid_offset = pid // 2
        xnumel = 1024
        r0_numel = 1
        xoffset = pid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:]
        xmask = xindex < xnumel
        x1 = xindex
        tmp3 = tl.load(in_ptr0 + load_seed_offset1)
        tmp4 = x1
        tmp5 = tl.rand(tmp3, (tmp4).to(tl.uint32))
        tl.store(out_ptr1 + (x1), tmp5, xmask)
    else:
        pass

@karthickai karthickai requested a review from mlazos November 19, 2025 01:30
"constants": {},
}

for arg_num in equal_1_arg_indices(signature):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems super specialized, are there other values the constants can have?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, constants can have any value, but only equal_to_1 needs special handling because triton autodetects it and requires it in the constant dict to prevent segfaults this exact pattern is used in regular triton codegen https://github.com/pytorch/pytorch/blob/0d7ba9714ac77b2b4a446a9eff913a6ff9dfc782/torch/_inductor/codegen/triton.py#L5169C1-L5177C1 but was missing in combo kernels, so I added it.

Copy link
Contributor

@mlazos mlazos left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, please address my one question though

…qual_to_1 args"


Fixes: #168124
This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
…qual_to_1 args"


Fixes: #168124
This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
…qual_to_1 args"


Fixes: #168124
This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
…qual_to_1 args"


Fixes: #168124
This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
…qual_to_1 args"


Fixes: #168124
This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
…qual_to_1 args"


Fixes: #168124
This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
…qual_to_1 args"


Fixes: #168124
This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
…qual_to_1 args"


Fixes: #168124
This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
…qual_to_1 args"


Fixes: #168124
This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
…qual_to_1 args"


Fixes: #168124
This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
…qual_to_1 args"


Fixes: #168124
This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
…qual_to_1 args"


Fixes: #168124
This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
@karthickai karthickai added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 2, 2025
…qual_to_1 args"


Fixes: #168124
This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo mlazos chenyang78

[ghstack-poisoned]
…qual_to_1 args"


Fixes: #168124
This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo mlazos chenyang78

[ghstack-poisoned]
…qual_to_1 args"


Fixes: #168124
This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo mlazos chenyang78

[ghstack-poisoned]
…qual_to_1 args"


Fixes: #168124
This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo mlazos chenyang78

[ghstack-poisoned]
…qual_to_1 args"


Fixes: #168124
This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo mlazos chenyang78

[ghstack-poisoned]
…qual_to_1 args"


Fixes: #168124
This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo mlazos chenyang78

[ghstack-poisoned]
@karthickai
Copy link
Collaborator Author

@pytorchbot merge -i

umechand-amd pushed a commit to ROCm/pytorch that referenced this pull request Dec 8, 2025
…rgs (pytorch#168127)

Fixes: pytorch#168124
This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels.

Pull Request resolved: pytorch#168127
Approved by: https://github.com/mlazos
ghstack dependencies: pytorch#167781
JacobSzwejbka pushed a commit that referenced this pull request Dec 8, 2025
…rgs (#168127)

Fixes: #168124
This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels.

Pull Request resolved: #168127
Approved by: https://github.com/mlazos
ghstack dependencies: #167781
tiendatngcs pushed a commit to tiendatngcs/pytorch-Dec25 that referenced this pull request Dec 9, 2025
tiendatngcs pushed a commit to tiendatngcs/pytorch-Dec25 that referenced this pull request Dec 10, 2025
tiendatngcs pushed a commit to tiendatngcs/pytorch-Dec25 that referenced this pull request Dec 10, 2025
@github-actions github-actions bot deleted the gh/karthickai/16/head branch January 5, 2026 02:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants