Skip to content

[Inductor] Fix combo kernels variable collision with ND tiled reductions#168946

Closed
karthickai wants to merge 17 commits intogh/karthickai/21/basefrom
gh/karthickai/21/head
Closed

[Inductor] Fix combo kernels variable collision with ND tiled reductions#168946
karthickai wants to merge 17 commits intogh/karthickai/21/basefrom
gh/karthickai/21/head

Conversation

@karthickai
Copy link
Collaborator

@karthickai karthickai commented Nov 24, 2025

Stack from ghstack (oldest at bottom):

Fixes: #168945

Fix combo kernels crash when fusing op with multi-dim reductions (ND tiling). It caused variable name collisions in generated triton code when multiple reduction dimensions existed.

def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        RBLOCK_0: tl.constexpr = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        RBLOCK_0: tl.constexpr = 8
        ^
ValueError('RBLOCK_0 is already defined. constexpr cannot be reassigned.')

root cause: block size variable generation used sub-kernel index instead of reduction dim prefix, causing collisions when multiple reduction dims existed in the same sub-kernel.

after fix generated triton code:

import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties

@triton_heuristics.persistent_reduction(
    size_hints={'x': 4, 'r0_': 2, 'r1_': 8},
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'RoundRobinComboKernelGrid', 'combo_grid_meta': {'num_kernels': 1, 'min_blocks': 0, 'default_config': None, 'no_x_dim_0': None, 'xnumel_0': 4}, 'kernel_name': 'triton_per_fused_1', 'mutated_arg_names': ['in_out_ptr0'], 'backend_hash': '07C4B3116EC6B0BD20166279782DB98EA71861B79334EBBC8CCB3D36A1E5D7F2', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': False, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False}
)
@triton.jit
def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        R1_BLOCK_0: tl.constexpr = 8
        rnumel = r0_numel * r1_numel
        RBLOCK: tl.constexpr = R0_BLOCK_0*R1_BLOCK_0
        xoffset = pid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:, None, None]
        xmask = xindex < xnumel
        r0_index = tl.arange(0, R0_BLOCK_0)[None, :, None]
        r0_offset = 0
        r0_mask = r0_index < r0_numel
        r1_index = tl.arange(0, R1_BLOCK_0)[None, None, :]
        r1_offset = 0
        r1_mask = r1_index < r1_numel
        roffset = r1_offset + r0_offset*r1_numel
        rindex = r1_index + r0_index*r1_numel
        r0_1 = r0_index
        r1_2 = r1_index
        x0 = xindex
        tmp0 = tl.load(in_ptr0 + (r1_2 + 8*x0 + 32*r0_1), r0_mask & r1_mask & xmask, other=0.0)
        tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp3 = tl.where(r0_mask & r1_mask & xmask, tmp1, 0)
        tmp4 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp6 = tl.where(r0_mask & r1_mask & xmask, tmp4, 0)
        tmp7 = tl.reshape(tmp6, [XBLOCK, RBLOCK])
        tmp8 = tl.sum(tmp7, 1)[:, None, None].to(tl.float32)
        tmp9 = tl.full([1, 1, 1], 16, tl.int32)
        tmp10 = tmp9.to(tl.float32)
        tmp11 = (tmp8 / tmp10)
        tmp12 = tmp1 - tmp11
        tmp13 = tmp12 * tmp12
        tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp16 = tl.where(r0_mask & r1_mask & xmask, tmp14, 0)
        tmp17 = tl.reshape(tmp16, [XBLOCK, RBLOCK])
        tmp18 = tl.sum(tmp17, 1)[:, None, None].to(tl.float32)
        tmp19 = 15.0
        tmp20 = (tmp18 / tmp19)
        tl.debug_barrier()
        tl.store(in_out_ptr0 + (x0), tmp20, xmask)
        tl.store(out_ptr0 + (x0), tmp11, xmask)
    else:
        pass

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo @mlazos @chenyang78

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 24, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/168946

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit e2067b3 with merge base a208ed2 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

…led reductions"


Fixes: #168945

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
…led reductions"


Fixes: #168945

Fix combo kernels crash when fusing op with multi-dim reductions (ND tiling). It caused variable name collisions in generated triton code when multiple reduction dimensions existed.

```python
def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        RBLOCK_0: tl.constexpr = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        RBLOCK_0: tl.constexpr = 8
        ^
ValueError('RBLOCK_0 is already defined. constexpr cannot be reassigned.')
```

root cause: block size variable generation used sub-kernel index instead of reduction dim prefix, causing collisions when multiple reduction dims existed in the same sub-kernel.

after fix generated triton code:
```python
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties

triton_heuristics.persistent_reduction(
    size_hints={'x': 4, 'r0_': 2, 'r1_': 8},
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'RoundRobinComboKernelGrid', 'combo_grid_meta': {'num_kernels': 1, 'min_blocks': 0, 'default_config': None, 'no_x_dim_0': None, 'xnumel_0': 4}, 'kernel_name': 'triton_per_fused_1', 'mutated_arg_names': ['in_out_ptr0'], 'backend_hash': '07C4B3116EC6B0BD20166279782DB98EA71861B79334EBBC8CCB3D36A1E5D7F2', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': False, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False}
)
triton.jit
def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        R1_BLOCK_0: tl.constexpr = 8
        rnumel = r0_numel * r1_numel
        RBLOCK: tl.constexpr = R0_BLOCK_0*R1_BLOCK_0
        xoffset = pid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:, None, None]
        xmask = xindex < xnumel
        r0_index = tl.arange(0, R0_BLOCK_0)[None, :, None]
        r0_offset = 0
        r0_mask = r0_index < r0_numel
        r1_index = tl.arange(0, R1_BLOCK_0)[None, None, :]
        r1_offset = 0
        r1_mask = r1_index < r1_numel
        roffset = r1_offset + r0_offset*r1_numel
        rindex = r1_index + r0_index*r1_numel
        r0_1 = r0_index
        r1_2 = r1_index
        x0 = xindex
        tmp0 = tl.load(in_ptr0 + (r1_2 + 8*x0 + 32*r0_1), r0_mask & r1_mask & xmask, other=0.0)
        tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp3 = tl.where(r0_mask & r1_mask & xmask, tmp1, 0)
        tmp4 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp6 = tl.where(r0_mask & r1_mask & xmask, tmp4, 0)
        tmp7 = tl.reshape(tmp6, [XBLOCK, RBLOCK])
        tmp8 = tl.sum(tmp7, 1)[:, None, None].to(tl.float32)
        tmp9 = tl.full([1, 1, 1], 16, tl.int32)
        tmp10 = tmp9.to(tl.float32)
        tmp11 = (tmp8 / tmp10)
        tmp12 = tmp1 - tmp11
        tmp13 = tmp12 * tmp12
        tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp16 = tl.where(r0_mask & r1_mask & xmask, tmp14, 0)
        tmp17 = tl.reshape(tmp16, [XBLOCK, RBLOCK])
        tmp18 = tl.sum(tmp17, 1)[:, None, None].to(tl.float32)
        tmp19 = 15.0
        tmp20 = (tmp18 / tmp19)
        tl.debug_barrier()
        tl.store(in_out_ptr0 + (x0), tmp20, xmask)
        tl.store(out_ptr0 + (x0), tmp11, xmask)
    else:
        pass
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
…led reductions"


Fixes: #168945

Fix combo kernels crash when fusing op with multi-dim reductions (ND tiling). It caused variable name collisions in generated triton code when multiple reduction dimensions existed.

```python
def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        RBLOCK_0: tl.constexpr = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        RBLOCK_0: tl.constexpr = 8
        ^
