[inductor] Fix argmin/argmax returning incorrect indices for non-contiguous tensor#165983
[inductor] Fix argmin/argmax returning incorrect indices for non-contiguous tensor#165983karthickai wants to merge 9 commits intogh/karthickai/9/basefrom
Conversation
…iguous tensor [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/165983
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 8a2efb5 with merge base 75b8295 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| kept_idx.append(i) | ||
| kept_sizes.append(size[i]) | ||
|
|
||
| # For argmax/argmin compute logical indices when the tensor has non-contiguous layout. |
There was a problem hiding this comment.
oh, I have a question: why we only need compute logical indices for argmax/argmin?
There was a problem hiding this comment.
because argmin/argmax are the only reductions that return indices rather than values
|
triton codegen for persistent_reduction and non persistent_reduction after fix @torch.compile(backend="inductor")
def fn(x):
return x.t().argmin()
torch.randn(6, 4, device="cuda")persistent_reduction @triton_heuristics.persistent_reduction(
size_hints={'x': 1, 'r0_': 32},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_per_fused_argmin_t_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': None, 'num_load': 1, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'F6B370E2BDBF74E52F78AF3377EFBDEAC9745509E518CAEC7B275C7E0DFA3C8B', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'r0_': 96}}
)
@triton.jit
def triton_per_fused_argmin_t_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr):
xnumel = 1
r0_numel = 24
R0_BLOCK: tl.constexpr = 32
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_index = tl.arange(0, R0_BLOCK)[None, :]
r0_offset = 0
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
r0_0 = (r0_index % 4)
r0_1 = r0_index // 4
tmp0 = tl.load(in_ptr0 + (r0_2), r0_mask, other=0.0)
tmp1 = r0_1 + 6*r0_0
tmp2 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
tmp3 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
tmp5 = tl.where(r0_mask, tmp2, float("inf"))
tmp4_val, tmp4_idx = triton_helpers.min_with_index(tmp5, (tmp3).to(tl.int32), 1)
tmp4 = tmp4_idx[:, None]
tmp4 = tmp4.to(tl.float32)
tl.store(out_ptr0 + (tl.full([1, 1], 0, tl.int32).broadcast_to(XBLOCK, 1)), tmp4, None)non persistent_reduction torch.randn(128, 64, device="cuda")@triton_heuristics.reduction(
size_hints={'x': 1, 'r0_': 8192},
reduction_hint=ReductionHint.INNER,
filename=__file__,
triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*i64', 'xnumel': 'constexpr', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=132, cc=90, major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {'xnumel': 1}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_argmin_t_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 1, 'num_store': 1, 'num_reduction': 1, 'backend_hash': 'F6B370E2BDBF74E52F78AF3377EFBDEAC9745509E518CAEC7B275C7E0DFA3C8B', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'r0_': 32768}}
)
@triton.jit
def triton_red_fused_argmin_t_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 1
r0_numel = 8192
rnumel = r0_numel
RBLOCK: tl.constexpr = R0_BLOCK
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = tl.full([XBLOCK, R0_BLOCK], True, tl.int1)
r0_base = tl.arange(0, R0_BLOCK)[None, :]
rbase = r0_base
_tmp4 = tl.full([XBLOCK, R0_BLOCK], float("inf"), tl.float32)
_tmp4_index = tl.full([XBLOCK, R0_BLOCK], 2147483647, tl.int32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
rindex = r0_index
r0_2 = r0_index
r0_0 = (r0_index % 64)
r0_1 = r0_index // 64
tmp0 = tl.load(in_ptr0 + (r0_2), r0_mask, eviction_policy='evict_first', other=0.0)
tmp1 = r0_1 + 128*r0_0
tmp2 = tl.broadcast_to(tmp0, [XBLOCK, R0_BLOCK])
tmp3 = tl.broadcast_to(tmp1, [XBLOCK, R0_BLOCK])
_tmp4_next, _tmp4_index_next = triton_helpers.minimum_with_index(
_tmp4, _tmp4_index, tmp2, (tmp3).to(tl.int32)
)
_tmp4 = tl.where(r0_mask, _tmp4_next, _tmp4)
_tmp4_index = tl.where(r0_mask, _tmp4_index_next, _tmp4_index)
tmp4_val, tmp4_idx = triton_helpers.min_with_index(_tmp4, _tmp4_index, 1)
tmp4 = tmp4_idx[:, None]
tl.store(out_ptr0 + (tl.full([1, 1], 0, tl.int32).broadcast_to(XBLOCK, 1)), tmp4, None) |
…or non-contiguous tensor" Fixes #163929 Fixes argmin/argmax operations to return correct logical indices instead of physical memory offsets when applied to transposed/permuted tensors. When `argmin()` or `argmax()` is called on a transposed tensor, Inductor was returning physical memory indices instead of logical row-major indices. This caused incorrect results that don't match eager mode behavior. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos [ghstack-poisoned]
…or non-contiguous tensor" Fixes #163929 Fixes argmin/argmax operations to return correct logical indices instead of physical memory offsets when applied to transposed/permuted tensors. When `argmin()` or `argmax()` is called on a transposed tensor, Inductor was returning physical memory indices instead of logical row-major indices. This caused incorrect results that don't match eager mode behavior. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos [ghstack-poisoned]
…or non-contiguous tensor" Fixes #163929 Fixes argmin/argmax operations to return correct logical indices instead of physical memory offsets when applied to transposed/permuted tensors. When `argmin()` or `argmax()` is called on a transposed tensor, Inductor was returning physical memory indices instead of logical row-major indices. This caused incorrect results that don't match eager mode behavior. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos [ghstack-poisoned]
…or non-contiguous tensor" Fixes #163929 Fixes argmin/argmax operations to return correct logical indices instead of physical memory offsets when applied to transposed/permuted tensors. When `argmin()` or `argmax()` is called on a transposed tensor, Inductor was returning physical memory indices instead of logical row-major indices. This caused incorrect results that don't match eager mode behavior. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos [ghstack-poisoned]
| if isinstance(x.data, PermuteView): | ||
| should_compute_logical_index = True | ||
| elif isinstance(x.data, ir.ReinterpretView): | ||
| layout = x.get_layout() | ||
| should_compute_logical_index = ( | ||
| layout.is_transposed() or not layout.is_contiguous() | ||
| ) |
There was a problem hiding this comment.
This does not cover the case when the function input is in column major layout already.
There was a problem hiding this comment.
thanks for the catch! I've added isinstance(x.data, ir.StorageBox) to handle column major tensors passed as function input.
There was a problem hiding this comment.
also added a test case to verify column-major inputs
| permuted = x.permute(2, 0, 1) | ||
| return (permuted.argmin(), permuted.argmax()) | ||
|
|
||
| self.common(fn, (torch.randn(4, 6, 8, device=GPU_TYPE),)) |
There was a problem hiding this comment.
To be complete, can we also test for the case where the layout of the tensor contains 'gap', e.g.:
torch.randn(10, 20)[:, :10]
There was a problem hiding this comment.
I've added a test case for sliced tensors with gaps in memory
…or non-contiguous tensor" Fixes #163929 Fixes argmin/argmax operations to return correct logical indices instead of physical memory offsets when applied to transposed/permuted tensors. When `argmin()` or `argmax()` is called on a transposed tensor, Inductor was returning physical memory indices instead of logical row-major indices. This caused incorrect results that don't match eager mode behavior. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos [ghstack-poisoned]
…or non-contiguous tensor" Fixes #163929 Fixes argmin/argmax operations to return correct logical indices instead of physical memory offsets when applied to transposed/permuted tensors. When `argmin()` or `argmax()` is called on a transposed tensor, Inductor was returning physical memory indices instead of logical row-major indices. This caused incorrect results that don't match eager mode behavior. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos [ghstack-poisoned]
…or non-contiguous tensor" Fixes #163929 Fixes argmin/argmax operations to return correct logical indices instead of physical memory offsets when applied to transposed/permuted tensors. When `argmin()` or `argmax()` is called on a transposed tensor, Inductor was returning physical memory indices instead of logical row-major indices. This caused incorrect results that don't match eager mode behavior. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos [ghstack-poisoned]
torch/_inductor/lowering.py
Outdated
| isinstance(x.data, ir.StorageBox) | ||
| and isinstance(x.data.data, ir.InputBuffer) |
There was a problem hiding this comment.
hmm, it works for graph input now but still does not work for internally generated buffers (e.g. ComputedBuffer).
Here are a few choices:
- either change
isinstance(x.data.data, ir.InputBuffer)toisinstance(x.data.data, ir.Buffer) - or always use logical index disregarding the layout. Logical index should still work even if the input for argmin/max is contiguous right?
There was a problem hiding this comment.
you're correct. I've changed it to ir.Buffer. Yes, using logical index for contiguous will work, but I feel it's less efficient.
…or non-contiguous tensor" Fixes #163929 Fixes argmin/argmax operations to return correct logical indices instead of physical memory offsets when applied to transposed/permuted tensors. When `argmin()` or `argmax()` is called on a transposed tensor, Inductor was returning physical memory indices instead of logical row-major indices. This caused incorrect results that don't match eager mode behavior. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben mlazos [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
Fixes #163929
Fixes argmin/argmax operations to return correct logical indices instead of physical memory offsets when applied to transposed/permuted tensors. When
argmin()orargmax()is called on a transposed tensor, Inductor was returning physical memory indices instead of logical row-major indices. This caused incorrect results that don't match eager mode behavior.cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @mlazos