Skip to content

Commit 49d63a7

Browse files
author
Yang Chen
committed
[inductor] support _scaled_dot_product_flash_attention fallback
This PR supports _scaled_dot_product_flash_attention fallback kernel. Note that in the abi_compatible mode, we retrieve outputs by passing output argument pointers rather than relying on std::get. It also fixes an issue related to dynamic shapes, where we wrongfully query undefined dynamic symbols. ghstack-source-id: 3c51dab Pull Request resolved: #110003
1 parent e42d450 commit 49d63a7

File tree

5 files changed

+184
-25
lines changed

5 files changed

+184
-25
lines changed

test/inductor/test_aot_inductor.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -538,14 +538,49 @@ def forward(self, x, y):
538538
constraints=constraints,
539539
)
540540

541+
# scaled_dot_product_flash_attention
542+
def test_sdpa(self):
543+
class Repro(torch.nn.Module):
544+
def __init__(self):
545+
super().__init__()
546+
547+
def forward(self, q, k, v):
548+
return torch.nn.functional.scaled_dot_product_attention(q, k, v)[0]
549+
550+
example_inputs = (
551+
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device="cuda"),
552+
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device="cuda"),
553+
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device="cuda"),
554+
)
555+
self.check_model(Repro(), example_inputs)
556+
557+
def test_sdpa_2(self):
558+
class Repro(torch.nn.Module):
559+
def __init__(self):
560+
super().__init__()
561+
562+
def forward(self, q, k, v, x):
563+
t = torch.nn.functional.scaled_dot_product_attention(
564+
q, k, v, is_causal=True
565+
)[0]
566+
return x + t
567+
568+
example_inputs = (
569+
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device="cuda"),
570+
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device="cuda"),
571+
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device="cuda"),
572+
torch.randn(1, 48, 64, 64, dtype=torch.bfloat16, device="cuda"),
573+
)
574+
self.check_model(Repro(), example_inputs)
575+
541576

542-
class AOTInductorTestABICompatibile(TestCase):
577+
class AOTInductorTestABICompatible(TestCase):
543578
abi_compatible = True
544579
check_model = check_model
545580
check_model_with_multiple_inputs = check_model_with_multiple_inputs
546581

547582

548-
copy_tests(AOTInductorTestsTemplate, AOTInductorTestABICompatibile, "abi_compatible")
583+
copy_tests(AOTInductorTestsTemplate, AOTInductorTestABICompatible, "abi_compatible")
549584

550585

551586
class AOTInductorTestNonABICompatible(TestCase):

torch/_inductor/codegen/wrapper.py

Lines changed: 58 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,11 @@ def __init__(self):
290290
self.first_device_guard = True
291291
self.supports_intermediate_hooks = True
292292
self.expr_printer = pexpr
293+
# Not all the dynamic symbols will be used in the generated code. This
294+
# set contains those actually being defined by something like
295+
# "{self.declare_shape} s0 = ...". It ensures that we are not going to
296+
# emit queries for undefined symbols.
297+
self.defined_symbols = set()
293298

294299
self.write_header()
295300
self.write_prefix()
@@ -581,6 +586,7 @@ def is_expr(x):
581586
for name, shape in graph_inputs_expr:
582587
shape = V.graph.sizevars.simplify(shape)
583588
if shape in needed:
589+
self.defined_symbols.add(shape)
584590
needed.remove(shape)
585591
code.writeline(f"{self.declare}{shape} = {name}{self.ending}")
586592

@@ -589,6 +595,7 @@ def is_expr(x):
589595
for dim, shape in enumerate(shapes):
590596
shape = V.graph.sizevars.simplify(shape)
591597
if shape in needed:
598+
self.defined_symbols.add(shape)
592599
needed.remove(shape)
593600
code.writeline(
594601
f"{self.declare}{shape} = {sizeof(name)}[{dim}]{self.ending}"
@@ -599,6 +606,7 @@ def is_expr(x):
599606
for dim, shape in enumerate(shapes):
600607
shape = V.graph.sizevars.simplify(shape)
601608
if shape in needed:
609+
self.defined_symbols.add(shape)
602610
needed.remove(shape)
603611
code.writeline(
604612
f"{self.declare}{shape} = {strideof(name)}[{dim}]{self.ending}"
@@ -617,7 +625,7 @@ def codegen_python_sizevar(self, x: Expr) -> str:
617625
def codegen_sizevar(self, x: Expr) -> str:
618626
return self.codegen_python_sizevar(x)
619627

620-
def codegen_tuple_access(self, basename: str, index: str) -> str:
628+
def codegen_tuple_access(self, basename: str, name: str, index: str) -> str:
621629
return f"{basename}[{index}]"
622630

623631
def codegen_python_shape_tuple(self, shape: Tuple[Expr, ...]) -> str:
@@ -637,6 +645,9 @@ def codegen_reinterpret_view(self, name, size, stride, offset, writer) -> str:
637645
offset = self.codegen_sizevar(offset)
638646
return f"reinterpret_tensor({name}, {size}, {stride}, {offset})"
639647

648+
def codegen_multi_output(self, name, value):
649+
self.writeline(f"{self.declare}{name} = {value}{self.ending}")
650+
640651
def benchmark_compiled_module(self, output):
641652
def add_fake_input(name, shape, stride, device, dtype):
642653
output.writeline(
@@ -1182,7 +1193,11 @@ def write_wrapper_decl(self):
11821193

11831194
if V.graph.aot_mode:
11841195
self.prefix.writeline("inputs.clear();")
1185-
dynamic_symbols = V.graph.sizevars.free_symbols()
1196+
dynamic_symbols = [
1197+
s
1198+
for s in V.graph.sizevars.free_symbols()
1199+
if s in self.defined_symbols
1200+
]
11861201
for dim in dynamic_symbols:
11871202
self.prefix.writeline(
11881203
f'auto dim_{dim} = find_dynamic_dim("{dim}");'
@@ -1402,21 +1417,38 @@ def generate_c_shim_extern_kernel_call(self, kernel, args):
14021417
kernel = "aoti_torch_" + kernel.split("::")[-1]
14031418
self.writeline(f"AOTI_TORCH_ERROR_CODE_CHECK({kernel}({', '.join(args)}));")
14041419

1420+
def generate_c_shim_extern_kernel_alloc_call(self, extern_kernel, args):
1421+
output_args = []
1422+
output_raii_handles = []
1423+
output_name_base = extern_kernel.get_name()
1424+
for idx, output in enumerate(extern_kernel.outputs):
1425+
if isinstance(output, ir.MultiOutput):
1426+
name = f"{output.get_name()}"
1427+
output_handle_name = f"{name}_handle"
1428+
assert (
1429+
output.indices[0][1] == idx
1430+
), f"expected {output.indices[1]=} == {idx=} for {output_name_base=}"
1431+
self.writeline(f"AtenTensorHandle {output_handle_name};")
1432+
output_args.append(f"&{output_handle_name}")
1433+
output_raii_handles.append(
1434+
f"RAIIAtenTensorHandle {name}({output_handle_name});"
1435+
)
1436+
elif isinstance(output, int):
1437+
output_name = f"{output_name_base}_{idx}"
1438+
self.writeline(f"int64_t {output_name} = {output};")
1439+
output_args.append(f"&{output_name}")
1440+
elif output is None:
1441+
output_args.append("nullptr")
1442+
else:
1443+
raise NotImplementedError("unsupported type of {output=}")
1444+
args = args + output_args
1445+
self.generate_c_shim_extern_kernel_call(extern_kernel.kernel, args)
1446+
for raii_handle in output_raii_handles:
1447+
self.writeline(raii_handle)
1448+
14051449
def generate_extern_kernel_alloc(self, extern_kernel, args):
14061450
if V.graph.aot_mode and config.aot_inductor.abi_compatible:
1407-
output_name = extern_kernel.get_name()
1408-
self.writeline(f"AtenTensorHandle {output_name};")
1409-
kernel = extern_kernel.kernel
1410-
size = self.codegen_shape_tuple(tuple(extern_kernel.get_size()))
1411-
stride = self.codegen_shape_tuple(tuple(extern_kernel.get_stride()))
1412-
args = [
1413-
f"&{output_name}",
1414-
str(len(extern_kernel.get_size())), # ndim
1415-
self.codegen_int_array_var(size),
1416-
self.codegen_int_array_var(stride),
1417-
] + args
1418-
# TODO: support extern kernel that allocates
1419-
self.generate_c_shim_extern_kernel_call(kernel, args)
1451+
self.generate_c_shim_extern_kernel_alloc_call(extern_kernel, args)
14201452
else:
14211453
super().generate_extern_kernel_alloc(extern_kernel, args)
14221454

@@ -1461,8 +1493,12 @@ def add_benchmark_harness(self, output):
14611493
def codegen_sizevar(self, x: Expr) -> str:
14621494
return self.expr_printer(V.graph.sizevars.simplify(x))
14631495

1464-
def codegen_tuple_access(self, basename: str, index: str) -> str:
1465-
return f"std::get<{index}>({basename})"
1496+
def codegen_tuple_access(self, basename: str, name: str, index: str) -> str:
1497+
if V.graph.aot_mode and config.aot_inductor.abi_compatible:
1498+
# in the abi_compatible mode, outputs are returned via arguments
1499+
return name
1500+
else:
1501+
return f"std::get<{index}>({basename})"
14661502

14671503
def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str:
14681504
parts = list(map(self.codegen_sizevar, shape))
@@ -1584,6 +1620,11 @@ def codegen_reinterpret_view(self, name, size, stride, offset, writer) -> str:
15841620
args = [name, size, stride, offset]
15851621
return f"reinterpret_tensor({', '.join(args)})"
15861622

1623+
def codegen_multi_output(self, name, value):
1624+
# if V.graph.aot_mode and name in set(V.graph.get_output_names()):
1625+
if not config.aot_inductor.abi_compatible:
1626+
super().codegen_multi_output(name, value)
1627+
15871628
def generate_extern_kernel_args_decl_if_needed(
15881629
self, op_overload, raw_args, output_args
15891630
):

torch/_inductor/ir.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3627,6 +3627,10 @@ def __init__(
36273627
tuple(tensor_args),
36283628
tuple(nontensor_args),
36293629
)
3630+
# We need output buffers for generating kernel arguments in the
3631+
# abi-compatible mode, where we retrieve outputs by pass each individual
3632+
# output through the abi-compatible interface.
3633+
self.outputs = []
36303634
self.use_cpp_op_schema = False
36313635

36323636
self.op_overload = kernel
@@ -3879,7 +3883,8 @@ def generate_output(output, indices):
38793883
assert output is None, "FallbackKernel output type is not supported"
38803884
return None
38813885

3882-
return generate_output(example_output, [])
3886+
packed.outputs = generate_output(example_output, [])
3887+
return packed.outputs
38833888

38843889
def apply_constraint(self):
38853890
return super().apply_constraint()
@@ -3899,7 +3904,7 @@ def codegen_list_tuple_access(self, basename, indices):
38993904
elif itype == tuple:
39003905
# cpp wrapper code needs to use std::get<> to access a tuple
39013906
tuple_access = V.graph.wrapper_code.codegen_tuple_access(
3902-
basename, str(i)
3907+
basename, self.get_name(), str(i)
39033908
)
39043909
return self.codegen_list_tuple_access(tuple_access, indices[1:])
39053910
else:
@@ -3908,10 +3913,10 @@ def codegen_list_tuple_access(self, basename, indices):
39083913
return basename
39093914

39103915
def codegen(self, wrapper):
3911-
line = V.graph.wrapper_code.declare
3912-
line += f"{self.get_name()} = {self.codegen_list_tuple_access(self.inputs[0].get_name(), self.indices)}"
3913-
line += V.graph.wrapper_code.ending
3914-
V.graph.wrapper_code.writeline(line)
3916+
V.graph.wrapper_code.codegen_multi_output(
3917+
self.get_name(),
3918+
self.codegen_list_tuple_access(self.inputs[0].get_name(), self.indices),
3919+
)
39153920
self.codegen_size_asserts(V.graph.wrapper_code)
39163921

39173922
def __init__(self, layout, input, indices: List[Tuple[Any, ...]]):

torch/csrc/inductor/aoti_torch/c/shim.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,25 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob(
152152
AtenTensorHandle* ret // returns new reference
153153
);
154154

155+
AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_dot_product_flash_attention(
156+
AtenTensorHandle query,
157+
AtenTensorHandle key,
158+
AtenTensorHandle value,
159+
double dropout_p,
160+
bool is_causal,
161+
bool return_debug_mask,
162+
double scale,
163+
AtenTensorHandle* ret0, // returns new reference
164+
AtenTensorHandle* ret1, // returns new reference
165+
AtenTensorHandle* ret2, // returns new reference
166+
AtenTensorHandle* ret3, // returns new reference
167+
int64_t* ret4,
168+
int64_t* ret5,
169+
AtenTensorHandle* ret6, // returns new reference
170+
AtenTensorHandle* ret7, // returns new reference
171+
AtenTensorHandle* ret8 // returns new reference
172+
);
173+
155174
AOTI_TORCH_EXPORT AOTITorchError
156175
aoti_torch_tensor_copy_(AtenTensorHandle src, AtenTensorHandle dst);
157176

torch/csrc/inductor/aoti_torch/shim_common.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,19 @@
66
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
77
#include <torch/csrc/inductor/aoti_torch/utils.h>
88
#include <torch/csrc/inductor/inductor_ops.h>
9+
#include <cstdarg>
910
#include <cstdint>
1011
#include <cstdio>
1112
#include <iostream>
1213
#include <memory>
14+
#include <tuple>
1315

1416
#ifndef AT_PER_OPERATOR_HEADERS
1517
#include <ATen/Functions.h>
1618
#else
1719

1820
#include <ATen/ops/_addmm_activation.h>
21+
#include <ATen/ops/_scaled_dot_product_flash_attention.h>
1922
#include <ATen/ops/addmm.h>
2023
#include <ATen/ops/as_strided.h>
2124
#include <ATen/ops/bmm.h>
@@ -182,6 +185,62 @@ AOTITorchError aoti_torch_create_tensor_from_blob(
182185
});
183186
}
184187

188+
AOTITorchError aoti_torch__scaled_dot_product_flash_attention(
189+
AtenTensorHandle query,
190+
AtenTensorHandle key,
191+
AtenTensorHandle value,
192+
double dropout_p,
193+
bool is_causal,
194+
bool return_debug_mask,
195+
double scale,
196+
AtenTensorHandle* ret0, // returns new reference
197+
AtenTensorHandle* ret1, // returns new reference
198+
AtenTensorHandle* ret2, // returns new reference
199+
AtenTensorHandle* ret3, // returns new reference
200+
int64_t* ret4,
201+
int64_t* ret5,
202+
AtenTensorHandle* ret6, // returns new reference
203+
AtenTensorHandle* ret7, // returns new reference
204+
AtenTensorHandle* ret8 // returns new reference
205+
) {
206+
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
207+
at::Tensor* query_tensor = tensor_handle_to_tensor_pointer(query);
208+
at::Tensor* key_tensor = tensor_handle_to_tensor_pointer(key);
209+
at::Tensor* value_tensor = tensor_handle_to_tensor_pointer(value);
210+
auto [r0, r1, r2, r3, r4, r5, r6, r7, r8] =
211+
at::_scaled_dot_product_flash_attention(
212+
*query_tensor,
213+
*key_tensor,
214+
*value_tensor,
215+
dropout_p,
216+
is_causal,
217+
return_debug_mask,
218+
scale);
219+
220+
at::Tensor* ret0_tensor = new at::Tensor(std::move(r0));
221+
*ret0 = tensor_pointer_to_tensor_handle(ret0_tensor);
222+
at::Tensor* ret1_tensor = new at::Tensor(std::move(r1));
223+
*ret1 = tensor_pointer_to_tensor_handle(ret1_tensor);
224+
// ret2 and ret3 may be null
225+
if (ret2) {
226+
at::Tensor* ret2_tensor = new at::Tensor(std::move(r2));
227+
*ret2 = tensor_pointer_to_tensor_handle(ret2_tensor);
228+
}
229+
if (ret3) {
230+
at::Tensor* ret3_tensor = new at::Tensor(std::move(r3));
231+
*ret3 = tensor_pointer_to_tensor_handle(ret3_tensor);
232+
}
233+
*ret4 = r4;
234+
*ret5 = r5;
235+
at::Tensor* ret6_tensor = new at::Tensor(std::move(r6));
236+
*ret6 = tensor_pointer_to_tensor_handle(ret6_tensor);
237+
at::Tensor* ret7_tensor = new at::Tensor(std::move(r7));
238+
*ret7 = tensor_pointer_to_tensor_handle(ret7_tensor);
239+
at::Tensor* ret8_tensor = new at::Tensor(std::move(r8));
240+
*ret8 = tensor_pointer_to_tensor_handle(ret8_tensor);
241+
});
242+
}
243+
185244
// TODO: implement a more efficient version instead of calling into aten
186245
AOTITorchError aoti_torch_tensor_copy_(
187246
AtenTensorHandle src,

0 commit comments

Comments
 (0)