ValueError('RBLOCK_0 is already defined. constexpr cannot be reassigned.')
```

root cause: block size variable generation used sub-kernel index instead of reduction dim prefix, causing collisions when multiple reduction dims existed in the same sub-kernel.

after fix generated triton code:
```python
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties

triton_heuristics.persistent_reduction(
    size_hints={'x': 4, 'r0_': 2, 'r1_': 8},
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'RoundRobinComboKernelGrid', 'combo_grid_meta': {'num_kernels': 1, 'min_blocks': 0, 'default_config': None, 'no_x_dim_0': None, 'xnumel_0': 4}, 'kernel_name': 'triton_per_fused_1', 'mutated_arg_names': ['in_out_ptr0'], 'backend_hash': '07C4B3116EC6B0BD20166279782DB98EA71861B79334EBBC8CCB3D36A1E5D7F2', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': False, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False}
)
triton.jit
def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        R1_BLOCK_0: tl.constexpr = 8
        rnumel = r0_numel * r1_numel
        RBLOCK: tl.constexpr = R0_BLOCK_0*R1_BLOCK_0
        xoffset = pid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:, None, None]
        xmask = xindex < xnumel
        r0_index = tl.arange(0, R0_BLOCK_0)[None, :, None]
        r0_offset = 0
        r0_mask = r0_index < r0_numel
        r1_index = tl.arange(0, R1_BLOCK_0)[None, None, :]
        r1_offset = 0
        r1_mask = r1_index < r1_numel
        roffset = r1_offset + r0_offset*r1_numel
        rindex = r1_index + r0_index*r1_numel
        r0_1 = r0_index
        r1_2 = r1_index
        x0 = xindex
        tmp0 = tl.load(in_ptr0 + (r1_2 + 8*x0 + 32*r0_1), r0_mask & r1_mask & xmask, other=0.0)
        tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp3 = tl.where(r0_mask & r1_mask & xmask, tmp1, 0)
        tmp4 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp6 = tl.where(r0_mask & r1_mask & xmask, tmp4, 0)
        tmp7 = tl.reshape(tmp6, [XBLOCK, RBLOCK])
        tmp8 = tl.sum(tmp7, 1)[:, None, None].to(tl.float32)
        tmp9 = tl.full([1, 1, 1], 16, tl.int32)
        tmp10 = tmp9.to(tl.float32)
        tmp11 = (tmp8 / tmp10)
        tmp12 = tmp1 - tmp11
        tmp13 = tmp12 * tmp12
        tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp16 = tl.where(r0_mask & r1_mask & xmask, tmp14, 0)
        tmp17 = tl.reshape(tmp16, [XBLOCK, RBLOCK])
        tmp18 = tl.sum(tmp17, 1)[:, None, None].to(tl.float32)
        tmp19 = 15.0
        tmp20 = (tmp18 / tmp19)
        tl.debug_barrier()
        tl.store(in_out_ptr0 + (x0), tmp20, xmask)
        tl.store(out_ptr0 + (x0), tmp11, xmask)
    else:
        pass
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
…led reductions"


Fixes: #168945

Fix combo kernels crash when fusing op with multi-dim reductions (ND tiling). It caused variable name collisions in generated triton code when multiple reduction dimensions existed.

```python
def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        RBLOCK_0: tl.constexpr = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        RBLOCK_0: tl.constexpr = 8
        ^
ValueError('RBLOCK_0 is already defined. constexpr cannot be reassigned.')
```

root cause: block size variable generation used sub-kernel index instead of reduction dim prefix, causing collisions when multiple reduction dims existed in the same sub-kernel.

after fix generated triton code:
```python
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties

triton_heuristics.persistent_reduction(
    size_hints={'x': 4, 'r0_': 2, 'r1_': 8},
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'RoundRobinComboKernelGrid', 'combo_grid_meta': {'num_kernels': 1, 'min_blocks': 0, 'default_config': None, 'no_x_dim_0': None, 'xnumel_0': 4}, 'kernel_name': 'triton_per_fused_1', 'mutated_arg_names': ['in_out_ptr0'], 'backend_hash': '07C4B3116EC6B0BD20166279782DB98EA71861B79334EBBC8CCB3D36A1E5D7F2', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': False, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False}
)
triton.jit
def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        R1_BLOCK_0: tl.constexpr = 8
        rnumel = r0_numel * r1_numel
        RBLOCK: tl.constexpr = R0_BLOCK_0*R1_BLOCK_0
        xoffset = pid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:, None, None]
        xmask = xindex < xnumel
        r0_index = tl.arange(0, R0_BLOCK_0)[None, :, None]
        r0_offset = 0
        r0_mask = r0_index < r0_numel
        r1_index = tl.arange(0, R1_BLOCK_0)[None, None, :]
        r1_offset = 0
        r1_mask = r1_index < r1_numel
        roffset = r1_offset + r0_offset*r1_numel
        rindex = r1_index + r0_index*r1_numel
        r0_1 = r0_index
        r1_2 = r1_index
        x0 = xindex
        tmp0 = tl.load(in_ptr0 + (r1_2 + 8*x0 + 32*r0_1), r0_mask & r1_mask & xmask, other=0.0)
        tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp3 = tl.where(r0_mask & r1_mask & xmask, tmp1, 0)
        tmp4 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp6 = tl.where(r0_mask & r1_mask & xmask, tmp4, 0)
        tmp7 = tl.reshape(tmp6, [XBLOCK, RBLOCK])
        tmp8 = tl.sum(tmp7, 1)[:, None, None].to(tl.float32)
        tmp9 = tl.full([1, 1, 1], 16, tl.int32)
        tmp10 = tmp9.to(tl.float32)
        tmp11 = (tmp8 / tmp10)
        tmp12 = tmp1 - tmp11
        tmp13 = tmp12 * tmp12
        tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp16 = tl.where(r0_mask & r1_mask & xmask, tmp14, 0)
        tmp17 = tl.reshape(tmp16, [XBLOCK, RBLOCK])
        tmp18 = tl.sum(tmp17, 1)[:, None, None].to(tl.float32)
        tmp19 = 15.0
        tmp20 = (tmp18 / tmp19)
        tl.debug_barrier()
        tl.store(in_out_ptr0 + (x0), tmp20, xmask)
        tl.store(out_ptr0 + (x0), tmp11, xmask)
    else:
        pass
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
…led reductions"


Fixes: #168945

Fix combo kernels crash when fusing op with multi-dim reductions (ND tiling). It caused variable name collisions in generated triton code when multiple reduction dimensions existed.

```python
def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        RBLOCK_0: tl.constexpr = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        RBLOCK_0: tl.constexpr = 8
        ^
ValueError('RBLOCK_0 is already defined. constexpr cannot be reassigned.')
```

root cause: block size variable generation used sub-kernel index instead of reduction dim prefix, causing collisions when multiple reduction dims existed in the same sub-kernel.

after fix generated triton code:
```python
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties

