[Inductor] Fix combo kernels by populating constants for equal_to_1 args#168127
[Inductor] Fix combo kernels by populating constants for equal_to_1 args#168127karthickai wants to merge 19 commits intogh/karthickai/16/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/168127
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 5 Unrelated FailuresAs of commit 747aadb with merge base e770c95 ( NEW FAILURE - The following job has failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but was present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Generated triton kernel before and after fix: before fix: constant is missing for import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
@triton_heuristics.pointwise(
size_hints={'x': 1024}, tile_hint=TileHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'load_seed_offset': 'i32', 'load_seed_offset1': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'RoundRobinComboKernelGrid', 'combo_grid_meta': {'num_kernels': 2, 'min_blocks': 0, 'default_config': None, 'no_x_dim_0': False, 'xnumel_0': 1024, 'no_x_dim_1': False, 'xnumel_1': 1024}, 'kernel_name': 'triton_poi_fused_0', 'mutated_arg_names': [], 'backend_hash': '07C4B3116EC6B0BD20166279782DB98EA71861B79334EBBC8CCB3D36A1E5D7F2', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': False, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False}
)
@triton.jit
def triton_poi_fused_0(in_ptr0, out_ptr0, out_ptr1, load_seed_offset, load_seed_offset1, XBLOCK : tl.constexpr):
pid = tl.program_id(0)
if pid % 2 == 0:
pid_offset = pid // 2
xnumel = 1024
r0_numel = 1
xoffset = pid_offset * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + load_seed_offset)
tmp1 = x0
tmp2 = tl.rand(tmp0, (tmp1).to(tl.uint32))
tl.store(out_ptr0 + (x0), tmp2, xmask)
elif pid % 2 == 1:
pid_offset = pid // 2
xnumel = 1024
r0_numel = 1
xoffset = pid_offset * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x1 = xindex
tmp3 = tl.load(in_ptr0 + load_seed_offset1)
tmp4 = x1
tmp5 = tl.rand(tmp3, (tmp4).to(tl.uint32))
tl.store(out_ptr1 + (x1), tmp5, xmask)
else:
passAfter fix constant populated properly import triton
import triton.language as tl
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
@triton_heuristics.pointwise(
size_hints={'x': 1024}, tile_hint=TileHint.DEFAULT,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*i64', 'out_ptr0': '*fp32', 'out_ptr1': '*fp32', 'load_seed_offset': 'i32', 'load_seed_offset1': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'load_seed_offset1': 1}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'RoundRobinComboKernelGrid', 'combo_grid_meta': {'num_kernels': 2, 'min_blocks': 0, 'default_config': None, 'no_x_dim_0': False, 'xnumel_0': 1024, 'no_x_dim_1': False, 'xnumel_1': 1024}, 'kernel_name': 'triton_poi_fused_0', 'mutated_arg_names': [], 'backend_hash': '07C4B3116EC6B0BD20166279782DB98EA71861B79334EBBC8CCB3D36A1E5D7F2', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': False, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False}
)
@triton.jit
def triton_poi_fused_0(in_ptr0, out_ptr0, out_ptr1, load_seed_offset, load_seed_offset1, XBLOCK : tl.constexpr):
pid = tl.program_id(0)
if pid % 2 == 0:
pid_offset = pid // 2
xnumel = 1024
r0_numel = 1
xoffset = pid_offset * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + load_seed_offset)
tmp1 = x0
tmp2 = tl.rand(tmp0, (tmp1).to(tl.uint32))
tl.store(out_ptr0 + (x0), tmp2, xmask)
elif pid % 2 == 1:
pid_offset = pid // 2
xnumel = 1024
r0_numel = 1
xoffset = pid_offset * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x1 = xindex
tmp3 = tl.load(in_ptr0 + load_seed_offset1)
tmp4 = x1
tmp5 = tl.rand(tmp3, (tmp4).to(tl.uint32))
tl.store(out_ptr1 + (x1), tmp5, xmask)
else:
pass |
| "constants": {}, | ||
| } | ||
|
|
||
| for arg_num in equal_1_arg_indices(signature): |
There was a problem hiding this comment.
this seems super specialized, are there other values the constants can have?
There was a problem hiding this comment.
yes, constants can have any value, but only equal_to_1 needs special handling because triton autodetects it and requires it in the constant dict to prevent segfaults this exact pattern is used in regular triton codegen https://github.com/pytorch/pytorch/blob/0d7ba9714ac77b2b4a446a9eff913a6ff9dfc782/torch/_inductor/codegen/triton.py#L5169C1-L5177C1 but was missing in combo kernels, so I added it.
mlazos
left a comment
There was a problem hiding this comment.
Looks good, please address my one question though
…qual_to_1 args" Fixes: #168124 This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos [ghstack-poisoned]
…qual_to_1 args" Fixes: #168124 This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos [ghstack-poisoned]
…qual_to_1 args" Fixes: #168124 This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos [ghstack-poisoned]
…qual_to_1 args" Fixes: #168124 This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos [ghstack-poisoned]
…qual_to_1 args" Fixes: #168124 This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos [ghstack-poisoned]
…qual_to_1 args" Fixes: #168124 This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos [ghstack-poisoned]
…qual_to_1 args" Fixes: #168124 This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos [ghstack-poisoned]
…qual_to_1 args" Fixes: #168124 This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos [ghstack-poisoned]
…qual_to_1 args" Fixes: #168124 This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos [ghstack-poisoned]
…qual_to_1 args" Fixes: #168124 This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos [ghstack-poisoned]
…qual_to_1 args" Fixes: #168124 This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos [ghstack-poisoned]
…qual_to_1 args" Fixes: #168124 This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos [ghstack-poisoned]
…qual_to_1 args" Fixes: #168124 This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo mlazos chenyang78 [ghstack-poisoned]
…qual_to_1 args" Fixes: #168124 This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo mlazos chenyang78 [ghstack-poisoned]
…qual_to_1 args" Fixes: #168124 This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo mlazos chenyang78 [ghstack-poisoned]
…qual_to_1 args" Fixes: #168124 This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo mlazos chenyang78 [ghstack-poisoned]
…qual_to_1 args" Fixes: #168124 This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo mlazos chenyang78 [ghstack-poisoned]
…qual_to_1 args" Fixes: #168124 This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy kadeng muchulee8 amjames chauhang aakhundov coconutruben jataylo mlazos chenyang78 [ghstack-poisoned]
|
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 6 checks: pull / linux-jammy-py3.14-clang12 / test (default, 2, 5, linux.4xlarge), pull / linux-jammy-py3.14-clang12 / test (default, 4, 5, linux.4xlarge), pull / linux-jammy-py3.14-clang12 / test (default, 3, 5, linux.4xlarge), pull / linux-jammy-py3.14-clang12 / test (default, 5, 5, linux.4xlarge), inductor / inductor-test / test (inductor_torchbench, 2, 2, linux.g5.4xlarge.nvidia.gpu), trunk / linux-jammy-cuda12.8-py3.10-gcc11 / test (default, 5, 5, linux.g6.4xlarge.experimental.nvidia.gpu) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…rgs (pytorch#168127) Fixes: pytorch#168124 This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels. Pull Request resolved: pytorch#168127 Approved by: https://github.com/mlazos ghstack dependencies: pytorch#167781
…rgs (#168127) Fixes: #168124 This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the `constants` for args marked as compile-time constants, matching the behavior of regular Triton kernels. Pull Request resolved: #168127 Approved by: https://github.com/mlazos ghstack dependencies: #167781
ghstack-source-id: 5c05640 Pull Request resolved: pytorch/pytorch#168127
ghstack-source-id: 262716f Pull Request resolved: pytorch/pytorch#168127
ghstack-source-id: 60a60ce Pull Request resolved: pytorch/pytorch#168127
Stack from ghstack (oldest at bottom):
Fixes: #168124
This PR fixes triton compilation failures in combo kernels when combining multiple kernels with random ops (or any ops that creates args with value equal to 1). The fix adds the missing logic to populate the
constantsfor args marked as compile-time constants, matching the behavior of regular Triton kernels.cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo @mlazos @chenyang78