77from pathlib import Path
88from typing import Dict , List
99
10+ import numpy
1011import sympy
1112
1213import torch
14+ import torch .fx
1315from torch ._prims_common import is_float_dtype
1416
1517from .. import codecache , config , ir , metrics
1921from .common import (
2022 BracesBuffer ,
2123 CppWrapperKernelArgs ,
24+ CSEVariable ,
2225 DeferredIndentedBuffer ,
2326 ExprPrinter ,
2427 IndentedBuffer ,
@@ -231,6 +234,34 @@ def erf(x):
231234 def sqrt (x ):
232235 return f"{ x } .sqrt()"
233236
237+ @staticmethod
238+ def eq (x , y ):
239+ return f"{ x } == { y } "
240+
241+ @staticmethod
242+ def ne (x , y ):
243+ return f"{ x } != { y } "
244+
245+ @staticmethod
246+ def lt (x , y ):
247+ return f"{ x } < { y } "
248+
249+ @staticmethod
250+ def gt (x , y ):
251+ return f"{ x } > { y } "
252+
253+ @staticmethod
254+ def le (x , y ):
255+ return f"{ x } <= { y } "
256+
257+ @staticmethod
258+ def ge (x , y ):
259+ return f"{ x } >= { y } "
260+
261+ @staticmethod
262+ def and_ (x , y ):
263+ return f"{ x } & { y } "
264+
234265 @staticmethod
235266 def rsqrt (x ):
236267 return f"{ x } .rsqrt()"
@@ -285,17 +316,19 @@ def reciprocal(a):
285316
286317 @staticmethod
287318 def constant (val , dtype ):
319+ proposed_dtype = V .interpreter .current_node .meta ["dtype" ]
288320 if val == float ("inf" ):
289- quote = f"std::numeric_limits<{ DTYPE_TO_CPP [dtype ]} >::infinity()"
321+ quote = f"std::numeric_limits<{ DTYPE_TO_CPP [proposed_dtype ]} >::infinity()"
290322 elif val == float ("-inf" ):
291- quote = f"-std::numeric_limits<{ DTYPE_TO_CPP [dtype ]} >::infinity()"
323+ quote = f"-std::numeric_limits<{ DTYPE_TO_CPP [proposed_dtype ]} >::infinity()"
292324 elif math .isnan (val ):
293- quote = f"std::numeric_limits<{ DTYPE_TO_CPP [dtype ]} >::quiet_NaN()"
325+ quote = f"std::numeric_limits<{ DTYPE_TO_CPP [proposed_dtype ]} >::quiet_NaN()"
294326 elif val is True or val is False :
295- quote = f"static_cast<{ DTYPE_TO_CPP [dtype ]} >({ str (val ).lower ()} )"
327+ quote = f"static_cast<{ DTYPE_TO_CPP [proposed_dtype ]} >({ str (val ).lower ()} )"
296328 else :
297- quote = f"static_cast<{ DTYPE_TO_CPP [dtype ]} >({ repr (val )} )"
298- return f"at::vec::Vectorized<{ DTYPE_TO_CPP [dtype ]} >({ quote } )"
329+ quote = f"static_cast<{ DTYPE_TO_CPP [proposed_dtype ]} >({ repr (val )} )"
330+
331+ return f"at::vec::Vectorized<{ DTYPE_TO_CPP [proposed_dtype ]} >({ quote } )"
299332
300333 @staticmethod
301334 def relu (x ):
@@ -370,6 +403,40 @@ def expm1(x):
370403 def log1p (x ):
371404 return f"{ x } .log1p()"
372405
406+ @staticmethod
407+ def masked (mask , body , other ):
408+ assert V .interpreter .current_node .meta ["is_masked_load" ]
409+ code = BracesBuffer ()
410+
411+ var = V .kernel .cse .newvar ()
412+ if other == float ("-inf" ):
413+ code .writeline (
414+ f"auto { var } = at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity());"
415+ )
416+ elif other == float ("inf" ):
417+ code .writeline (
418+ f"auto { var } = at::vec::Vectorized<float>(std::numeric_limits<float>::infinity());"
419+ )
420+ elif isinstance (other , float ):
421+ code .writeline (f"auto { var } = at::vec::Vectorized<float>({ other } );" )
422+ else :
423+ code .writeline (f"auto { var } = at::vec::Vectorized<float>({ other !r} );" )
424+ with V .kernel .swap_buffers (code ), code .indent ():
425+ result = body ()
426+ zero_val = "at::vec::Vectorized<float>(0)"
427+ float_mask = f"flag_to_float_vec({ mask } )"
428+ blendv = f"decltype({ result } )::blendv({ var } , { result } , { float_mask } != { zero_val } )"
429+ code .writeline (f"{ var } = { blendv } ;" )
430+ V .kernel .compute .splice (code )
431+ return var
432+
433+ @staticmethod
434+ def index_expr (expr , dtype ):
435+ assert dtype == torch .int64
436+ assert V .interpreter .current_node .meta ["dtype" ] == torch .int32
437+ assert V .interpreter .current_node .meta ["most_inner_loop_irrevelant" ]
438+ return f"at::vec::Vectorized<int>(static_cast<int>({ cexpr (V .kernel .rename_indexing (expr ))} ))"
439+
373440
374441class CppOverrides (OpOverrides ):
375442 """Map element-wise ops to C++"""
@@ -905,8 +972,47 @@ def __init__(self, args, num_threads):
905972 self .fast_vec_list .append (k )
906973 self .exit_stack = contextlib .ExitStack ()
907974
975+ # Cache all the load result
976+ self .load_results : list [CSEVariable ] = []
977+ self .load_supported_dtypes : list [torch .dtype ] = [
978+ torch .float ,
979+ torch .float32 ,
980+ torch .bool ,
981+ torch .uint8 ,
982+ torch .long ,
983+ ]
984+ self .store_supported_dtypes : list [torch .dtype ] = [torch .float , torch .float32 ]
985+ # Cache the dtypes of the store operation. If the store is mixing dtypes, the
986+ # vectorization would not support it as it is hard to determin the vec dtype
987+ self .store_dtypes : list [torch .dtype ] = []
988+ # The dtype is used for vectorization
989+ self .vec_dtype : torch .dtype = torch .float32
990+
991+ def decide_vec_dtype (self ):
992+ n_store_dtypes = len (self .store_dtypes )
993+ if n_store_dtypes == 1 :
994+ self .vec_dtype = self .store_dtypes [0 ]
995+
996+ return self .vec_dtype
997+
998+ def is_indirect_indexing (self , index : sympy .Expr ):
999+ for _load_res in self .load_results :
1000+ # The index expression cotains a value that loads from memory
1001+ if index .count (sympy_symbol (_load_res .name )) > 0 :
1002+ return True
1003+ return False
1004+
9081005 def is_legal_data_access (self , var : sympy .Symbol , index : sympy .Expr ):
909- return self .is_var_irrevelant (var , index ) or self .is_single_step_var (var , index )
1006+ _indirect_indexing = self .is_indirect_indexing (index )
1007+ if _indirect_indexing :
1008+ return False
1009+
1010+ _loop_var_irrevelant = self .is_var_irrevelant (var , index )
1011+ _single_step = self .is_single_step_var (var , index )
1012+ if not _single_step and not _loop_var_irrevelant :
1013+ return False
1014+
1015+ return True
9101016
9111017 def could_vec (self , name : str , index : sympy .Expr ):
9121018 assert self .itervars is not None
@@ -918,21 +1024,40 @@ def could_vec(self, name: str, index: sympy.Expr):
9181024 return self .is_legal_data_access (most_inner_var , index )
9191025
9201026 def load (self , name : str , index : sympy .Expr ):
921- if not V .graph .get_dtype (name ) in [
922- torch .float ,
923- torch .float32 ,
924- torch .bool ,
925- torch .uint8 ,
926- ]:
1027+ load_type = V .graph .get_dtype (name )
1028+ current_node : torch .fx .Node = V .interpreter .current_node
1029+ current_node .meta ["dtype" ] = load_type
1030+
1031+ var = self .cse .newvar ()
1032+ self .load_results .append (var )
1033+
1034+ if not V .graph .get_dtype (name ) in self .load_supported_dtypes :
9271035 self .simd_vec = False
928- return self .simd_vec
1036+ return var
1037+
1038+ def is_mask ():
1039+ user_nodes = current_node .users
1040+ for __node in user_nodes .keys ():
1041+ _node : torch .fx .Node = __node
1042+ if _node .target not in ["where" , "masked" ]:
1043+ return False
1044+ return True
1045+
1046+ current_node .meta ["is_mask" ] = is_mask ()
9291047
9301048 index = self .rename_indexing (index )
9311049 self .simd_vec = self .simd_vec and self .could_vec (name , index )
932- return self . simd_vec
1050+ return var
9331051
9341052 def store (self , name , index , value , mode = None ):
935- if not V .graph .get_dtype (name ) in [torch .float , torch .float32 ]:
1053+ store_dtype = V .graph .get_dtype (name )
1054+
1055+ current_node : torch .fx .Node = V .interpreter .current_node
1056+ current_node .meta ["dtype" ] = store_dtype
1057+
1058+ store_dtype = torch .float if store_dtype == torch .float32 else store_dtype
1059+ self .store_dtypes .append (store_dtype )
1060+ if store_dtype not in [torch .float , torch .float32 ]:
9361061 self .simd_vec = False
9371062 return self .simd_vec
9381063
@@ -957,6 +1082,27 @@ def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
9571082 self .simd_vec = False
9581083 return self .simd_vec
9591084
1085+ def is_load_only_block (self , sub_graph : torch .fx .Graph ):
1086+ # The sub graph only contains "placeholder", "output", "get_index", "load"
1087+ is_load_only = False
1088+ load_dtype = None
1089+ skip_io_nodes = ["placeholder" , "output" ]
1090+ for _node in sub_graph .nodes :
1091+ if _node .op in skip_io_nodes :
1092+ continue
1093+
1094+ if _node .target not in ["load" , "get_index" ]:
1095+ # The body contains non load node
1096+ is_load_only = False
1097+ break
1098+
1099+ if _node .target == "load" :
1100+ _ , name , _ = _node .args
1101+ load_dtype = V .graph .get_dtype (name )
1102+ is_load_only = True
1103+
1104+ return is_load_only , load_dtype
1105+
9601106 def __exit__ (self , exc_type , exc_val , exc_tb ):
9611107 assert self ._orig_wrapper_code is not None
9621108 # Restore the wrapper_code
@@ -999,15 +1145,60 @@ def reduction(name, dtype, src_dtype, reduction_type, index, value):
9991145
10001146 @staticmethod
10011147 def constant (val , dtype ):
1148+ current_node : torch .fx .Node = V .interpreter .current_node
1149+ current_node .meta ["dtype" ] = dtype
1150+ i32_iinfo = numpy .iinfo (numpy .int32 )
1151+ if (
1152+ dtype == torch .int64
1153+ and val <= i32_iinfo .max
1154+ and val >= i32_iinfo .min
1155+ ):
1156+ current_node .meta ["dtype" ] = torch .int32
1157+ f64_iinfo = numpy .finfo (numpy .float32 )
1158+ if (
1159+ dtype == torch .double
1160+ and val <= f64_iinfo .max
1161+ and val >= f64_iinfo .min
1162+ ):
1163+ current_node .meta ["dtype" ] = torch .float32
1164+
10021165 supported_dtype = (torch .float32 , torch .int32 )
1003- is_supported_dtype = dtype in (supported_dtype )
1166+ is_supported_dtype = current_node . meta [ " dtype" ] in (supported_dtype )
10041167 if not is_supported_dtype :
10051168 self .simd_vec = False
10061169 return is_supported_dtype
10071170
10081171 @staticmethod
10091172 def index_expr (expr , dtype ):
1010- self .simd_vec = False
1173+ current_node : torch .fx .Node = V .interpreter .current_node
1174+
1175+ loop_range = {}
1176+ assert len (self .ranges ) == len (self .itervars )
1177+ for idx in range (len (self .ranges )):
1178+ loop_range [self .itervars [idx ]] = self .ranges [idx ]
1179+ expr_val = sympy .simplify (sympy_subs (expr , loop_range ))
1180+ i32_iinfo = numpy .iinfo (numpy .int32 )
1181+ if (
1182+ dtype == torch .int64
1183+ and expr_val <= i32_iinfo .max
1184+ and expr_val >= i32_iinfo .min
1185+ ):
1186+ current_node .meta ["dtype" ] = torch .int32
1187+ else :
1188+ self .simd_vec = False
1189+
1190+ # Pick the most inner loop variable since we always vectorize the
1191+ # most inner loop
1192+ most_inner_var = self .itervars [- 1 ]
1193+ most_inner_loop_irrevelant = self .is_var_irrevelant (
1194+ most_inner_var , expr
1195+ )
1196+ if not most_inner_loop_irrevelant :
1197+ self .simd_vec = False
1198+ current_node .meta [
1199+ "most_inner_loop_irrevelant"
1200+ ] = most_inner_loop_irrevelant
1201+
10111202 tmp_var = self .cse .newvar ()
10121203 return tmp_var
10131204
@@ -1018,11 +1209,24 @@ def indirect_indexing(index_var):
10181209
10191210 @staticmethod
10201211 def masked (mask , body , other ):
1212+ current_node : torch .fx .Node = V .interpreter .current_node
1213+ is_masked_load , load_dtype = self .is_load_only_block (body .graph )
1214+ current_node .meta ["dtype" ] = load_dtype
1215+ current_node .meta ["is_masked_load" ] = is_masked_load
1216+
1217+ self .simd_vec = is_masked_load and current_node .meta ["dtype" ] in [
1218+ torch .float32 ,
1219+ torch .float ,
1220+ ]
1221+
10211222 tmp_var = self .cse .newvar ()
10221223 return tmp_var
10231224
10241225 @staticmethod
10251226 def to_dtype (x , dtype ):
1227+ current_node : torch .fx .Node = V .interpreter .current_node
1228+ current_node ["dtype" ] = dtype
1229+
10261230 if dtype != torch .bool :
10271231 self .simd_vec = False
10281232 return x
0 commit comments