triton_heuristics.persistent_reduction(
    size_hints={'x': 4, 'r0_': 2, 'r1_': 8},
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'RoundRobinComboKernelGrid', 'combo_grid_meta': {'num_kernels': 1, 'min_blocks': 0, 'default_config': None, 'no_x_dim_0': None, 'xnumel_0': 4}, 'kernel_name': 'triton_per_fused_1', 'mutated_arg_names': ['in_out_ptr0'], 'backend_hash': '07C4B3116EC6B0BD20166279782DB98EA71861B79334EBBC8CCB3D36A1E5D7F2', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': False, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False}
)
triton.jit
def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        R1_BLOCK_0: tl.constexpr = 8
        rnumel = r0_numel * r1_numel
        RBLOCK: tl.constexpr = R0_BLOCK_0*R1_BLOCK_0
        xoffset = pid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:, None, None]
        xmask = xindex < xnumel
        r0_index = tl.arange(0, R0_BLOCK_0)[None, :, None]
        r0_offset = 0
        r0_mask = r0_index < r0_numel
        r1_index = tl.arange(0, R1_BLOCK_0)[None, None, :]
        r1_offset = 0
        r1_mask = r1_index < r1_numel
        roffset = r1_offset + r0_offset*r1_numel
        rindex = r1_index + r0_index*r1_numel
        r0_1 = r0_index
        r1_2 = r1_index
        x0 = xindex
        tmp0 = tl.load(in_ptr0 + (r1_2 + 8*x0 + 32*r0_1), r0_mask & r1_mask & xmask, other=0.0)
        tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp3 = tl.where(r0_mask & r1_mask & xmask, tmp1, 0)
        tmp4 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp6 = tl.where(r0_mask & r1_mask & xmask, tmp4, 0)
        tmp7 = tl.reshape(tmp6, [XBLOCK, RBLOCK])
        tmp8 = tl.sum(tmp7, 1)[:, None, None].to(tl.float32)
        tmp9 = tl.full([1, 1, 1], 16, tl.int32)
        tmp10 = tmp9.to(tl.float32)
        tmp11 = (tmp8 / tmp10)
        tmp12 = tmp1 - tmp11
        tmp13 = tmp12 * tmp12
        tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp16 = tl.where(r0_mask & r1_mask & xmask, tmp14, 0)
        tmp17 = tl.reshape(tmp16, [XBLOCK, RBLOCK])
        tmp18 = tl.sum(tmp17, 1)[:, None, None].to(tl.float32)
        tmp19 = 15.0
        tmp20 = (tmp18 / tmp19)
        tl.debug_barrier()
        tl.store(in_out_ptr0 + (x0), tmp20, xmask)
        tl.store(out_ptr0 + (x0), tmp11, xmask)
    else:
        pass
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
…led reductions"


Fixes: #168945

Fix combo kernels crash when fusing op with multi-dim reductions (ND tiling). It caused variable name collisions in generated triton code when multiple reduction dimensions existed.

```python
def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        RBLOCK_0: tl.constexpr = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        RBLOCK_0: tl.constexpr = 8
        ^
ValueError('RBLOCK_0 is already defined. constexpr cannot be reassigned.')
```

root cause: block size variable generation used sub-kernel index instead of reduction dim prefix, causing collisions when multiple reduction dims existed in the same sub-kernel.

after fix generated triton code:
```python
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties

triton_heuristics.persistent_reduction(
    size_hints={'x': 4, 'r0_': 2, 'r1_': 8},
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'RoundRobinComboKernelGrid', 'combo_grid_meta': {'num_kernels': 1, 'min_blocks': 0, 'default_config': None, 'no_x_dim_0': None, 'xnumel_0': 4}, 'kernel_name': 'triton_per_fused_1', 'mutated_arg_names': ['in_out_ptr0'], 'backend_hash': '07C4B3116EC6B0BD20166279782DB98EA71861B79334EBBC8CCB3D36A1E5D7F2', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': False, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False}
)
triton.jit
def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        R1_BLOCK_0: tl.constexpr = 8
        rnumel = r0_numel * r1_numel
        RBLOCK: tl.constexpr = R0_BLOCK_0*R1_BLOCK_0
        xoffset = pid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:, None, None]
        xmask = xindex < xnumel
        r0_index = tl.arange(0, R0_BLOCK_0)[None, :, None]
        r0_offset = 0
        r0_mask = r0_index < r0_numel
        r1_index = tl.arange(0, R1_BLOCK_0)[None, None, :]
        r1_offset = 0
        r1_mask = r1_index < r1_numel
        roffset = r1_offset + r0_offset*r1_numel
        rindex = r1_index + r0_index*r1_numel
        r0_1 = r0_index
        r1_2 = r1_index
        x0 = xindex
        tmp0 = tl.load(in_ptr0 + (r1_2 + 8*x0 + 32*r0_1), r0_mask & r1_mask & xmask, other=0.0)
        tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp3 = tl.where(r0_mask & r1_mask & xmask, tmp1, 0)
        tmp4 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp6 = tl.where(r0_mask & r1_mask & xmask, tmp4, 0)
        tmp7 = tl.reshape(tmp6, [XBLOCK, RBLOCK])
        tmp8 = tl.sum(tmp7, 1)[:, None, None].to(tl.float32)
        tmp9 = tl.full([1, 1, 1], 16, tl.int32)
        tmp10 = tmp9.to(tl.float32)
        tmp11 = (tmp8 / tmp10)
        tmp12 = tmp1 - tmp11
        tmp13 = tmp12 * tmp12
        tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp16 = tl.where(r0_mask & r1_mask & xmask, tmp14, 0)
        tmp17 = tl.reshape(tmp16, [XBLOCK, RBLOCK])
        tmp18 = tl.sum(tmp17, 1)[:, None, None].to(tl.float32)
        tmp19 = 15.0
        tmp20 = (tmp18 / tmp19)
        tl.debug_barrier()
        tl.store(in_out_ptr0 + (x0), tmp20, xmask)
        tl.store(out_ptr0 + (x0), tmp11, xmask)
    else:
        pass
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
…led reductions"


Fixes: #168945

Fix combo kernels crash when fusing op with multi-dim reductions (ND tiling). It caused variable name collisions in generated triton code when multiple reduction dimensions existed.

```python
def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        RBLOCK_0: tl.constexpr = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        RBLOCK_0: tl.constexpr = 8
        ^
ValueError('RBLOCK_0 is already defined. constexpr cannot be reassigned.')
```

root cause: block size variable generation used sub-kernel index instead of reduction dim prefix, causing collisions when multiple reduction dims existed in the same sub-kernel.

after fix generated triton code:
```python
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties

triton_heuristics.persistent_reduction(
    size_hints={'x': 4, 'r0_': 2, 'r1_': 8},
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'RoundRobinComboKernelGrid', 'combo_grid_meta': {'num_kernels': 1, 'min_blocks': 0, 'default_config': None, 'no_x_dim_0': None, 'xnumel_0': 4}, 'kernel_name': 'triton_per_fused_1', 'mutated_arg_names': ['in_out_ptr0'], 'backend_hash': '07C4B3116EC6B0BD20166279782DB98EA71861B79334EBBC8CCB3D36A1E5D7F2', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': False, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False}
)
triton.jit
def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        R1_BLOCK_0: tl.constexpr = 8
        rnumel = r0_numel * r1_numel
        RBLOCK: tl.constexpr = R0_BLOCK_0*R1_BLOCK_0
        xoffset = pid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:, None, None]
        xmask = xindex < xnumel
        r0_index = tl.arange(0, R0_BLOCK_0)[None, :, None]
        r0_offset = 0
        r0_mask = r0_index < r0_numel
        r1_index = tl.arange(0, R1_BLOCK_0)[None, None, :]
        r1_offset = 0
        r1_mask = r1_index < r1_numel
        roffset = r1_offset + r0_offset*r1_numel
        rindex = r1_index + r0_index*r1_numel
        r0_1 = r0_index
        r1_2 = r1_index
        x0 = xindex
        tmp0 = tl.load(in_ptr0 + (r1_2 + 8*x0 + 32*r0_1), r0_mask & r1_mask & xmask, other=0.0)
        tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp3 = tl.where(r0_mask & r1_mask & xmask, tmp1, 0)
        tmp4 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp6 = tl.where(r0_mask & r1_mask & xmask, tmp4, 0)
        tmp7 = tl.reshape(tmp6, [XBLOCK, RBLOCK])
        tmp8 = tl.sum(tmp7, 1)[:, None, None].to(tl.float32)
        tmp9 = tl.full([1, 1, 1], 16, tl.int32)
        tmp10 = tmp9.to(tl.float32)
        tmp11 = (tmp8 / tmp10)
        tmp12 = tmp1 - tmp11
        tmp13 = tmp12 * tmp12
        tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp16 = tl.where(r0_mask & r1_mask & xmask, tmp14, 0)
        tmp17 = tl.reshape(tmp16, [XBLOCK, RBLOCK])
        tmp18 = tl.sum(tmp17, 1)[:, None, None].to(tl.float32)
        tmp19 = 15.0
        tmp20 = (tmp18 / tmp19)
        tl.debug_barrier()
        tl.store(in_out_ptr0 + (x0), tmp20, xmask)
        tl.store(out_ptr0 + (x0), tmp11, xmask)
    else:
        pass
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
…led reductions"


Fixes: #168945

Fix combo kernels crash when fusing op with multi-dim reductions (ND tiling). It caused variable name collisions in generated triton code when multiple reduction dimensions existed.

```python
def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        RBLOCK_0: tl.constexpr = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        RBLOCK_0: tl.constexpr = 8
        ^
ValueError('RBLOCK_0 is already defined. constexpr cannot be reassigned.')
```

root cause: block size variable generation used sub-kernel index instead of reduction dim prefix, causing collisions when multiple reduction dims existed in the same sub-kernel.

after fix generated triton code:
```python
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties

triton_heuristics.persistent_reduction(
    size_hints={'x': 4, 'r0_': 2, 'r1_': 8},
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'RoundRobinComboKernelGrid', 'combo_grid_meta': {'num_kernels': 1, 'min_blocks': 0, 'default_config': None, 'no_x_dim_0': None, 'xnumel_0': 4}, 'kernel_name': 'triton_per_fused_1', 'mutated_arg_names': ['in_out_ptr0'], 'backend_hash': '07C4B3116EC6B0BD20166279782DB98EA71861B79334EBBC8CCB3D36A1E5D7F2', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': False, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False}
)
triton.jit
def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        R1_BLOCK_0: tl.constexpr = 8
        rnumel = r0_numel * r1_numel
        RBLOCK: tl.constexpr = R0_BLOCK_0*R1_BLOCK_0
        xoffset = pid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:, None, None]
        xmask = xindex < xnumel
        r0_index = tl.arange(0, R0_BLOCK_0)[None, :, None]
        r0_offset = 0
        r0_mask = r0_index < r0_numel
        r1_index = tl.arange(0, R1_BLOCK_0)[None, None, :]
        r1_offset = 0
        r1_mask = r1_index < r1_numel
        roffset = r1_offset + r0_offset*r1_numel
        rindex = r1_index + r0_index*r1_numel
        r0_1 = r0_index
        r1_2 = r1_index
        x0 = xindex
        tmp0 = tl.load(in_ptr0 + (r1_2 + 8*x0 + 32*r0_1), r0_mask & r1_mask & xmask, other=0.0)
        tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp3 = tl.where(r0_mask & r1_mask & xmask, tmp1, 0)
        tmp4 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp6 = tl.where(r0_mask & r1_mask & xmask, tmp4, 0)
        tmp7 = tl.reshape(tmp6, [XBLOCK, RBLOCK])
        tmp8 = tl.sum(tmp7, 1)[:, None, None].to(tl.float32)
        tmp9 = tl.full([1, 1, 1], 16, tl.int32)
        tmp10 = tmp9.to(tl.float32)
        tmp11 = (tmp8 / tmp10)
        tmp12 = tmp1 - tmp11
        tmp13 = tmp12 * tmp12
        tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp16 = tl.where(r0_mask & r1_mask & xmask, tmp14, 0)
        tmp17 = tl.reshape(tmp16, [XBLOCK, RBLOCK])
        tmp18 = tl.sum(tmp17, 1)[:, None, None].to(tl.float32)
        tmp19 = 15.0
        tmp20 = (tmp18 / tmp19)
        tl.debug_barrier()
        tl.store(in_out_ptr0 + (x0), tmp20, xmask)
        tl.store(out_ptr0 + (x0), tmp11, xmask)
    else:
        pass
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
…led reductions"


Fixes: #168945

Fix combo kernels crash when fusing op with multi-dim reductions (ND tiling). It caused variable name collisions in generated triton code when multiple reduction dimensions existed.

```python
def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        RBLOCK_0: tl.constexpr = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        RBLOCK_0: tl.constexpr = 8
        ^
ValueError('RBLOCK_0 is already defined. constexpr cannot be reassigned.')
```

root cause: block size variable generation used sub-kernel index instead of reduction dim prefix, causing collisions when multiple reduction dims existed in the same sub-kernel.

after fix generated triton code:
```python
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties

triton_heuristics.persistent_reduction(
    size_hints={'x': 4, 'r0_': 2, 'r1_': 8},
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'RoundRobinComboKernelGrid', 'combo_grid_meta': {'num_kernels': 1, 'min_blocks': 0, 'default_config': None, 'no_x_dim_0': None, 'xnumel_0': 4}, 'kernel_name': 'triton_per_fused_1', 'mutated_arg_names': ['in_out_ptr0'], 'backend_hash': '07C4B3116EC6B0BD20166279782DB98EA71861B79334EBBC8CCB3D36A1E5D7F2', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': False, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False}
)
triton.jit
def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        R1_BLOCK_0: tl.constexpr = 8
        rnumel = r0_numel * r1_numel
        RBLOCK: tl.constexpr = R0_BLOCK_0*R1_BLOCK_0
        xoffset = pid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:, None, None]
        xmask = xindex < xnumel
        r0_index = tl.arange(0, R0_BLOCK_0)[None, :, None]
        r0_offset = 0
        r0_mask = r0_index < r0_numel
        r1_index = tl.arange(0, R1_BLOCK_0)[None, None, :]
        r1_offset = 0
        r1_mask = r1_index < r1_numel
        roffset = r1_offset + r0_offset*r1_numel
        rindex = r1_index + r0_index*r1_numel
        r0_1 = r0_index
        r1_2 = r1_index
        x0 = xindex
        tmp0 = tl.load(in_ptr0 + (r1_2 + 8*x0 + 32*r0_1), r0_mask & r1_mask & xmask, other=0.0)
        tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp3 = tl.where(r0_mask & r1_mask & xmask, tmp1, 0)
        tmp4 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp6 = tl.where(r0_mask & r1_mask & xmask, tmp4, 0)
        tmp7 = tl.reshape(tmp6, [XBLOCK, RBLOCK])
        tmp8 = tl.sum(tmp7, 1)[:, None, None].to(tl.float32)
        tmp9 = tl.full([1, 1, 1], 16, tl.int32)
        tmp10 = tmp9.to(tl.float32)
        tmp11 = (tmp8 / tmp10)
        tmp12 = tmp1 - tmp11
        tmp13 = tmp12 * tmp12
        tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp16 = tl.where(r0_mask & r1_mask & xmask, tmp14, 0)
        tmp17 = tl.reshape(tmp16, [XBLOCK, RBLOCK])
        tmp18 = tl.sum(tmp17, 1)[:, None, None].to(tl.float32)
        tmp19 = 15.0
        tmp20 = (tmp18 / tmp19)
        tl.debug_barrier()
        tl.store(in_out_ptr0 + (x0), tmp20, xmask)
        tl.store(out_ptr0 + (x0), tmp11, xmask)
    else:
        pass
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
…led reductions"


