Skip to content

[inductor] Fix argmin/argmax returning incorrect indices for non-contiguous tensor#165983

Closed
karthickai wants to merge 9 commits intogh/karthickai/9/basefrom
gh/karthickai/9/head
Closed

[inductor] Fix argmin/argmax returning incorrect indices for non-contiguous tensor#165983
karthickai wants to merge 9 commits intogh/karthickai/9/basefrom
gh/karthickai/9/head

Conversation

@karthickai
Copy link
Collaborator

@karthickai karthickai commented Oct 21, 2025

Stack from ghstack (oldest at bottom):

Fixes #163929

Fixes argmin/argmax operations to return correct logical indices instead of physical memory offsets when applied to transposed/permuted tensors. When argmin() or argmax() is called on a transposed tensor, Inductor was returning physical memory indices instead of logical row-major indices. This caused incorrect results that don't match eager mode behavior.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @mlazos

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 21, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/165983

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 8a2efb5 with merge base 75b8295 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

karthickai added a commit that referenced this pull request Oct 21, 2025
…iguous tensor

ghstack-source-id: 228f710
Pull Request resolved: #165983
@karthickai karthickai added release notes: inductor ciflow/trunk Trigger trunk jobs on your pull request labels Oct 21, 2025
kept_idx.append(i)
kept_sizes.append(size[i])

# For argmax/argmin compute logical indices when the tensor has non-contiguous layout.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I have a question: why we only need compute logical indices for argmax/argmin?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because argmin/argmax are the only reductions that return indices rather than values

@karthickai karthickai requested a review from eellison October 21, 2025 09:35
@karthickai karthickai marked this pull request as draft October 21, 2025 09:45
@karthickai
Copy link
Collaborator Author

triton codegen for persistent_reduction and non persistent_reduction after fix

    @torch.compile(backend="inductor")
    def fn(x):
        return x.t().argmin()

torch.randn(6, 4, device="cuda")

persistent_reduction

@triton_heuristics.persistent_reduction(
    size_hints={'x': 1, 'r0_': 32},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_argmin_t_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'F6B370E2BDBF74E52F78AF3377EFBDEAC9745509E518CAEC7B275C7E0DFA3C8B', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'r0_': 96}}
)
@triton.jit
def triton_per_fused_argmin_t_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr):
    xnumel = 1
    r0_numel = 24
    R0_BLOCK: tl.constexpr = 32
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
    r0_index = tl.arange(0, R0_BLOCK)[None, :]
    r0_offset = 0
    r0_mask = r0_index < r0_numel
    roffset = r0_offset
    rindex = r0_index
    r0_2 = r0_index
    r0_0 = (r0_index % 4)
    r0_1 = r0_index // 4
    tmp0 = tl.load(in_ptr0 + (r0_2), r0_mask, other=0.0)
    tmp1 = r0_1 + 6*r0_0
    tmp2 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
    tmp3 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
    tmp5 = tl.where(r0_mask, tmp2, float("inf"))
    tmp4_val, tmp4_idx = triton_helpers.min_with_index(tmp5, (tmp3).to(tl.int32), 1)
    tmp4 = tmp4_idx[:, None]
    tmp4 = tmp4.to(tl.float32)
    tl.store(out_ptr0 + (tl.full([1, 1], 0, tl.int32).broadcast_to(XBLOCK, 1)), tmp4, None)

non persistent_reduction

torch.randn(128, 64, device="cuda")
@triton_heuristics.reduction(
    size_hints={'x': 1, 'r0_': 8192},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmin_t_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'F6B370E2BDBF74E52F78AF3377EFBDEAC9745509E518CAEC7B275C7E0DFA3C8B', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'r0_': 32768}}
)
@triton.jit
def triton_red_fused_argmin_t_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
    xnumel = 1
    r0_numel = 8192
    rnumel = r0_numel
    RBLOCK: tl.constexpr = R0_BLOCK
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
    r0_base = tl.arange(0, R0_BLOCK)[None, :]
    rbase = r0_base
    _tmp4 = tl.full([XBLOCK, R0_BLOCK], float("inf"), tl.float32)
    _tmp4_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32)
    for r0_offset in range(0, r0_numel, R0_BLOCK):
        r0_index = r0_offset + r0_base
        r0_mask = r0_index < r0_numel
        roffset = r0_offset
        rindex = r0_index
        r0_2 = r0_index
        r0_0 = (r0_index % 64)
        r0_1 = r0_index // 64
        tmp0 = tl.load(in_ptr0 + (r0_2), r0_mask, eviction_policy='evict_first', other=0.0)
        tmp1 = r0_1 + 128*r0_0
        tmp2 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
        tmp3 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
        _tmp4_next, _tmp4_index_next = triton_helpers.minimum_with_index(
            _tmp4, _tmp4_index, tmp2, (tmp3).to(tl.int32)
        )
        _tmp4 = tl.where(r0_mask, _tmp4_next, _tmp4)
        _tmp4_index = tl.where(r0_mask, _tmp4_index_next, _tmp4_index)
    tmp4_val, tmp4_idx = triton_helpers.min_with_index(_tmp4, _tmp4_index, 1)
    tmp4 = tmp4_idx[:, None]
    tl.store(out_ptr0 + (tl.full([1, 1], 0, tl.int32).broadcast_to(XBLOCK, 1)), tmp4, None)

…or non-contiguous tensor"


Fixes #163929

Fixes argmin/argmax operations to return correct logical indices instead of physical memory offsets when applied to transposed/permuted tensors.  When `argmin()` or `argmax()` is called on a transposed tensor, Inductor was returning physical memory indices instead of logical row-major indices. This caused incorrect results that don't match eager mode behavior.  


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
karthickai added a commit that referenced this pull request Oct 21, 2025
…iguous tensor

