Skip to content

Commit 7c8ec84

Browse files
henrylhtsangpytorchmergebot
authored andcommitted
[cutlass backend] fix bug for accuminator dtype (#146356)
Will add unit tests for accuracy. Pull Request resolved: #146356 Approved by: https://github.com/Chillee
1 parent 13e17aa commit 7c8ec84

File tree

2 files changed

+6
-73
lines changed

2 files changed

+6
-73
lines changed

test/inductor/test_cutlass_backend.py

Lines changed: 6 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,6 @@ def test_max_autotune_cutlass_threshold(self):
9292
if torch.version.hip:
9393
return
9494

95-
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
96-
9795
def mm(a, b):
9896
return a @ b
9997

@@ -141,8 +139,6 @@ def test_max_autotune_precompile(self):
141139
if torch.version.hip:
142140
return
143141

144-
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
145-
146142
def mm(a, b):
147143
return a @ b
148144

@@ -170,7 +166,6 @@ def test_aoti_rerun_with_different_shapes(self):
170166
Compile with one shape, then re-run with different input shapes
171167
"""
172168
max_autotune_gemm_backends = "CUTLASS"
173-
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
174169

175170
class MyModel(torch.nn.Module):
176171
def forward(self, a, b):
@@ -216,7 +211,6 @@ def forward(self, a, b):
216211
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
217212
def test_diff_matmul_share_same_kernel(self, dynamic):
218213
max_autotune_gemm_backends = "CUTLASS"
219-
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
220214

221215
class MyModel(torch.nn.Module):
222216
def __init__(self):
@@ -267,8 +261,6 @@ def test_max_autotune_cutlass_backend_regular_mm(
267261
if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
268262
return
269263

270-
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
271-
272264
class MyModel(torch.nn.Module):
273265
def __init__(self):
274266
super().__init__()
@@ -312,8 +304,6 @@ def test_max_autotune_cutlass_backend_regular_mm_streamk(
312304
if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
313305
return
314306

315-
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
316-
317307
def mm(a, b):
318308
return a @ b
319309

@@ -356,16 +346,11 @@ def _test_max_autotune_cutlass_backend_epilogue_fusion(
356346
self,
357347
dynamic: bool = False,
358348
max_autotune_gemm_backends: str = "CUTLASS",
359-
mixed_precision=False,
360349
fp16=True,
361350
expected_fuse_count=0,
362351
mm: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
363352
batch_size: Optional[int] = None,
364353
):
365-
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (
366-
mixed_precision
367-
)
368-
369354
# Note: The ops that are available
370355
# also depend on the alignment of the shapes
371356
# so if these shapes don't all align to at least 8 elements
@@ -400,36 +385,14 @@ def _test_max_autotune_cutlass_backend_epilogue_fusion(
400385
), f"Expected fuse count of {expected_fuse_count} but got {actual_count}"
401386
torch.testing.assert_close(Y_compiled, Y, atol=1e-2, rtol=1e-2)
402387

403-
@unittest.skipIf(not SM90OrLater, "need sm_90")
404-
@unittest.skipIf(torch.version.hip, "HIP not supported")
405-
def test_max_autotune_cutlass_backend_simple_fusion_fp16(self):
406-
def mm(a, b):
407-
return (a @ b) * 3.0
408-
409-
# The pointwise ops seem to be pre-fused into a single Pointwise
410-
self._test_max_autotune_cutlass_backend_epilogue_fusion(
411-
mixed_precision=False, fp16=True, expected_fuse_count=0, mm=mm
412-
)
413-
414388
@unittest.skipIf(not SM90OrLater, "need sm_90")
415389
@unittest.skipIf(torch.version.hip, "HIP not supported")
416390
def test_max_autotune_cutlass_backend_simple_fusion_fp16_fp32acc(self):
417391
def mm(a, b):
418392
return (a @ b) * 3.0
419393

420394
self._test_max_autotune_cutlass_backend_epilogue_fusion(
421-
mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
422-
)
423-
424-
@unittest.skipIf(not SM90OrLater, "need sm_90")
425-
@unittest.skipIf(torch.version.hip, "HIP not supported")
426-
def test_max_autotune_cutlass_backend_chained_fusion_fp16(self):
427-
def mm(a, b):
428-
return (a @ b) * 3.3 - 1.234
429-
430-
# The pointwise ops seem to be pre-fused into a single Pointwise
431-
self._test_max_autotune_cutlass_backend_epilogue_fusion(
432-
mixed_precision=False, fp16=True, expected_fuse_count=0, mm=mm
395+
fp16=True, expected_fuse_count=0, mm=mm
433396
)
434397

435398
@unittest.skipIf(not SM90OrLater, "need sm_90")
@@ -439,17 +402,7 @@ def mm(a, b):
439402
return (a @ b) * 3.3 - 1.234
440403

441404
self._test_max_autotune_cutlass_backend_epilogue_fusion(
442-
mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
443-
)
444-
445-
@unittest.skipIf(not SM90OrLater, "need sm_90")
446-
@unittest.skipIf(torch.version.hip, "HIP not supported")
447-
def test_max_autotune_cutlass_backend_relu_fusion_fp16(self):
448-
def mm(a, b):
449-
return torch.nn.functional.relu((a @ b) * 3.3 - 1.234)
450-
451-
self._test_max_autotune_cutlass_backend_epilogue_fusion(
452-
mixed_precision=False, fp16=True, expected_fuse_count=0, mm=mm
405+
fp16=True, expected_fuse_count=0, mm=mm
453406
)
454407

455408
@unittest.skipIf(not SM90OrLater, "need sm_90")
@@ -460,7 +413,7 @@ def mm(a, b):
460413

461414
# The pointwise ops seem to be pre-fused into a single Pointwise
462415
self._test_max_autotune_cutlass_backend_epilogue_fusion(
463-
mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
416+
fp16=True, expected_fuse_count=0, mm=mm
464417
)
465418

466419
@unittest.skipIf(not SM90OrLater, "need sm_90")
@@ -471,7 +424,7 @@ def mm(a, b):
471424

472425
# The pointwise ops seem to be pre-fused into a single Pointwise
473426
self._test_max_autotune_cutlass_backend_epilogue_fusion(
474-
mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
427+
fp16=True, expected_fuse_count=0, mm=mm
475428
)
476429

477430
@unittest.skipIf(not SM90OrLater, "need sm_90")
@@ -482,15 +435,14 @@ def mm(a, b):
482435
return (a @ b).to(torch.float32) * 0.00001
483436

484437
self._test_max_autotune_cutlass_backend_epilogue_fusion(
485-
mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
438+
fp16=True, expected_fuse_count=0, mm=mm
486439
)
487440

488441
def test_max_autotune_cutlass_backend_simple_bmm(self):
489442
def bmm(a, b):
490443
return torch.bmm(a, b)
491444

492445
self._test_max_autotune_cutlass_backend_epilogue_fusion( # test bmm
493-
mixed_precision=False,
494446
fp16=True,
495447
expected_fuse_count=0,
496448
mm=bmm,
@@ -504,7 +456,7 @@ def mm(a, b):
504456
return (a @ b) / b.size(1)
505457

506458
self._test_max_autotune_cutlass_backend_epilogue_fusion(
507-
mixed_precision=True, fp16=True, expected_fuse_count=0, mm=mm
459+
fp16=True, expected_fuse_count=0, mm=mm
508460
)
509461

510462
# TODO: Enable dynamic test cases when dynamic support is added.
@@ -522,8 +474,6 @@ def test_max_autotune_cutlass_backend_mm_bias(
522474
if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
523475
return
524476

525-
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
526-
527477
def mm(a, b, bias):
528478
return torch.nn.functional.linear(a, b, bias)
529479

@@ -558,8 +508,6 @@ def test_max_autotune_cutlass_backend_addmm(
558508
if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
559509
return
560510

561-
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
562-
563511
def addmm(x, a, b, alpha, beta):
564512
return torch.addmm(x, a, b, alpha=alpha, beta=beta)
565513

@@ -597,8 +545,6 @@ def compare_results(
597545

598546
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
599547
def test_addmm_with_expanded_bias(self):
600-
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
601-
602548
class MyModel(torch.nn.Module):
603549
def forward(self, x, w):
604550
bias = torch.zeros(
@@ -671,8 +617,6 @@ def mm(a, b):
671617
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
672618
@unittest.skipIf(not SM90OrLater, "need sm_90")
673619
def test_force_cutlass_backend_aoti_dynamic(self):
674-
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
675-
676620
class MyModel(torch.nn.Module):
677621
def forward(self, x, w):
678622
return x @ w
@@ -709,8 +653,6 @@ def forward(self, x, w):
709653
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
710654
@unittest.skipIf(not SM90OrLater, "need sm_90")
711655
def test_force_cutlass_backend_aoti_cexpr_codegen(self):
712-
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
713-
714656
class MyModel(torch.nn.Module):
715657
def forward(self, x, w):
716658
x0, x1 = x.shape
@@ -752,8 +694,6 @@ def forward(self, x, w):
752694
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
753695
@unittest.skipIf(not SM90OrLater, "need sm_90")
754696
def test_aoti_workspace_ptr(self):
755-
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
756-
757697
class MyModel(torch.nn.Module):
758698
def forward(self, x, w):
759699
return x @ w
@@ -798,8 +738,6 @@ def test_max_autotune_cutlass_backend_mixed_mm(
798738
if max_autotune_gemm_backends == "CUTLASS" and torch.version.hip:
799739
return
800740

801-
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
802-
803741
def mm(a, b):
804742
return torch.mm(a, b.to(torch.half))
805743

torch/_inductor/codegen/cuda/cutlass_utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -279,11 +279,6 @@ def get_accumulator_dtype(
279279
]:
280280
torch_dtype = dtype0
281281

282-
if torch_dtype == torch.half:
283-
if torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction:
284-
return torch_dtype
285-
else:
286-
return torch.float
287282
if torch_dtype in (torch.float16, torch.bfloat16, torch.float):
288283
return torch.float
289284
if torch_dtype == torch.int8:

0 commit comments

Comments
 (0)