Fixes: #168945

Fix combo kernels crash when fusing op with multi-dim reductions (ND tiling). It caused variable name collisions in generated triton code when multiple reduction dimensions existed.

```python
def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        RBLOCK_0: tl.constexpr = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        RBLOCK_0: tl.constexpr = 8
        ^
ValueError('RBLOCK_0 is already defined. constexpr cannot be reassigned.')
```

root cause: block size variable generation used sub-kernel index instead of reduction dim prefix, causing collisions when multiple reduction dims existed in the same sub-kernel.

after fix generated triton code:
```python
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties

triton_heuristics.persistent_reduction(
    size_hints={'x': 4, 'r0_': 2, 'r1_': 8},
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'RoundRobinComboKernelGrid', 'combo_grid_meta': {'num_kernels': 1, 'min_blocks': 0, 'default_config': None, 'no_x_dim_0': None, 'xnumel_0': 4}, 'kernel_name': 'triton_per_fused_1', 'mutated_arg_names': ['in_out_ptr0'], 'backend_hash': '07C4B3116EC6B0BD20166279782DB98EA71861B79334EBBC8CCB3D36A1E5D7F2', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': False, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False}
)
triton.jit
def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        R1_BLOCK_0: tl.constexpr = 8
        rnumel = r0_numel * r1_numel
        RBLOCK: tl.constexpr = R0_BLOCK_0*R1_BLOCK_0
        xoffset = pid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:, None, None]
        xmask = xindex < xnumel
        r0_index = tl.arange(0, R0_BLOCK_0)[None, :, None]
        r0_offset = 0
        r0_mask = r0_index < r0_numel
        r1_index = tl.arange(0, R1_BLOCK_0)[None, None, :]
        r1_offset = 0
        r1_mask = r1_index < r1_numel
        roffset = r1_offset + r0_offset*r1_numel
        rindex = r1_index + r0_index*r1_numel
        r0_1 = r0_index
        r1_2 = r1_index
        x0 = xindex
        tmp0 = tl.load(in_ptr0 + (r1_2 + 8*x0 + 32*r0_1), r0_mask & r1_mask & xmask, other=0.0)
        tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp3 = tl.where(r0_mask & r1_mask & xmask, tmp1, 0)
        tmp4 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp6 = tl.where(r0_mask & r1_mask & xmask, tmp4, 0)
        tmp7 = tl.reshape(tmp6, [XBLOCK, RBLOCK])
        tmp8 = tl.sum(tmp7, 1)[:, None, None].to(tl.float32)
        tmp9 = tl.full([1, 1, 1], 16, tl.int32)
        tmp10 = tmp9.to(tl.float32)
        tmp11 = (tmp8 / tmp10)
        tmp12 = tmp1 - tmp11
        tmp13 = tmp12 * tmp12
        tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp16 = tl.where(r0_mask & r1_mask & xmask, tmp14, 0)
        tmp17 = tl.reshape(tmp16, [XBLOCK, RBLOCK])
        tmp18 = tl.sum(tmp17, 1)[:, None, None].to(tl.float32)
        tmp19 = 15.0
        tmp20 = (tmp18 / tmp19)
        tl.debug_barrier()
        tl.store(in_out_ptr0 + (x0), tmp20, xmask)
        tl.store(out_ptr0 + (x0), tmp11, xmask)
    else:
        pass
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
…led reductions"


Fixes: #168945

Fix combo kernels crash when fusing op with multi-dim reductions (ND tiling). It caused variable name collisions in generated triton code when multiple reduction dimensions existed.

```python
def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        RBLOCK_0: tl.constexpr = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        RBLOCK_0: tl.constexpr = 8
        ^
ValueError('RBLOCK_0 is already defined. constexpr cannot be reassigned.')
```

root cause: block size variable generation used sub-kernel index instead of reduction dim prefix, causing collisions when multiple reduction dims existed in the same sub-kernel.

after fix generated triton code:
```python
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties

triton_heuristics.persistent_reduction(
    size_hints={'x': 4, 'r0_': 2, 'r1_': 8},
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'RoundRobinComboKernelGrid', 'combo_grid_meta': {'num_kernels': 1, 'min_blocks': 0, 'default_config': None, 'no_x_dim_0': None, 'xnumel_0': 4}, 'kernel_name': 'triton_per_fused_1', 'mutated_arg_names': ['in_out_ptr0'], 'backend_hash': '07C4B3116EC6B0BD20166279782DB98EA71861B79334EBBC8CCB3D36A1E5D7F2', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': False, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False}
)
triton.jit
def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        R1_BLOCK_0: tl.constexpr = 8
        rnumel = r0_numel * r1_numel
        RBLOCK: tl.constexpr = R0_BLOCK_0*R1_BLOCK_0
        xoffset = pid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:, None, None]
        xmask = xindex < xnumel
        r0_index = tl.arange(0, R0_BLOCK_0)[None, :, None]
        r0_offset = 0
        r0_mask = r0_index < r0_numel
        r1_index = tl.arange(0, R1_BLOCK_0)[None, None, :]
        r1_offset = 0
        r1_mask = r1_index < r1_numel
        roffset = r1_offset + r0_offset*r1_numel
        rindex = r1_index + r0_index*r1_numel
        r0_1 = r0_index
        r1_2 = r1_index
        x0 = xindex
        tmp0 = tl.load(in_ptr0 + (r1_2 + 8*x0 + 32*r0_1), r0_mask & r1_mask & xmask, other=0.0)
        tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp3 = tl.where(r0_mask & r1_mask & xmask, tmp1, 0)
        tmp4 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp6 = tl.where(r0_mask & r1_mask & xmask, tmp4, 0)
        tmp7 = tl.reshape(tmp6, [XBLOCK, RBLOCK])
        tmp8 = tl.sum(tmp7, 1)[:, None, None].to(tl.float32)
        tmp9 = tl.full([1, 1, 1], 16, tl.int32)
        tmp10 = tmp9.to(tl.float32)
        tmp11 = (tmp8 / tmp10)
        tmp12 = tmp1 - tmp11
        tmp13 = tmp12 * tmp12
        tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp16 = tl.where(r0_mask & r1_mask & xmask, tmp14, 0)
        tmp17 = tl.reshape(tmp16, [XBLOCK, RBLOCK])
        tmp18 = tl.sum(tmp17, 1)[:, None, None].to(tl.float32)
        tmp19 = 15.0
        tmp20 = (tmp18 / tmp19)
        tl.debug_barrier()
        tl.store(in_out_ptr0 + (x0), tmp20, xmask)
        tl.store(out_ptr0 + (x0), tmp11, xmask)
    else:
        pass
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
…led reductions"


