Skip to content

Commit fc5612a

Browse files
[inductor] don't try to reorder loops for template (#166910)
[inductor] don't try to reorder loops for template (#165601) fix #165579 Pull Request resolved: #165601 Approved by: https://github.com/yushangdi (cherry picked from commit a303d6d) Co-authored-by: Shunting Zhang <shunting@fb.com>
1 parent d29deef commit fc5612a

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

test/inductor/test_loop_ordering.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,31 @@ def f(x, y):
592592
".run(", 1 + int(inductor_config.benchmark_kernel), exactly=True
593593
).run(code[0])
594594

595+
@inductor_config.patch(
596+
{
597+
"max_autotune": True,
598+
"max_autotune_gemm_backends": "TRITON",
599+
"test_configs.max_mm_configs": 4,
600+
}
601+
)
602+
@skipUnless(HAS_GPU and is_big_gpu(), "Need big gpu for max-autotune")
603+
def test_interaction_with_multi_template(self):
604+
"""
605+
Skip MultiTemplateBuffer during loop reordering
606+
"""
607+
608+
@torch.compile
609+
def f(x, y):
610+
return (x @ y), x + 1
611+
612+
N = 2
613+
x = torch.randn([N, N], device=GPU_TYPE, dtype=torch.bfloat16)
614+
y = torch.randn([N, N], device=GPU_TYPE, dtype=torch.bfloat16)
615+
616+
out, code = run_and_get_code(f, x, y)
617+
# didn't fuse due to small savings
618+
FileCheck().check_count("@triton.jit", 2, exactly=True).run(code[0])
619+
595620
def test_fuse_with_scalar_shared_memory(self):
596621
"""
597622
Make sure if we can fuse two nodes sharing a scalar before,

torch/_inductor/scheduler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3953,6 +3953,12 @@ def shared_data_after_reordering_loop(
39533953
):
39543954
return -1
39553955

3956+
# in some rare case, a template can be passed in.
3957+
# Check test_interaction_with_multi_template in test_loop_ordering.py
3958+
# and https://github.com/pytorch/pytorch/issues/165579
3959+
if node1.is_template() or node2.is_template():
3960+
return -1
3961+
39563962
node1_buffer_names = node1.read_writes.buffer_names()
39573963
node2_buffer_names = node2.read_writes.buffer_names()
39583964
# Fast path: no common buffers.

0 commit comments

Comments
 (0)