Skip to content

Commit 90d0710

Browse files
committed
To vectorize long datatype as mask index
ghstack-source-id: 4527867 Pull Request resolved: #91076
1 parent 7669405 commit 90d0710

File tree

1 file changed

+51
-10
lines changed

1 file changed

+51
-10
lines changed

torch/_inductor/codegen/cpp.py

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .common import (
2020
BracesBuffer,
2121
CppWrapperKernelArgs,
22+
CSEVariable,
2223
DeferredIndentedBuffer,
2324
ExprPrinter,
2425
IndentedBuffer,
@@ -901,8 +902,47 @@ def __init__(self, args, num_threads):
901902
self.fast_vec_list.append(k)
902903
self.exit_stack = contextlib.ExitStack()
903904

905+
# Cache all the load result
906+
self.load_results: list[CSEVariable] = []
907+
self.load_supported_dtypes: list[torch.dtype] = [
908+
torch.float,
909+
torch.float32,
910+
torch.bool,
911+
torch.uint8,
912+
torch.long,
913+
]
914+
self.store_supported_dtypes: list[torch.dtype] = [torch.float, torch.float32]
915+
# Cache the dtypes of the store operation. If the store is mixing dtypes, the
916+
# vectorization would not support it as it is hard to determin the vec dtype
917+
self.store_dtypes: list[torch.dtype] = []
918+
# The dtype is used for vectorization
919+
self.vec_dtype: torch.dtype = torch.float32
920+
921+
def decide_vec_dtype(self):
922+
n_store_dtypes = len(self.store_dtypes)
923+
if n_store_dtypes == 1:
924+
self.vec_dtype = self.store_dtypes[0]
925+
926+
return self.vec_dtype
927+
928+
def is_indirect_indexing(self, index: sympy.Expr):
929+
for _load_res in self.load_results:
930+
# The index expression cotains a value that loads from memory
931+
if index.count(sympy_symbol(_load_res.name)) > 0:
932+
return True
933+
return False
934+
904935
def is_legal_data_access(self, var: sympy.Symbol, index: sympy.Expr):
905-
return self.is_var_irrevelant(var, index) or self.is_single_step_var(var, index)
936+
_indirect_indexing = not self.is_indirect_indexing(index)
937+
if _indirect_indexing:
938+
return False
939+
940+
_loop_var_irrevelant = self.is_var_irrevelant(var, index)
941+
_single_step = self.is_single_step_var(var, index)
942+
if not _single_step and not _loop_var_irrevelant:
943+
return False
944+
945+
return True
906946

907947
def could_vec(self, name: str, index: sympy.Expr):
908948
assert self.itervars is not None
@@ -914,21 +954,22 @@ def could_vec(self, name: str, index: sympy.Expr):
914954
return self.is_legal_data_access(most_inner_var, index)
915955

916956
def load(self, name: str, index: sympy.Expr):
917-
if not V.graph.get_dtype(name) in [
918-
torch.float,
919-
torch.float32,
920-
torch.bool,
921-
torch.uint8,
922-
]:
957+
var = self.cse.newvar()
958+
self.load_results.append(var)
959+
960+
if not V.graph.get_dtype(name) in self.load_supported_dtypes:
923961
self.simd_vec = False
924-
return self.simd_vec
962+
return var
925963

926964
index = self.rename_indexing(index)
927965
self.simd_vec = self.simd_vec and self.could_vec(name, index)
928-
return self.simd_vec
966+
return var
929967

930968
def store(self, name, index, value, mode=None):
931-
if not V.graph.get_dtype(name) in [torch.float, torch.float32]:
969+
store_dtype = V.graph.get_dtype(name)
970+
store_dtype = torch.float if store_dtype == torch.float32 else store_dtype
971+
self.store_dtypes.append(store_dtype)
972+
if store_dtype not in [torch.float, torch.float32]:
932973
self.simd_vec = False
933974
return self.simd_vec
934975

0 commit comments

Comments
 (0)