Skip to content

Commit ae6bb58

Browse files
Revert "[cutlass backend] Forward fix for less aligned gemm shapes (#148521)"
This reverts commit ad49cfc. Reverted #148521 on behalf of https://github.com/davidberard98 due to broke lint: [GH job link](https://github.com/pytorch/pytorch/actions/runs/13690720601/job/38283359447) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/ad49cfc9f0a8a4d8881b3734edd8c33a087c8b97) ([comment](#148521 (comment)))
1 parent 4dc956a commit ae6bb58

File tree

2 files changed

+42
-102
lines changed

2 files changed

+42
-102
lines changed

test/inductor/test_cutlass_backend.py

Lines changed: 0 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from pathlib import Path
1010
from typing import Callable, Optional
1111

12-
from torch._inductor.exc import InductorError
1312
from torch._inductor.utils import clear_inductor_caches
1413
from torch.export import Dim
1514
from torch.testing._internal.logging_utils import log_settings
@@ -962,80 +961,6 @@ def select_no_algorithm(*args, **kwargs):
962961
cuda_template_count += 1
963962
assert cuda_template_count > 0, "No CUDATemplateCaller choices"
964963

965-
@unittest.skipIf(not SM90OrLater, "need sm_90")
966-
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
967-
def test_cutlass_backend_shape_coverage_mm(
968-
self,
969-
):
970-
"""
971-
Checks if cutlass backend produces some ops for a variety of shapes.
972-
973-
This test doesn't compile and check the correctness of the ops.
974-
975-
NOTE: K has to be even.
976-
"""
977-
978-
inputs = [
979-
(torch.randn(128, 500).cuda().half(), torch.randn(500, 576).cuda().half()),
980-
(
981-
torch.randn(500, 128).cuda().half(),
982-
torch.randn(128, 576).cuda().half(),
983-
),
984-
(torch.randn(128, 250).cuda().half(), torch.randn(250, 576).cuda().half()),
985-
(
986-
torch.randn(250, 128).cuda().half(),
987-
torch.randn(128, 576).cuda().half(),
988-
),
989-
(
990-
torch.randn(125, 128).cuda().half(),
991-
torch.randn(128, 576).cuda().half(),
992-
),
993-
]
994-
995-
def select_no_algorithm(*args, **kwargs):
996-
raise NoValidChoicesError
997-
998-
with fresh_inductor_cache(), config.patch(
999-
{
1000-
"max_autotune": True,
1001-
"max_autotune_gemm_backends": "CUTLASS",
1002-
"cuda.cutlass_max_profiling_configs": 2,
1003-
"autotune_fallback_to_aten": False,
1004-
}
1005-
), mock.patch(
1006-
"torch._inductor.kernel.mm.autotune_select_algorithm",
1007-
wraps=select_no_algorithm,
1008-
) as sa:
1009-
for input in inputs:
1010-
A, B = input
1011-
M, K = A.shape
1012-
_, N = B.shape
1013-
1014-
with self.assertRaises(InductorError, r".*NoValidChoicesError.*"):
1015-
torch.compile(torch.mm, dynamic=False)(*input)
1016-
1017-
self.assertTrue(
1018-
sa.called,
1019-
f"autotune_select_algorithm was not called with shape M={M}, N={N}, K={K}",
1020-
)
1021-
args, _ = sa.call_args
1022-
op_name, choices, _, __ = args
1023-
assert op_name == "mm"
1024-
cuda_template_count = 0
1025-
for choice in choices:
1026-
if isinstance(choice, CUDATemplateCaller):
1027-
choice_info = choice.info_dict()
1028-
op_conf_name = choice_info.get("op_conf_name", "")
1029-
assert isinstance(op_conf_name, str)
1030-
cuda_template_count += 1
1031-
1032-
self.assertGreater(
1033-
cuda_template_count,
1034-
0,
1035-
"No CUDATemplateCaller choices found for matmul with shape "
1036-
f"M={M}, N={N}, K={K}",
1037-
)
1038-
1039964
@unittest.skipIf(not SM80OrLater, "need sm_80")
1040965
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
1041966
def test_get_max_alignment(self):

