We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent fea819e commit b002365Copy full SHA for b002365
torch/_inductor/kernel/bmm.py
@@ -239,9 +239,8 @@ def _to_dtype(x):
239
templates_to_use.append(aten_handler)
240
kwarg_overrides[aten_handler.uid] = aten_extra_kwargs
241
242
- if use_triton_template(layout, check_max_autotune=False):
+ if use_triton_template(layout, check_max_autotune=False) and (out_dtype is None or out_dtype == mat1.get_dtype()):
243
# TODO: add out_dtype support for Triton Template
244
- assert out_dtype is None, "out_dtype is not supported for Triton"
245
templates_to_use.append(bmm_template)
246
247
# Single unified call for all templates
0 commit comments