Skip to content

Commit a7bf8ac

Browse files
committed
To vectorize long datatype as mask index
ghstack-source-id: d020ab9 Pull Request resolved: #91076
1 parent 57dcd93 commit a7bf8ac

File tree

5 files changed

+280
-19
lines changed

5 files changed

+280
-19
lines changed

test/inductor/test_torchinductor.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5270,6 +5270,23 @@ def fn(x):
52705270
assert same(fn(x)[0], compiled([x])[0], equal_nan=True)
52715271
assert metrics.generated_cpp_vec_kernel_count == 1
52725272

5273+
@unittest.skipIf(
5274+
not codecache.valid_vec_isa_list(), "Does not support vectorization"
5275+
)
5276+
@patch("torch.cuda.is_available", lambda: False)
5277+
def test_maxpool2d_cpu_only(self):
5278+
input = torch.randn(10, 32, 20, 20).to(memory_format=torch.channels_last)
5279+
maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
5280+
5281+
def func(x):
5282+
return maxpool(x)
5283+
5284+
with patch.object(config.cpp, "simdlen", None):
5285+
graph = torch.compile(func, backend="inductor")
5286+
graph(input)
5287+
assert same(graph(input), func(input), equal_nan=True)
5288+
assert metrics.generated_cpp_vec_kernel_count == 1
5289+
52735290
@unittest.skipIf(
52745291
not codecache.valid_vec_isa_list(), "Does not support vectorization"
52755292
)

torch/_inductor/codegen/cpp.py

Lines changed: 222 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
from pathlib import Path
88
from typing import Dict, List
99

10+
import numpy
1011
import sympy
1112

1213
import torch
14+
import torch.fx
1315
from torch._prims_common import is_float_dtype
1416

