Skip to content

Commit a90b4f0

Browse files
authored
use 4 warps for small block config in mm (#95383)
* use 4 warps for small block config in mm * Update test/inductor/test_select_algorithm.py * Update test/inductor/test_select_algorithm.py
1 parent 1211cee commit a90b4f0

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

test/inductor/test_select_algorithm.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,30 @@ def test_addmm(self):
6262
def foo(input, weight, bias):
6363
return torch.addmm(bias, input, weight)
6464

65-
foo(
65+
inps = (
6666
torch.randn(20, 33, device="cuda"),
6767
torch.randn(33, 16, device="cuda"),
6868
torch.randn(20, 16, device="cuda"),
6969
)
70+
71+
foo(*inps)
72+
# Autotuning checks correctness of each version
73+
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
74+
75+
@patch.object(select_algorithm, "VERIFY", dict(atol=5e-2, rtol=5e-2))
76+
@patches
77+
def test_addmm_fp16(self):
78+
@torch.compile
79+
def foo(input, weight, bias):
80+
return torch.addmm(bias, input, weight)
81+
82+
inps = (
83+
torch.randn(2, 320, device="cuda", dtype=torch.half),
84+
torch.randn(320, 320, device="cuda", dtype=torch.half).t(),
85+
torch.empty(320, device="cuda", dtype=torch.half),
86+
)
87+
88+
foo(*inps)
7089
# Autotuning checks correctness of each version
7190
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
7291

torch/_inductor/kernel/mm_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def mm_configs():
4545
{"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 64}, num_stages=3, num_warps=8
4646
),
4747
triton.Config(
48-
{"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 128}, num_stages=2, num_warps=8
48+
{"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 128}, num_stages=2, num_warps=4
4949
),
5050
triton.Config(
5151
{"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 16}, num_stages=2, num_warps=4

0 commit comments

Comments
 (0)