ghstack-source-id: 19f6416
Pull Request resolved: #165983
…or non-contiguous tensor"


Fixes #163929

Fixes argmin/argmax operations to return correct logical indices instead of physical memory offsets when applied to transposed/permuted tensors.  When `argmin()` or `argmax()` is called on a transposed tensor, Inductor was returning physical memory indices instead of logical row-major indices. This caused incorrect results that don't match eager mode behavior.  


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
karthickai added a commit that referenced this pull request Oct 23, 2025
…iguous tensor

ghstack-source-id: c26d384
Pull Request resolved: #165983
…or non-contiguous tensor"


Fixes #163929

Fixes argmin/argmax operations to return correct logical indices instead of physical memory offsets when applied to transposed/permuted tensors.  When `argmin()` or `argmax()` is called on a transposed tensor, Inductor was returning physical memory indices instead of logical row-major indices. This caused incorrect results that don't match eager mode behavior.  


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
karthickai added a commit that referenced this pull request Oct 23, 2025
…iguous tensor

ghstack-source-id: 22a965b
Pull Request resolved: #165983
…or non-contiguous tensor"


Fixes #163929

Fixes argmin/argmax operations to return correct logical indices instead of physical memory offsets when applied to transposed/permuted tensors.  When `argmin()` or `argmax()` is called on a transposed tensor, Inductor was returning physical memory indices instead of logical row-major indices. This caused incorrect results that don't match eager mode behavior.  


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
karthickai added a commit that referenced this pull request Oct 26, 2025
…iguous tensor

ghstack-source-id: 9a0677d
Pull Request resolved: #165983
@karthickai karthickai marked this pull request as ready for review October 26, 2025 00:40
@eellison eellison requested review from shunting314 and removed request for eellison October 27, 2025 17:39
Comment on lines 6101 to 6107
if isinstance(x.data, PermuteView):
should_compute_logical_index = True
elif isinstance(x.data, ir.ReinterpretView):
layout = x.get_layout()
should_compute_logical_index = (
layout.is_transposed() or not layout.is_contiguous()
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not cover the case when the function input is in column major layout already.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the catch! I've added isinstance(x.data, ir.StorageBox) to handle column major tensors passed as function input.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also added a test case to verify column-major inputs

permuted = x.permute(2, 0, 1)
return (permuted.argmin(), permuted.argmax())

self.common(fn, (torch.randn(4, 6, 8, device=GPU_TYPE),))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be complete, can we also test for the case where the layout of the tensor contains 'gap', e.g.:

torch.randn(10, 20)[:, :10]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a test case for sliced tensors with gaps in memory

@eellison eellison self-requested a review October 27, 2025 18:54
…or non-contiguous tensor"


Fixes #163929

Fixes argmin/argmax operations to return correct logical indices instead of physical memory offsets when applied to transposed/permuted tensors.  When `argmin()` or `argmax()` is called on a transposed tensor, Inductor was returning physical memory indices instead of logical row-major indices. This caused incorrect results that don't match eager mode behavior.  


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
karthickai added a commit that referenced this pull request Oct 27, 2025
…iguous tensor

ghstack-source-id: 52b5dd6
Pull Request resolved: #165983
…or non-contiguous tensor"


Fixes #163929

Fixes argmin/argmax operations to return correct logical indices instead of physical memory offsets when applied to transposed/permuted tensors.  When `argmin()` or `argmax()` is called on a transposed tensor, Inductor was returning physical memory indices instead of logical row-major indices. This caused incorrect results that don't match eager mode behavior.  


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
karthickai added a commit that referenced this pull request Oct 27, 2025
…iguous tensor

ghstack-source-id: 0fe19cc
Pull Request resolved: #165983
…or non-contiguous tensor"


Fixes #163929

Fixes argmin/argmax operations to return correct logical indices instead of physical memory offsets when applied to transposed/permuted tensors.  When `argmin()` or `argmax()` is called on a transposed tensor, Inductor was returning physical memory indices instead of logical row-major indices. This caused incorrect results that don't match eager mode behavior.  


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
karthickai added a commit that referenced this pull request Oct 27, 2025
…iguous tensor

ghstack-source-id: fb1ebd9
Pull Request resolved: #165983
Comment on lines 6104 to 6105
isinstance(x.data, ir.StorageBox)
and isinstance(x.data.data, ir.InputBuffer)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, it works for graph input now but still does not work for internally generated buffers (e.g. ComputedBuffer).

Here are a few choices:

  1. either change isinstance(x.data.data, ir.InputBuffer) to isinstance(x.data.data, ir.Buffer)
  2. or always use logical index disregarding the layout. Logical index should still work even if the input for argmin/max is contiguous right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're correct. I've changed it to ir.Buffer. Yes, using logical index for contiguous will work, but I feel it's less efficient.

…or non-contiguous tensor"


Fixes #163929

Fixes argmin/argmax operations to return correct logical indices instead of physical memory offsets when applied to transposed/permuted tensors.  When `argmin()` or `argmax()` is called on a transposed tensor, Inductor was returning physical memory indices instead of logical row-major indices. This caused incorrect results that don't match eager mode behavior.  


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos

[ghstack-poisoned]
karthickai added a commit that referenced this pull request Oct 27, 2025
…iguous tensor

ghstack-source-id: ab77ba9
Pull Request resolved: #165983
@shunting314 shunting314 self-requested a review October 28, 2025 00:43
@karthickai
Copy link
Collaborator Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants