@@ -92,8 +92,6 @@ def test_max_autotune_cutlass_threshold(self):
9292 if torch .version .hip :
9393 return
9494
95- torch .backends .cuda .matmul .allow_fp16_reduced_precision_reduction = False
96-
9795 def mm (a , b ):
9896 return a @ b
9997
@@ -141,8 +139,6 @@ def test_max_autotune_precompile(self):
141139 if torch .version .hip :
142140 return
143141
144- torch .backends .cuda .matmul .allow_fp16_reduced_precision_reduction = False
145-
146142 def mm (a , b ):
147143 return a @ b
148144
@@ -170,7 +166,6 @@ def test_aoti_rerun_with_different_shapes(self):
170166 Compile with one shape, then re-run with different input shapes
171167 """
172168 max_autotune_gemm_backends = "CUTLASS"
173- torch .backends .cuda .matmul .allow_fp16_reduced_precision_reduction = False
174169
175170 class MyModel (torch .nn .Module ):
176171 def forward (self , a , b ):
@@ -216,7 +211,6 @@ def forward(self, a, b):
216211 @unittest .mock .patch .dict (os .environ , {"PATH" : _get_path_without_sccache ()})
217212 def test_diff_matmul_share_same_kernel (self , dynamic ):
218213 max_autotune_gemm_backends = "CUTLASS"
219- torch .backends .cuda .matmul .allow_fp16_reduced_precision_reduction = False
220214
221215 class MyModel (torch .nn .Module ):
222216 def __init__ (self ):
@@ -267,8 +261,6 @@ def test_max_autotune_cutlass_backend_regular_mm(
267261 if max_autotune_gemm_backends == "CUTLASS" and torch .version .hip :
268262 return
269263
270- torch .backends .cuda .matmul .allow_fp16_reduced_precision_reduction = False
271-
272264 class MyModel (torch .nn .Module ):
273265 def __init__ (self ):
274266 super ().__init__ ()
@@ -312,8 +304,6 @@ def test_max_autotune_cutlass_backend_regular_mm_streamk(
312304 if max_autotune_gemm_backends == "CUTLASS" and torch .version .hip :
313305 return
314306
315- torch .backends .cuda .matmul .allow_fp16_reduced_precision_reduction = False
316-
317307 def mm (a , b ):
318308 return a @ b
319309
@@ -356,16 +346,11 @@ def _test_max_autotune_cutlass_backend_epilogue_fusion(
356346 self ,
357347 dynamic : bool = False ,
358348 max_autotune_gemm_backends : str = "CUTLASS" ,
359- mixed_precision = False ,
360349 fp16 = True ,
361350 expected_fuse_count = 0 ,
362351 mm : Optional [Callable [[torch .Tensor , torch .Tensor ], torch .Tensor ]] = None ,
363352 batch_size : Optional [int ] = None ,
364353 ):
365- torch .backends .cuda .matmul .allow_fp16_reduced_precision_reduction = (
366- mixed_precision
367- )
368-
369354 # Note: The ops that are available
370355 # also depend on the alignment of the shapes
371356 # so if these shapes don't all align to at least 8 elements
@@ -400,36 +385,14 @@ def _test_max_autotune_cutlass_backend_epilogue_fusion(
400385 ), f"Expected fuse count of { expected_fuse_count } but got { actual_count } "
401386 torch .testing .assert_close (Y_compiled , Y , atol = 1e-2 , rtol = 1e-2 )
402387
403- @unittest .skipIf (not SM90OrLater , "need sm_90" )
404- @unittest .skipIf (torch .version .hip , "HIP not supported" )
405- def test_max_autotune_cutlass_backend_simple_fusion_fp16 (self ):
406- def mm (a , b ):
407- return (a @ b ) * 3.0
408-
409- # The pointwise ops seem to be pre-fused into a single Pointwise
410- self ._test_max_autotune_cutlass_backend_epilogue_fusion (
411- mixed_precision = False , fp16 = True , expected_fuse_count = 0 , mm = mm
412- )
413-
414388 @unittest .skipIf (not SM90OrLater , "need sm_90" )
415389 @unittest .skipIf (torch .version .hip , "HIP not supported" )
416390 def test_max_autotune_cutlass_backend_simple_fusion_fp16_fp32acc (self ):
417391 def mm (a , b ):
418392 return (a @ b ) * 3.0
419393
420394 self ._test_max_autotune_cutlass_backend_epilogue_fusion (
421- mixed_precision = True , fp16 = True , expected_fuse_count = 0 , mm = mm
422- )
423-
424- @unittest .skipIf (not SM90OrLater , "need sm_90" )
425- @unittest .skipIf (torch .version .hip , "HIP not supported" )
426- def test_max_autotune_cutlass_backend_chained_fusion_fp16 (self ):
427- def mm (a , b ):
428- return (a @ b ) * 3.3 - 1.234
429-
430- # The pointwise ops seem to be pre-fused into a single Pointwise
431- self ._test_max_autotune_cutlass_backend_epilogue_fusion (
432- mixed_precision = False , fp16 = True , expected_fuse_count = 0 , mm = mm
395+ fp16 = True , expected_fuse_count = 0 , mm = mm
433396 )
434397
435398 @unittest .skipIf (not SM90OrLater , "need sm_90" )
@@ -439,17 +402,7 @@ def mm(a, b):
439402 return (a @ b ) * 3.3 - 1.234
440403
441404 self ._test_max_autotune_cutlass_backend_epilogue_fusion (
442- mixed_precision = True , fp16 = True , expected_fuse_count = 0 , mm = mm
443- )
444-
445- @unittest .skipIf (not SM90OrLater , "need sm_90" )
446- @unittest .skipIf (torch .version .hip , "HIP not supported" )
447- def test_max_autotune_cutlass_backend_relu_fusion_fp16 (self ):
448- def mm (a , b ):
449- return torch .nn .functional .relu ((a @ b ) * 3.3 - 1.234 )
450-
451- self ._test_max_autotune_cutlass_backend_epilogue_fusion (
452- mixed_precision = False , fp16 = True , expected_fuse_count = 0 , mm = mm
405+ fp16 = True , expected_fuse_count = 0 , mm = mm
453406 )
454407
455408 @unittest .skipIf (not SM90OrLater , "need sm_90" )
@@ -460,7 +413,7 @@ def mm(a, b):
460413
461414 # The pointwise ops seem to be pre-fused into a single Pointwise
462415 self ._test_max_autotune_cutlass_backend_epilogue_fusion (
463- mixed_precision = True , fp16 = True , expected_fuse_count = 0 , mm = mm
416+ fp16 = True , expected_fuse_count = 0 , mm = mm
464417 )
465418
466419 @unittest .skipIf (not SM90OrLater , "need sm_90" )
@@ -471,7 +424,7 @@ def mm(a, b):
471424
472425 # The pointwise ops seem to be pre-fused into a single Pointwise
473426 self ._test_max_autotune_cutlass_backend_epilogue_fusion (
474- mixed_precision = True , fp16 = True , expected_fuse_count = 0 , mm = mm
427+ fp16 = True , expected_fuse_count = 0 , mm = mm
475428 )
476429
477430 @unittest .skipIf (not SM90OrLater , "need sm_90" )
@@ -482,15 +435,14 @@ def mm(a, b):
482435 return (a @ b ).to (torch .float32 ) * 0.00001
483436
484437 self ._test_max_autotune_cutlass_backend_epilogue_fusion (
485- mixed_precision = True , fp16 = True , expected_fuse_count = 0 , mm = mm
438+ fp16 = True , expected_fuse_count = 0 , mm = mm
486439 )
487440
488441 def test_max_autotune_cutlass_backend_simple_bmm (self ):
489442 def bmm (a , b ):
490443 return torch .bmm (a , b )
491444
492445 self ._test_max_autotune_cutlass_backend_epilogue_fusion ( # test bmm
493- mixed_precision = False ,
494446 fp16 = True ,
495447 expected_fuse_count = 0 ,
496448 mm = bmm ,
@@ -504,7 +456,7 @@ def mm(a, b):
504456 return (a @ b ) / b .size (1 )
505457
506458 self ._test_max_autotune_cutlass_backend_epilogue_fusion (
507- mixed_precision = True , fp16 = True , expected_fuse_count = 0 , mm = mm
459+ fp16 = True , expected_fuse_count = 0 , mm = mm
508460 )
509461
510462 # TODO: Enable dynamic test cases when dynamic support is added.
@@ -522,8 +474,6 @@ def test_max_autotune_cutlass_backend_mm_bias(
522474 if max_autotune_gemm_backends == "CUTLASS" and torch .version .hip :
523475 return
524476
525- torch .backends .cuda .matmul .allow_fp16_reduced_precision_reduction = False
526-
527477 def mm (a , b , bias ):
528478 return torch .nn .functional .linear (a , b , bias )
529479
@@ -558,8 +508,6 @@ def test_max_autotune_cutlass_backend_addmm(
558508 if max_autotune_gemm_backends == "CUTLASS" and torch .version .hip :
559509 return
560510
561- torch .backends .cuda .matmul .allow_fp16_reduced_precision_reduction = False
562-
563511 def addmm (x , a , b , alpha , beta ):
564512 return torch .addmm (x , a , b , alpha = alpha , beta = beta )
565513
@@ -597,8 +545,6 @@ def compare_results(
597545
598546 @unittest .mock .patch .dict (os .environ , {"PATH" : _get_path_without_sccache ()})
599547 def test_addmm_with_expanded_bias (self ):
600- torch .backends .cuda .matmul .allow_fp16_reduced_precision_reduction = False
601-
602548 class MyModel (torch .nn .Module ):
603549 def forward (self , x , w ):
604550 bias = torch .zeros (
@@ -671,8 +617,6 @@ def mm(a, b):
671617 @unittest .mock .patch .dict (os .environ , {"PATH" : _get_path_without_sccache ()})
672618 @unittest .skipIf (not SM90OrLater , "need sm_90" )
673619 def test_force_cutlass_backend_aoti_dynamic (self ):
674- torch .backends .cuda .matmul .allow_fp16_reduced_precision_reduction = False
675-
676620 class MyModel (torch .nn .Module ):
677621 def forward (self , x , w ):
678622 return x @ w
@@ -709,8 +653,6 @@ def forward(self, x, w):
709653 @unittest .mock .patch .dict (os .environ , {"PATH" : _get_path_without_sccache ()})
710654 @unittest .skipIf (not SM90OrLater , "need sm_90" )
711655 def test_force_cutlass_backend_aoti_cexpr_codegen (self ):
712- torch .backends .cuda .matmul .allow_fp16_reduced_precision_reduction = False
713-
714656 class MyModel (torch .nn .Module ):
715657 def forward (self , x , w ):
716658 x0 , x1 = x .shape
@@ -752,8 +694,6 @@ def forward(self, x, w):
752694 @unittest .mock .patch .dict (os .environ , {"PATH" : _get_path_without_sccache ()})
753695 @unittest .skipIf (not SM90OrLater , "need sm_90" )
754696 def test_aoti_workspace_ptr (self ):
755- torch .backends .cuda .matmul .allow_fp16_reduced_precision_reduction = False
756-
757697 class MyModel (torch .nn .Module ):
758698 def forward (self , x , w ):
759699 return x @ w
@@ -798,8 +738,6 @@ def test_max_autotune_cutlass_backend_mixed_mm(
798738 if max_autotune_gemm_backends == "CUTLASS" and torch .version .hip :
799739 return
800740
801- torch .backends .cuda .matmul .allow_fp16_reduced_precision_reduction = False
802-
803741 def mm (a , b ):
804742 return torch .mm (a , b .to (torch .half ))
805743
0 commit comments