1517
from .. import codecache, config, ir, metrics
@@ -19,6 +21,7 @@
1921
from .common import (
2022
BracesBuffer,
2123
CppWrapperKernelArgs,
24+
CSEVariable,
2225
DeferredIndentedBuffer,
2326
ExprPrinter,
2427
IndentedBuffer,
@@ -231,6 +234,34 @@ def erf(x):
231234
def sqrt(x):
232235
return f"{x}.sqrt()"
233236

237+
@staticmethod
238+
def eq(x, y):
239+
return f"{x} == {y}"
240+
241+
@staticmethod
242+
def ne(x, y):
243+
return f"{x} != {y}"
244+
245+
@staticmethod
246+
def lt(x, y):
247+
return f"{x} < {y}"
248+
249+
@staticmethod
250+
def gt(x, y):
251+
return f"{x} > {y}"
252+
253+
@staticmethod
254+
def le(x, y):
255+
return f"{x} <= {y}"
256+
257+
@staticmethod
258+
def ge(x, y):
259+
return f"{x} >= {y}"
260+
261+
@staticmethod
262+
def and_(x, y):
263+
return f"{x} & {y}"
264+
234265
@staticmethod
235266
def rsqrt(x):
236267
return f"{x}.rsqrt()"
@@ -285,17 +316,19 @@ def reciprocal(a):
285316

286317
@staticmethod
287318
def constant(val, dtype):
319+
proposed_dtype = V.interpreter.current_node.meta["dtype"]
288320
if val == float("inf"):
289-
quote = f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()"
321+
quote = f"std::numeric_limits<{DTYPE_TO_CPP[proposed_dtype]}>::infinity()"
290322
elif val == float("-inf"):
291-
quote = f"-std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::infinity()"
323+
quote = f"-std::numeric_limits<{DTYPE_TO_CPP[proposed_dtype]}>::infinity()"
292324
elif math.isnan(val):
293-
quote = f"std::numeric_limits<{DTYPE_TO_CPP[dtype]}>::quiet_NaN()"
325+
quote = f"std::numeric_limits<{DTYPE_TO_CPP[proposed_dtype]}>::quiet_NaN()"
294326
elif val is True or val is False:
295-
quote = f"static_cast<{DTYPE_TO_CPP[dtype]}>({str(val).lower()})"
327+
quote = f"static_cast<{DTYPE_TO_CPP[proposed_dtype]}>({str(val).lower()})"
296328
else:
297-
quote = f"static_cast<{DTYPE_TO_CPP[dtype]}>({repr(val)})"
298-
return f"at::vec::Vectorized<{DTYPE_TO_CPP[dtype]}>({quote})"
329+
quote = f"static_cast<{DTYPE_TO_CPP[proposed_dtype]}>({repr(val)})"
330+
331+
return f"at::vec::Vectorized<{DTYPE_TO_CPP[proposed_dtype]}>({quote})"
299332

300333
@staticmethod
301334
def relu(x):
@@ -370,6 +403,40 @@ def expm1(x):
370403
def log1p(x):
371404
return f"{x}.log1p()"
372405

406+
@staticmethod
407+
def masked(mask, body, other):
408+
assert V.interpreter.current_node.meta["is_masked_load"]
409+
code = BracesBuffer()
410+
411+
var = V.kernel.cse.newvar()
412+
if other == float("-inf"):
413+
code.writeline(
414+
f"auto {var} = at::vec::Vectorized<float>(-std::numeric_limits<float>::infinity());"
415+
)
416+
elif other == float("inf"):
417+
code.writeline(
418+
f"auto {var} = at::vec::Vectorized<float>(std::numeric_limits<float>::infinity());"
419+
)
420+
elif isinstance(other, float):
421+
code.writeline(f"auto {var} = at::vec::Vectorized<float>({other});")
422+
else:
423+
code.writeline(f"auto {var} = at::vec::Vectorized<float>({other!r});")
424+
with V.kernel.swap_buffers(code), code.indent():
425+
result = body()
426+
zero_val = "at::vec::Vectorized<float>(0)"
427+
float_mask = f"flag_to_float_vec({mask})"
428+
blendv = f"decltype({result})::blendv({var}, {result}, {float_mask} != {zero_val})"
429+
code.writeline(f"{var} = {blendv};")
430+
V.kernel.compute.splice(code)
431+
return var
432+
433+
@staticmethod
434+
def index_expr(expr, dtype):
435+
assert dtype == torch.int64
436+
assert V.interpreter.current_node.meta["dtype"] == torch.int32
437+
assert V.interpreter.current_node.meta["most_inner_loop_irrevelant"]
438+
return f"at::vec::Vectorized<int>(static_cast<int>({cexpr(V.kernel.rename_indexing(expr))}))"
439+
373440

374441
class CppOverrides(OpOverrides):
375442
"""Map element-wise ops to C++"""
@@ -905,8 +972,47 @@ def __init__(self, args, num_threads):
905972
self.fast_vec_list.append(k)
906973
self.exit_stack = contextlib.ExitStack()
907974

975+
# Cache all the load result
976+
self.load_results: list[CSEVariable] = []
977+
self.load_supported_dtypes: list[torch.dtype] = [
978+
torch.float,
979+
torch.float32,
980+
torch.bool,
981+
torch.uint8,
982+
torch.long,
983+
]
984+
self.store_supported_dtypes: list[torch.dtype] = [torch.float, torch.float32]
985+
# Cache the dtypes of the store operation. If the store is mixing dtypes, the
986+
# vectorization would not support it as it is hard to determin the vec dtype
987+
self.store_dtypes: list[torch.dtype] = []
988+
# The dtype is used for vectorization
989+
self.vec_dtype: torch.dtype = torch.float32
990+
991+
def decide_vec_dtype(self):
992+
n_store_dtypes = len(self.store_dtypes)
993+
if n_store_dtypes == 1:
994+
self.vec_dtype = self.store_dtypes[0]
995+
996+
return self.vec_dtype
997+
998+
def is_indirect_indexing(self, index: sympy.Expr):
999+
for _load_res in self.load_results:
1000+
# The index expression cotains a value that loads from memory
1001+
if index.count(sympy_symbol(_load_res.name)) > 0:
1002+
return True
1003+
return False
1004+
9081005
def is_legal_data_access(self, var: sympy.Symbol, index: sympy.Expr):
909-
return self.is_var_irrevelant(var, index) or self.is_single_step_var(var, index)
1006+
_indirect_indexing = self.is_indirect_indexing(index)
1007+
if _indirect_indexing:
1008+
return False
1009+
1010+
_loop_var_irrevelant = self.is_var_irrevelant(var, index)
1011+
_single_step = self.is_single_step_var(var, index)
1012+
if not _single_step and not _loop_var_irrevelant:
1013+
return False
1014+
1015+
return True
9101016

9111017
def could_vec(self, name: str, index: sympy.Expr):
9121018
assert self.itervars is not None
@@ -918,21 +1024,40 @@ def could_vec(self, name: str, index: sympy.Expr):
9181024
return self.is_legal_data_access(most_inner_var, index)
9191025

9201026
def load(self, name: str, index: sympy.Expr):
921-
if not V.graph.get_dtype(name) in [
922-
torch.float,
923-
torch.float32,
924-
torch.bool,
925-
torch.uint8,
926-
]:
1027+
load_type = V.graph.get_dtype(name)
1028+
current_node: torch.fx.Node = V.interpreter.current_node
1029+
current_node.meta["dtype"] = load_type
1030+
1031+
var = self.cse.newvar()
1032+
self.load_results.append(var)
1033+
1034+
if not V.graph.get_dtype(name) in self.load_supported_dtypes:
9271035
self.simd_vec = False
928-
return self.simd_vec
1036+
return var
1037+
1038+
def is_mask():
1039+
user_nodes = current_node.users
1040+
for __node in user_nodes.keys():
1041+
_node: torch.fx.Node = __node
1042+
if _node.target not in ["where", "masked"]:
1043+
return False
1044+
return True
1045+
1046+
current_node.meta["is_mask"] = is_mask()
9291047

9301048
index = self.rename_indexing(index)
9311049
self.simd_vec = self.simd_vec and self.could_vec(name, index)
932-
return self.simd_vec
1050+
return var
9331051

9341052
def store(self, name, index, value, mode=None):
935-
if not V.graph.get_dtype(name) in [torch.float, torch.float32]:
1053+
store_dtype = V.graph.get_dtype(name)
1054+
1055+
current_node: torch.fx.Node = V.interpreter.current_node
1056+
current_node.meta["dtype"] = store_dtype
1057+
1058+
store_dtype = torch.float if store_dtype == torch.float32 else store_dtype
1059+
self.store_dtypes.append(store_dtype)
1060+
if store_dtype not in [torch.float, torch.float32]:
9361061
self.simd_vec = False
9371062
return self.simd_vec
9381063

@@ -957,6 +1082,27 @@ def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
9571082
self.simd_vec = False
9581083
return self.simd_vec
9591084

1085+
def is_load_only_block(self, sub_graph: torch.fx.Graph):
1086+
# The sub graph only contains "placeholder", "output", "get_index", "load"
1087+
is_load_only = False
1088+
load_dtype = None
1089+
skip_io_nodes = ["placeholder", "output"]
1090+
for _node in sub_graph.nodes:
1091+
if _node.op in skip_io_nodes:
1092+
continue
1093+
1094+
if _node.target not in ["load", "get_index"]:
1095+
# The body contains non load node
1096+
is_load_only = False
1097+
break
1098+
1099+
if _node.target == "load":
1100+
_, name, _ = _node.args
1101+
load_dtype = V.graph.get_dtype(name)
1102+
is_load_only = True
1103+
1104+
return is_load_only, load_dtype
1105+
9601106
def __exit__(self, exc_type, exc_val, exc_tb):
9611107
assert self._orig_wrapper_code is not None
9621108
# Restore the wrapper_code
@@ -999,15 +1145,60 @@ def reduction(name, dtype, src_dtype, reduction_type, index, value):
9991145

10001146
@staticmethod
10011147
def constant(val, dtype):
1148+
current_node: torch.fx.Node = V.interpreter.current_node
1149+
current_node.meta["dtype"] = dtype
1150+
i32_iinfo = numpy.iinfo(numpy.int32)
1151+
if (
1152+
dtype == torch.int64
1153+
and val <= i32_iinfo.max
1154+
and val >= i32_iinfo.min
1155+
):
1156+
current_node.meta["dtype"] = torch.int32
1157+
f64_iinfo = numpy.finfo(numpy.float32)
1158+
if (
1159+
dtype == torch.double
1160+
and val <= f64_iinfo.max
1161+
and val >= f64_iinfo.min
1162+
):
1163+
current_node.meta["dtype"] = torch.float32
1164+
10021165
supported_dtype = (torch.float32, torch.int32)
1003-
is_supported_dtype = dtype in (supported_dtype)
1166+
is_supported_dtype = current_node.meta["dtype"] in (supported_dtype)
10041167
if not is_supported_dtype:
10051168
self.simd_vec = False
10061169
return is_supported_dtype
10071170

10081171
@staticmethod
10091172
def index_expr(expr, dtype):
1010-
self.simd_vec = False
1173+
current_node: torch.fx.Node = V.interpreter.current_node
1174+
1175+
loop_range = {}
1176+
assert len(self.ranges) == len(self.itervars)
1177+
for idx in range(len(self.ranges)):
1178+
loop_range[self.itervars[idx]] = self.ranges[idx]
1179+
expr_val = sympy.simplify(sympy_subs(expr, loop_range))
1180+
i32_iinfo = numpy.iinfo(numpy.int32)
1181+
if (
1182+
dtype == torch.int64
1183+
and expr_val <= i32_iinfo.max
1184+
and expr_val >= i32_iinfo.min
1185+
):
1186+
current_node.meta["dtype"] = torch.int32
1187+
else:
1188+
self.simd_vec = False
1189+
1190+
# Pick the most inner loop variable since we always vectorize the
1191+
# most inner loop
1192+
most_inner_var = self.itervars[-1]
1193+
most_inner_loop_irrevelant = self.is_var_irrevelant(
1194+
most_inner_var, expr
1195+
)
1196+
if not most_inner_loop_irrevelant:
1197+
self.simd_vec = False
1198+
current_node.meta[
1199+
"most_inner_loop_irrevelant"
1200+
] = most_inner_loop_irrevelant
1201+
10111202
tmp_var = self.cse.newvar()
10121203
return tmp_var
10131204

@@ -1018,11 +1209,24 @@ def indirect_indexing(index_var):
10181209

10191210
@staticmethod
10201211
def masked(mask, body, other):
1212+
current_node: torch.fx.Node = V.interpreter.current_node
1213+
is_masked_load, load_dtype = self.is_load_only_block(body.graph)
1214+
current_node.meta["dtype"] = load_dtype
1215+
current_node.meta["is_masked_load"] = is_masked_load
1216+
1217+
self.simd_vec = is_masked_load and current_node.meta["dtype"] in [
1218+
torch.float32,
1219+
torch.float,
1220+
]
1221+
10211222
tmp_var = self.cse.newvar()
10221223
return tmp_var
10231224

10241225
@staticmethod
10251226
def to_dtype(x, dtype):
1227+
current_node: torch.fx.Node = V.interpreter.current_node
1228+
current_node["dtype"] = dtype
1229+
10261230
if dtype != torch.bool:
10271231
self.simd_vec = False
10281232
return x

torch/_inductor/codegen/cpp_prefix.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,27 @@ void flag_to_float(const T* src, float* dst, int64_t n) {
6969
dst_u32[i] = *(src + i) ? 0xFFFFFFFF : 0;
7070
}
7171
}
72+
73+
#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2)
74+
template <typename SRC>
75+
inline at::vec::Vectorized<float> flag_to_float_vec(
76+
at::vec::Vectorized<SRC>& src) {
77+
assert(
78+
at::vec::Vectorized<float>::size() == at::vec::Vectorized<SRC>::size());
79+
at::vec::Vectorized<float> res_vec(0);
80+
#pragma unroll
81+
for (int i = 0; i < at::vec::Vectorized<float>::size(); i++) {
82+
res_vec[i] = src[i] ? 0xFFFFFFFF : 0;
83+
}
84+
}
85+
86+
template <>
87+
inline at::vec::Vectorized<float> flag_to_float_vec(
88+
at::vec::Vectorized<int>& src) {
89+
#if defined(CPU_CAPABILITY_AVX2)
90+
return at::vec::Vectorized<float>(_mm256_cvtepi32_ps(src));
91+
#else
92+
return at::vec::Vectorized<float>(_mm512_cvtepi32_ps(src));
93+
#endif
94+
}
95+
#endif

0 commit comments

Comments
 (0)