Fixes: #168945

Fix combo kernels crash when fusing op with multi-dim reductions (ND tiling). It caused variable name collisions in generated triton code when multiple reduction dimensions existed.

```python
def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        RBLOCK_0: tl.constexpr = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        RBLOCK_0: tl.constexpr = 8
        ^
ValueError('RBLOCK_0 is already defined. constexpr cannot be reassigned.')
```

root cause: block size variable generation used sub-kernel index instead of reduction dim prefix, causing collisions when multiple reduction dims existed in the same sub-kernel.

after fix generated triton code:
```python
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties

triton_heuristics.persistent_reduction(
    size_hints={'x': 4, 'r0_': 2, 'r1_': 8},
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'RoundRobinComboKernelGrid', 'combo_grid_meta': {'num_kernels': 1, 'min_blocks': 0, 'default_config': None, 'no_x_dim_0': None, 'xnumel_0': 4}, 'kernel_name': 'triton_per_fused_1', 'mutated_arg_names': ['in_out_ptr0'], 'backend_hash': '07C4B3116EC6B0BD20166279782DB98EA71861B79334EBBC8CCB3D36A1E5D7F2', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': False, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False}
)
triton.jit
def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        R1_BLOCK_0: tl.constexpr = 8
        rnumel = r0_numel * r1_numel
        RBLOCK: tl.constexpr = R0_BLOCK_0*R1_BLOCK_0
        xoffset = pid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:, None, None]
        xmask = xindex < xnumel
        r0_index = tl.arange(0, R0_BLOCK_0)[None, :, None]
        r0_offset = 0
        r0_mask = r0_index < r0_numel
        r1_index = tl.arange(0, R1_BLOCK_0)[None, None, :]
        r1_offset = 0
        r1_mask = r1_index < r1_numel
        roffset = r1_offset + r0_offset*r1_numel
        rindex = r1_index + r0_index*r1_numel
        r0_1 = r0_index
        r1_2 = r1_index
        x0 = xindex
        tmp0 = tl.load(in_ptr0 + (r1_2 + 8*x0 + 32*r0_1), r0_mask & r1_mask & xmask, other=0.0)
        tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp3 = tl.where(r0_mask & r1_mask & xmask, tmp1, 0)
        tmp4 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp6 = tl.where(r0_mask & r1_mask & xmask, tmp4, 0)
        tmp7 = tl.reshape(tmp6, [XBLOCK, RBLOCK])
        tmp8 = tl.sum(tmp7, 1)[:, None, None].to(tl.float32)
        tmp9 = tl.full([1, 1, 1], 16, tl.int32)
        tmp10 = tmp9.to(tl.float32)
        tmp11 = (tmp8 / tmp10)
        tmp12 = tmp1 - tmp11
        tmp13 = tmp12 * tmp12
        tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp16 = tl.where(r0_mask & r1_mask & xmask, tmp14, 0)
        tmp17 = tl.reshape(tmp16, [XBLOCK, RBLOCK])
        tmp18 = tl.sum(tmp17, 1)[:, None, None].to(tl.float32)
        tmp19 = 15.0
        tmp20 = (tmp18 / tmp19)
        tl.debug_barrier()
        tl.store(in_out_ptr0 + (x0), tmp20, xmask)
        tl.store(out_ptr0 + (x0), tmp11, xmask)
    else:
        pass
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
…led reductions"


Fixes: #168945

Fix combo kernels crash when fusing op with multi-dim reductions (ND tiling). It caused variable name collisions in generated triton code when multiple reduction dimensions existed.

```python
def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        RBLOCK_0: tl.constexpr = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        RBLOCK_0: tl.constexpr = 8
        ^
ValueError('RBLOCK_0 is already defined. constexpr cannot be reassigned.')
```

root cause: block size variable generation used sub-kernel index instead of reduction dim prefix, causing collisions when multiple reduction dims existed in the same sub-kernel.

after fix generated triton code:
```python
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties

triton_heuristics.persistent_reduction(
    size_hints={'x': 4, 'r0_': 2, 'r1_': 8},
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'RoundRobinComboKernelGrid', 'combo_grid_meta': {'num_kernels': 1, 'min_blocks': 0, 'default_config': None, 'no_x_dim_0': None, 'xnumel_0': 4}, 'kernel_name': 'triton_per_fused_1', 'mutated_arg_names': ['in_out_ptr0'], 'backend_hash': '07C4B3116EC6B0BD20166279782DB98EA71861B79334EBBC8CCB3D36A1E5D7F2', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': False, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False}
)
triton.jit
def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        R1_BLOCK_0: tl.constexpr = 8
        rnumel = r0_numel * r1_numel
        RBLOCK: tl.constexpr = R0_BLOCK_0*R1_BLOCK_0
        xoffset = pid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:, None, None]
        xmask = xindex < xnumel
        r0_index = tl.arange(0, R0_BLOCK_0)[None, :, None]
        r0_offset = 0
        r0_mask = r0_index < r0_numel
        r1_index = tl.arange(0, R1_BLOCK_0)[None, None, :]
        r1_offset = 0
        r1_mask = r1_index < r1_numel
        roffset = r1_offset + r0_offset*r1_numel
        rindex = r1_index + r0_index*r1_numel
        r0_1 = r0_index
        r1_2 = r1_index
        x0 = xindex
        tmp0 = tl.load(in_ptr0 + (r1_2 + 8*x0 + 32*r0_1), r0_mask & r1_mask & xmask, other=0.0)
        tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp3 = tl.where(r0_mask & r1_mask & xmask, tmp1, 0)
        tmp4 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp6 = tl.where(r0_mask & r1_mask & xmask, tmp4, 0)
        tmp7 = tl.reshape(tmp6, [XBLOCK, RBLOCK])
        tmp8 = tl.sum(tmp7, 1)[:, None, None].to(tl.float32)
        tmp9 = tl.full([1, 1, 1], 16, tl.int32)
        tmp10 = tmp9.to(tl.float32)
        tmp11 = (tmp8 / tmp10)
        tmp12 = tmp1 - tmp11
        tmp13 = tmp12 * tmp12
        tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp16 = tl.where(r0_mask & r1_mask & xmask, tmp14, 0)
        tmp17 = tl.reshape(tmp16, [XBLOCK, RBLOCK])
        tmp18 = tl.sum(tmp17, 1)[:, None, None].to(tl.float32)
        tmp19 = 15.0
        tmp20 = (tmp18 / tmp19)
        tl.debug_barrier()
        tl.store(in_out_ptr0 + (x0), tmp20, xmask)
        tl.store(out_ptr0 + (x0), tmp11, xmask)
    else:
        pass
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
…led reductions"


