@@ -62,11 +62,30 @@ def test_addmm(self):
6262 def foo (input , weight , bias ):
6363 return torch .addmm (bias , input , weight )
6464
65- foo (
65+ inps = (
6666 torch .randn (20 , 33 , device = "cuda" ),
6767 torch .randn (33 , 16 , device = "cuda" ),
6868 torch .randn (20 , 16 , device = "cuda" ),
6969 )
70+
71+ foo (* inps )
72+ # Autotuning checks correctness of each version
73+ self .assertEqual (counters ["inductor" ]["select_algorithm_autotune" ], 1 )
74+
75+ @patch .object (select_algorithm , "VERIFY" , dict (atol = 5e-2 , rtol = 5e-2 ))
76+ @patches
77+ def test_addmm_fp16 (self ):
78+ @torch .compile
79+ def foo (input , weight , bias ):
80+ return torch .addmm (bias , input , weight )
81+
82+ inps = (
83+ torch .randn (2 , 320 , device = "cuda" , dtype = torch .half ),
84+ torch .randn (320 , 320 , device = "cuda" , dtype = torch .half ).t (),
85+ torch .empty (320 , device = "cuda" , dtype = torch .half ),
86+ )
87+
88+ foo (* inps )
7089 # Autotuning checks correctness of each version
7190 self .assertEqual (counters ["inductor" ]["select_algorithm_autotune" ], 1 )
7291
0 commit comments