torch/_inductor/codegen/cuda/gemm_template.py

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -841,16 +841,6 @@ def filter_op(
841841
# Set epilogue.
842842
# TODO: update epilogue functor according to epilogues.
843843
op.element_epilogue = op.accumulator_type()
844-
845-
# Set bias layout and alignment.
846-
status = self._set_bias_layout_and_alignment(op)
847-
if not status:
848-
log.debug(
849-
"Skipping due to bias layout and alignment setting failure. op: %s", op
850-
)
851-
return None
852-
853-
# Apply regex filters at the end when configuration name doesn't change anymore
854844
if inductor_cuda_config.cutlass_op_allowlist_regex is not None:
855845
if not re.search(
856846
inductor_cuda_config.cutlass_op_allowlist_regex, op.configuration_name()
@@ -862,6 +852,14 @@ def filter_op(
862852
):
863853
return None
864854

855+
# Set bias layout and alignment.
856+
status = self._set_bias_layout_and_alignment(op)
857+
if not status:
858+
log.debug(
859+
"Skipping due to bias layout and alignment setting failure. op: %s", op
860+
)
861+
return None
862+
865863
return op
866864

867865
def gen_ops(self) -> "list[tuple[str, cutlass_gemm_op.GemmOperation]]": # type: ignore[name-defined] # noqa: F821
@@ -1214,29 +1212,46 @@ def _set_bias_layout_and_alignment(
12141212
self,
12151213
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
12161214
) -> bool:
1217-
import cutlass_library.library as cutlass_lib
1218-
12191215
has_bias = len(self.input_nodes) >= 3 and self.input_nodes[2] is not None
12201216
if has_bias:
1221-
Bias = self.input_nodes[2]
1222-
# bias dtype
1223-
op.C.element = cutlass_utils.torch_dtype_to_cutlass_type(
1224-
Bias.get_layout().dtype
1225-
)
1226-
assert op.C.element == op.D.element, (
1227-
f"Expect C and D to have the same dtype, found {op.C.element} and {op.D.element}"
1228-
)
1229-
1230-
# Bias layout
1231-
bias_layout = CUTLASSGemmTemplate.cutlass_layout(Bias.get_layout())
1217+
bias = self.input_nodes[2]
1218+
bias_layout = CUTLASSGemmTemplate.cutlass_layout(bias.get_layout())
12321219
op.C.layout = bias_layout
1233-
1234-
# Bias alignment
1235-
status = self.set_alignment(Bias.get_layout(), op.C)
1220+
status = self.set_alignment(bias.get_layout(), op.C)
12361221
if not status:
12371222
return False
1223+
return True
1224+
1225+
def _dtype_match(
1226+
self,
1227+
op: "cutlass_library.gemm_op.GemmOperation", # type: ignore[name-defined] # noqa: F821
1228+
) -> bool:
1229+
"""
1230+
Checking dtypes of C (i.e. bias) here, since that is the one not checked in the base class.
1231+
"""
1232+
1233+
if not super()._dtype_match(op):
1234+
return False
1235+
1236+
assert cutlass_utils.try_import_cutlass()
1237+
from cutlass_library.library import DataType # type: ignore[import]
1238+
1239+
has_bias = len(self.input_nodes) >= 3 and self.input_nodes[2] is not None
1240+
1241+
if op.C.element == DataType.void:
1242+
if has_bias:
1243+
# op expects no bias, but bias exists
1244+
return False
12381245
else:
1239-
op.C.element = cutlass_lib.DataType.void
1246+
# op expects bias. Needs to check if bias exists and is of the right dtype
1247+
if not (
1248+
has_bias
1249+
and cutlass_utils.dtype_match(
1250+
self.input_nodes[2].get_dtype(), op.C.element
1251+
)
1252+
):
1253+
return False
1254+
12401255
return True
12411256

12421257
def _define_gemm_instance(

0 commit comments

Comments
 (0)