Fixes: #168945

Fix combo kernels crash when fusing op with multi-dim reductions (ND tiling). It caused variable name collisions in generated triton code when multiple reduction dimensions existed.

```python
def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        RBLOCK_0: tl.constexpr = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        RBLOCK_0: tl.constexpr = 8
        ^
ValueError('RBLOCK_0 is already defined. constexpr cannot be reassigned.')
```

root cause: block size variable generation used sub-kernel index instead of reduction dim prefix, causing collisions when multiple reduction dims existed in the same sub-kernel.

after fix generated triton code:
```python
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties

triton_heuristics.persistent_reduction(
    size_hints={'x': 4, 'r0_': 2, 'r1_': 8},
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'RoundRobinComboKernelGrid', 'combo_grid_meta': {'num_kernels': 1, 'min_blocks': 0, 'default_config': None, 'no_x_dim_0': None, 'xnumel_0': 4}, 'kernel_name': 'triton_per_fused_1', 'mutated_arg_names': ['in_out_ptr0'], 'backend_hash': '07C4B3116EC6B0BD20166279782DB98EA71861B79334EBBC8CCB3D36A1E5D7F2', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': False, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False}
)
triton.jit
def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        R1_BLOCK_0: tl.constexpr = 8
        rnumel = r0_numel * r1_numel
        RBLOCK: tl.constexpr = R0_BLOCK_0*R1_BLOCK_0
        xoffset = pid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:, None, None]
        xmask = xindex < xnumel
        r0_index = tl.arange(0, R0_BLOCK_0)[None, :, None]
        r0_offset = 0
        r0_mask = r0_index < r0_numel
        r1_index = tl.arange(0, R1_BLOCK_0)[None, None, :]
        r1_offset = 0
        r1_mask = r1_index < r1_numel
        roffset = r1_offset + r0_offset*r1_numel
        rindex = r1_index + r0_index*r1_numel
        r0_1 = r0_index
        r1_2 = r1_index
        x0 = xindex
        tmp0 = tl.load(in_ptr0 + (r1_2 + 8*x0 + 32*r0_1), r0_mask & r1_mask & xmask, other=0.0)
        tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp3 = tl.where(r0_mask & r1_mask & xmask, tmp1, 0)
        tmp4 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp6 = tl.where(r0_mask & r1_mask & xmask, tmp4, 0)
        tmp7 = tl.reshape(tmp6, [XBLOCK, RBLOCK])
        tmp8 = tl.sum(tmp7, 1)[:, None, None].to(tl.float32)
        tmp9 = tl.full([1, 1, 1], 16, tl.int32)
        tmp10 = tmp9.to(tl.float32)
        tmp11 = (tmp8 / tmp10)
        tmp12 = tmp1 - tmp11
        tmp13 = tmp12 * tmp12
        tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp16 = tl.where(r0_mask & r1_mask & xmask, tmp14, 0)
        tmp17 = tl.reshape(tmp16, [XBLOCK, RBLOCK])
        tmp18 = tl.sum(tmp17, 1)[:, None, None].to(tl.float32)
        tmp19 = 15.0
        tmp20 = (tmp18 / tmp19)
        tl.debug_barrier()
        tl.store(in_out_ptr0 + (x0), tmp20, xmask)
        tl.store(out_ptr0 + (x0), tmp11, xmask)
    else:
        pass
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
…led reductions"


Fixes: #168945

Fix combo kernels crash when fusing op with multi-dim reductions (ND tiling). It caused variable name collisions in generated triton code when multiple reduction dimensions existed.

```python
def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        RBLOCK_0: tl.constexpr = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        RBLOCK_0: tl.constexpr = 8
        ^
ValueError('RBLOCK_0 is already defined. constexpr cannot be reassigned.')
```

root cause: block size variable generation used sub-kernel index instead of reduction dim prefix, causing collisions when multiple reduction dims existed in the same sub-kernel.

after fix generated triton code:
```python
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties

triton_heuristics.persistent_reduction(
    size_hints={'x': 4, 'r0_': 2, 'r1_': 8},
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'RoundRobinComboKernelGrid', 'combo_grid_meta': {'num_kernels': 1, 'min_blocks': 0, 'default_config': None, 'no_x_dim_0': None, 'xnumel_0': 4}, 'kernel_name': 'triton_per_fused_1', 'mutated_arg_names': ['in_out_ptr0'], 'backend_hash': '07C4B3116EC6B0BD20166279782DB98EA71861B79334EBBC8CCB3D36A1E5D7F2', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': False, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False}
)
triton.jit
def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        R1_BLOCK_0: tl.constexpr = 8
        rnumel = r0_numel * r1_numel
        RBLOCK: tl.constexpr = R0_BLOCK_0*R1_BLOCK_0
        xoffset = pid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:, None, None]
        xmask = xindex < xnumel
        r0_index = tl.arange(0, R0_BLOCK_0)[None, :, None]
        r0_offset = 0
        r0_mask = r0_index < r0_numel
        r1_index = tl.arange(0, R1_BLOCK_0)[None, None, :]
        r1_offset = 0
        r1_mask = r1_index < r1_numel
        roffset = r1_offset + r0_offset*r1_numel
        rindex = r1_index + r0_index*r1_numel
        r0_1 = r0_index
        r1_2 = r1_index
        x0 = xindex
        tmp0 = tl.load(in_ptr0 + (r1_2 + 8*x0 + 32*r0_1), r0_mask & r1_mask & xmask, other=0.0)
        tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp3 = tl.where(r0_mask & r1_mask & xmask, tmp1, 0)
        tmp4 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp6 = tl.where(r0_mask & r1_mask & xmask, tmp4, 0)
        tmp7 = tl.reshape(tmp6, [XBLOCK, RBLOCK])
        tmp8 = tl.sum(tmp7, 1)[:, None, None].to(tl.float32)
        tmp9 = tl.full([1, 1, 1], 16, tl.int32)
        tmp10 = tmp9.to(tl.float32)
        tmp11 = (tmp8 / tmp10)
        tmp12 = tmp1 - tmp11
        tmp13 = tmp12 * tmp12
        tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp16 = tl.where(r0_mask & r1_mask & xmask, tmp14, 0)
        tmp17 = tl.reshape(tmp16, [XBLOCK, RBLOCK])
        tmp18 = tl.sum(tmp17, 1)[:, None, None].to(tl.float32)
        tmp19 = 15.0
        tmp20 = (tmp18 / tmp19)
        tl.debug_barrier()
        tl.store(in_out_ptr0 + (x0), tmp20, xmask)
        tl.store(out_ptr0 + (x0), tmp11, xmask)
    else:
        pass
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
…led reductions"


Fixes: #168945

Fix combo kernels crash when fusing op with multi-dim reductions (ND tiling). It caused variable name collisions in generated triton code when multiple reduction dimensions existed.

```python
def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        RBLOCK_0: tl.constexpr = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        RBLOCK_0: tl.constexpr = 8
        ^
ValueError('RBLOCK_0 is already defined. constexpr cannot be reassigned.')
```

root cause: block size variable generation used sub-kernel index instead of reduction dim prefix, causing collisions when multiple reduction dims existed in the same sub-kernel.

