Skip to content

[Inductor] Fix unbacked float symbol handling in kernel codegen#166890

Closed
karthickai wants to merge 4 commits intogh/karthickai/10/basefrom
gh/karthickai/10/head
Closed

[Inductor] Fix unbacked float symbol handling in kernel codegen#166890
karthickai wants to merge 4 commits intogh/karthickai/10/basefrom
gh/karthickai/10/head

Conversation

@karthickai
Copy link
Collaborator

@karthickai karthickai commented Nov 3, 2025

Stack from ghstack (oldest at bottom):

When a fn compiled with torch.compile calls .item() on a float tensor arg (e.g., for thresholds in torch.clamp), the generated triton kernel references an unbacked float symbol (e.g., zuf0) that was never added to the kernel's parameter list, causing a compilation error.

Fixes: #166888 #163674

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @mlazos

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 3, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/166890

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 477af36 with merge base ad7a572 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

karthickai added a commit that referenced this pull request Nov 3, 2025
@karthickai karthickai added release notes: inductor ciflow/trunk Trigger trunk jobs on your pull request labels Nov 3, 2025
@karthickai
Copy link
Collaborator Author

generated triton code
before fix (zuf0 is undefined)

import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 8192}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clamp_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': '1CF1A6CE7D6F88FF171C94282FF6FF5D221A856237DFCA75F009E863091F8BA8', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': False, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 72000}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_clamp_0(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 6000
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = 0.0
    tmp2 = triton_helpers.maximum(tmp0, tmp1)
    tmp3 = zuf0
    tmp4 = tmp3.to(tl.float32)
    tmp5 = triton_helpers.minimum(tmp2, tmp4)
    tl.store(out_ptr0 + (x0), tmp5, xmask)

after fix

import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 8192}, 
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'ks0': 'fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_clamp_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 0, 'backend_hash': '1CF1A6CE7D6F88FF171C94282FF6FF5D221A856237DFCA75F009E863091F8BA8', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': False, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'x': 72000}},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_clamp_0(in_ptr0, out_ptr0, ks0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 6000
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = 0.0
    tmp2 = triton_helpers.maximum(tmp0, tmp1)
    tmp3 = ks0
    tmp4 = tmp3.to(tl.float32)
    tmp5 = triton_helpers.minimum(tmp2, tmp4)
    tl.store(out_ptr0 + (x0), tmp5, xmask)

@karthickai karthickai requested a review from eellison November 3, 2025 21:02
…degen"


When a fn compiled with `torch.compile` calls `.item()` on a float tensor arg (e.g., for thresholds in `torch.clamp`), the generated triton kernel references an unbacked float symbol (e.g., `zuf0`) that was never added to the kernel's parameter list, causing a compilation error.

Fixes: #166888

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
karthickai added a commit that referenced this pull request Nov 3, 2025
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@karthickai
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@malfet
Copy link
Contributor

malfet commented Nov 5, 2025

@pytorchbot revert -m "Looks like it broke torchfuzz tests, see https://hud.pytorch.org/hud/pytorch/pytorch/fbd70fb84e347b45db79eb24cc2c53e447a04147/1?per_page=50&name_filter=trunk%20%2F%20linux-jammy-cuda12&mergeEphemeralLF=true and same test on slow" -c nosignal

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@karthickai your PR has been successfully reverted.

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Nov 5, 2025
@karthickai
Copy link
Collaborator Author

@malfet thanks for the info! I ran the test locally and it is passing. I'll rebase and submit again.

(unbacked-pytorch) [karthickps@devvm5699.eag0 ~/unbacked-pytorch (e3be39f4)]$ python test/test_torchfuzz_repros.py -k test_fuzzer_issue_163674
Eager Success! ✅
Compile Success!.
----------------------------------------------------------------------
Ran 1 test in 43.806s

OK

…degen"


When a fn compiled with `torch.compile` calls `.item()` on a float tensor arg (e.g., for thresholds in `torch.clamp`), the generated triton kernel references an unbacked float symbol (e.g., `zuf0`) that was never added to the kernel's parameter list, causing a compilation error.

Fixes: #166888

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
karthickai added a commit that referenced this pull request Nov 5, 2025
@karthickai
Copy link
Collaborator Author

karthickai commented Nov 6, 2025

I figured out why the job failed test_fuzzer_issue_163674 issue:(#163674) is an expected failure because zuf0 is undefined

def triton_poi_fused_fill_pow_view_zero_0(out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 238464
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = zuf0
           ^
NameError('zuf0 is not defined')

as I mentioned (#166888) here unbacked float is not handled in the Inductor codegen. This PR actually solves the problem that caused the failure. I already added test_case test_unbacked_float_item to check so I removing the test_fuzzer_issue_163674

…degen"


When a fn compiled with `torch.compile` calls `.item()` on a float tensor arg (e.g., for thresholds in `torch.clamp`), the generated triton kernel references an unbacked float symbol (e.g., `zuf0`) that was never added to the kernel's parameter list, causing a compilation error.

Fixes: #166888 #163674

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
@karthickai
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Silv3S pushed a commit to Silv3S/pytorch that referenced this pull request Nov 18, 2025
…rch#166890)

When a fn compiled with `torch.compile` calls `.item()` on a float tensor arg (e.g., for thresholds in `torch.clamp`), the generated triton kernel references an unbacked float symbol (e.g., `zuf0`) that was never added to the kernel's parameter list, causing a compilation error.

Fixes: pytorch#166888 pytorch#163674

Pull Request resolved: pytorch#166890
Approved by: https://github.com/eellison, https://github.com/mlazos
@github-actions github-actions bot deleted the gh/karthickai/10/head branch December 7, 2025 02:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants