1919from .common import (
2020 BracesBuffer ,
2121 CppWrapperKernelArgs ,
22+ CSEVariable ,
2223 DeferredIndentedBuffer ,
2324 ExprPrinter ,
2425 IndentedBuffer ,
@@ -901,8 +902,47 @@ def __init__(self, args, num_threads):
901902 self .fast_vec_list .append (k )
902903 self .exit_stack = contextlib .ExitStack ()
903904
905+ # Cache all the load result
906+ self .load_results : list [CSEVariable ] = []
907+ self .load_supported_dtypes : list [torch .dtype ] = [
908+ torch .float ,
909+ torch .float32 ,
910+ torch .bool ,
911+ torch .uint8 ,
912+ torch .long ,
913+ ]
914+ self .store_supported_dtypes : list [torch .dtype ] = [torch .float , torch .float32 ]
915+ # Cache the dtypes of the store operation. If the store is mixing dtypes, the
916+ # vectorization would not support it as it is hard to determin the vec dtype
917+ self .store_dtypes : list [torch .dtype ] = []
918+ # The dtype is used for vectorization
919+ self .vec_dtype : torch .dtype = torch .float32
920+
921+ def decide_vec_dtype (self ):
922+ n_store_dtypes = len (self .store_dtypes )
923+ if n_store_dtypes == 1 :
924+ self .vec_dtype = self .store_dtypes [0 ]
925+
926+ return self .vec_dtype
927+
928+ def is_indirect_indexing (self , index : sympy .Expr ):
929+ for _load_res in self .load_results :
930+ # The index expression cotains a value that loads from memory
931+ if index .count (sympy_symbol (_load_res .name )) > 0 :
932+ return True
933+ return False
934+
904935 def is_legal_data_access (self , var : sympy .Symbol , index : sympy .Expr ):
905- return self .is_var_irrevelant (var , index ) or self .is_single_step_var (var , index )
936+ _indirect_indexing = not self .is_indirect_indexing (index )
937+ if _indirect_indexing :
938+ return False
939+
940+ _loop_var_irrevelant = self .is_var_irrevelant (var , index )
941+ _single_step = self .is_single_step_var (var , index )
942+ if not _single_step and not _loop_var_irrevelant :
943+ return False
944+
945+ return True
906946
907947 def could_vec (self , name : str , index : sympy .Expr ):
908948 assert self .itervars is not None
@@ -914,21 +954,22 @@ def could_vec(self, name: str, index: sympy.Expr):
914954 return self .is_legal_data_access (most_inner_var , index )
915955
916956 def load (self , name : str , index : sympy .Expr ):
917- if not V .graph .get_dtype (name ) in [
918- torch .float ,
919- torch .float32 ,
920- torch .bool ,
921- torch .uint8 ,
922- ]:
957+ var = self .cse .newvar ()
958+ self .load_results .append (var )
959+
960+ if not V .graph .get_dtype (name ) in self .load_supported_dtypes :
923961 self .simd_vec = False
924- return self . simd_vec
962+ return var
925963
926964 index = self .rename_indexing (index )
927965 self .simd_vec = self .simd_vec and self .could_vec (name , index )
928- return self . simd_vec
966+ return var
929967
930968 def store (self , name , index , value , mode = None ):
931- if not V .graph .get_dtype (name ) in [torch .float , torch .float32 ]:
969+ store_dtype = V .graph .get_dtype (name )
970+ store_dtype = torch .float if store_dtype == torch .float32 else store_dtype
971+ self .store_dtypes .append (store_dtype )
972+ if store_dtype not in [torch .float , torch .float32 ]:
932973 self .simd_vec = False
933974 return self .simd_vec
934975
0 commit comments