after fix generated triton code:
```python
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties

triton_heuristics.persistent_reduction(
    size_hints={'x': 4, 'r0_': 2, 'r1_': 8},
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'RoundRobinComboKernelGrid', 'combo_grid_meta': {'num_kernels': 1, 'min_blocks': 0, 'default_config': None, 'no_x_dim_0': None, 'xnumel_0': 4}, 'kernel_name': 'triton_per_fused_1', 'mutated_arg_names': ['in_out_ptr0'], 'backend_hash': '07C4B3116EC6B0BD20166279782DB98EA71861B79334EBBC8CCB3D36A1E5D7F2', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': False, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False}
)
triton.jit
def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        R1_BLOCK_0: tl.constexpr = 8
        rnumel = r0_numel * r1_numel
        RBLOCK: tl.constexpr = R0_BLOCK_0*R1_BLOCK_0
        xoffset = pid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:, None, None]
        xmask = xindex < xnumel
        r0_index = tl.arange(0, R0_BLOCK_0)[None, :, None]
        r0_offset = 0
        r0_mask = r0_index < r0_numel
        r1_index = tl.arange(0, R1_BLOCK_0)[None, None, :]
        r1_offset = 0
        r1_mask = r1_index < r1_numel
        roffset = r1_offset + r0_offset*r1_numel
        rindex = r1_index + r0_index*r1_numel
        r0_1 = r0_index
        r1_2 = r1_index
        x0 = xindex
        tmp0 = tl.load(in_ptr0 + (r1_2 + 8*x0 + 32*r0_1), r0_mask & r1_mask & xmask, other=0.0)
        tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp3 = tl.where(r0_mask & r1_mask & xmask, tmp1, 0)
        tmp4 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp6 = tl.where(r0_mask & r1_mask & xmask, tmp4, 0)
        tmp7 = tl.reshape(tmp6, [XBLOCK, RBLOCK])
        tmp8 = tl.sum(tmp7, 1)[:, None, None].to(tl.float32)
        tmp9 = tl.full([1, 1, 1], 16, tl.int32)
        tmp10 = tmp9.to(tl.float32)
        tmp11 = (tmp8 / tmp10)
        tmp12 = tmp1 - tmp11
        tmp13 = tmp12 * tmp12
        tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp16 = tl.where(r0_mask & r1_mask & xmask, tmp14, 0)
        tmp17 = tl.reshape(tmp16, [XBLOCK, RBLOCK])
        tmp18 = tl.sum(tmp17, 1)[:, None, None].to(tl.float32)
        tmp19 = 15.0
        tmp20 = (tmp18 / tmp19)
        tl.debug_barrier()
        tl.store(in_out_ptr0 + (x0), tmp20, xmask)
        tl.store(out_ptr0 + (x0), tmp11, xmask)
    else:
        pass
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
@karthickai
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 9, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

tiendatngcs pushed a commit to tiendatngcs/pytorch-Dec25 that referenced this pull request Dec 10, 2025
tiendatngcs pushed a commit to tiendatngcs/pytorch-Dec25 that referenced this pull request Dec 10, 2025
tiendatngcs pushed a commit to tiendatngcs/pytorch-Dec25 that referenced this pull request Dec 10, 2025
tiendatngcs pushed a commit to tiendatngcs/pytorch-Dec25 that referenced this pull request Dec 10, 2025
skpark-rh pushed a commit to skpark-rh/pytorch that referenced this pull request Dec 10, 2025
…ons (pytorch#168946)

Fixes: pytorch#168945

Fix combo kernels crash when fusing op with multi-dim reductions (ND tiling). It caused variable name collisions in generated triton code when multiple reduction dimensions existed.

```python
def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        RBLOCK_0: tl.constexpr = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        RBLOCK_0: tl.constexpr = 8
        ^
ValueError('RBLOCK_0 is already defined. constexpr cannot be reassigned.')
```

root cause: block size variable generation used sub-kernel index instead of reduction dim prefix, causing collisions when multiple reduction dims existed in the same sub-kernel.

after fix generated triton code:
```python
import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties

@triton_heuristics.persistent_reduction(
    size_hints={'x': 4, 'r0_': 2, 'r1_': 8},
    reduction_hint=ReductionHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'RoundRobinComboKernelGrid', 'combo_grid_meta': {'num_kernels': 1, 'min_blocks': 0, 'default_config': None, 'no_x_dim_0': None, 'xnumel_0': 4}, 'kernel_name': 'triton_per_fused_1', 'mutated_arg_names': ['in_out_ptr0'], 'backend_hash': '07C4B3116EC6B0BD20166279782DB98EA71861B79334EBBC8CCB3D36A1E5D7F2', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': False, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False}
)
@triton.jit
def triton_per_fused_1(in_out_ptr0, in_ptr0, out_ptr0, XBLOCK : tl.constexpr):
    pid = tl.program_id(0)
    if pid % 1 == 0:
        pid_offset = pid // 1
        xnumel = 4
        r0_numel = 2
        R0_BLOCK_0: tl.constexpr = 2
        r1_numel = 8
        R1_BLOCK_0: tl.constexpr = 8
        rnumel = r0_numel * r1_numel
        RBLOCK: tl.constexpr = R0_BLOCK_0*R1_BLOCK_0
        xoffset = pid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:, None, None]
        xmask = xindex < xnumel
        r0_index = tl.arange(0, R0_BLOCK_0)[None, :, None]
        r0_offset = 0
        r0_mask = r0_index < r0_numel
        r1_index = tl.arange(0, R1_BLOCK_0)[None, None, :]
        r1_offset = 0
        r1_mask = r1_index < r1_numel
        roffset = r1_offset + r0_offset*r1_numel
        rindex = r1_index + r0_index*r1_numel
        r0_1 = r0_index
        r1_2 = r1_index
        x0 = xindex
        tmp0 = tl.load(in_ptr0 + (r1_2 + 8*x0 + 32*r0_1), r0_mask & r1_mask & xmask, other=0.0)
        tmp1 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp3 = tl.where(r0_mask & r1_mask & xmask, tmp1, 0)
        tmp4 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp6 = tl.where(r0_mask & r1_mask & xmask, tmp4, 0)
        tmp7 = tl.reshape(tmp6, [XBLOCK, RBLOCK])
        tmp8 = tl.sum(tmp7, 1)[:, None, None].to(tl.float32)
        tmp9 = tl.full([1, 1, 1], 16, tl.int32)
        tmp10 = tmp9.to(tl.float32)
        tmp11 = (tmp8 / tmp10)
        tmp12 = tmp1 - tmp11
        tmp13 = tmp12 * tmp12
        tmp14 = tl.broadcast_to(tmp13, [XBLOCK, R0_BLOCK_0, R1_BLOCK_0])
        tmp16 = tl.where(r0_mask & r1_mask & xmask, tmp14, 0)
        tmp17 = tl.reshape(tmp16, [XBLOCK, RBLOCK])
        tmp18 = tl.sum(tmp17, 1)[:, None, None].to(tl.float32)
        tmp19 = 15.0
        tmp20 = (tmp18 / tmp19)
        tl.debug_barrier()
        tl.store(in_out_ptr0 + (x0), tmp20, xmask)
        tl.store(out_ptr0 + (x0), tmp11, xmask)
    else:
        pass
```

Pull Request resolved: pytorch#168946
Approved by: https://github.com/eellison
@github-actions github-actions bot deleted the gh/karthickai/21/head branch January 9, 2026 02:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants