Skip to content

Commit b002365

Browse files
committed
Silently bypass bmm autotuning with out_dtype argument
ghstack-source-id: b37202c Pull Request resolved: #166457
1 parent fea819e commit b002365

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

torch/_inductor/kernel/bmm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,8 @@ def _to_dtype(x):
239239
templates_to_use.append(aten_handler)
240240
kwarg_overrides[aten_handler.uid] = aten_extra_kwargs
241241

242-
if use_triton_template(layout, check_max_autotune=False):
242+
if use_triton_template(layout, check_max_autotune=False) and (out_dtype is None or out_dtype == mat1.get_dtype()):
243243
# TODO: add out_dtype support for Triton Template
244-
assert out_dtype is None, "out_dtype is not supported for Triton"
245244
templates_to_use.append(bmm_template)
246245

247246
# Single unified call for all templates

0 commit comments

Comments
 (0)