Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14349,14 +14349,14 @@
variants: function
tags: nondeterministic_seeded

- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
dispatch:
CPU: _scaled_dot_product_flash_attention_cpu
CUDA: _scaled_dot_product_flash_attention_cuda
NestedTensorCUDA: _scaled_dot_product_flash_attention_nestedtensor_cuda
tags: nondeterministic_seeded

- func: _scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
- func: _scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
device_check: NoCheck
variants: function
dispatch:
Expand All @@ -14375,13 +14375,13 @@
CUDA: _scaled_dot_product_efficient_attention_backward_cuda
tags: nondeterministic_seeded

- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, int? max_q, int? max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt? max_q, SymInt? max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
variants: function
dispatch:
CUDA: _flash_attention_forward
tags: nondeterministic_seeded

- func: _flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor, Tensor, Tensor)
- func: _flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor, Tensor, Tensor)
device_check: NoCheck
variants: function
dispatch:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ std::tuple<
Tensor,
Tensor,
Tensor,
int64_t,
int64_t,
c10::SymInt,
c10::SymInt,
Tensor,
Tensor,
Tensor>
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/transformers/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -744,8 +744,8 @@ std::tuple<
at::Tensor,
at::Tensor,
at::Tensor,
int64_t,
int64_t,
c10::SymInt,
c10::SymInt,
at::Tensor,
at::Tensor,
at::Tensor>
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/transformers/cuda/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_cuda(
}
return std::make_tuple(std::move(proj), std::move(qkt));
}
std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t, int64_t, Tensor, Tensor, Tensor> _scaled_dot_product_flash_attention_cuda(
std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt, Tensor, Tensor, Tensor> _scaled_dot_product_flash_attention_cuda(
const Tensor& query,
const Tensor& key,
const Tensor& value,
Expand Down
48 changes: 48 additions & 0 deletions test/inductor/test_cuda_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@

import torch
import torch._dynamo.config as dynamo_config
import torch.backends.cuda
import torch.nn.functional as F
from torch import nn
from torch._dynamo.debug_utils import same_two_models
from torch._dynamo.testing import rand_strided
from torch._dynamo.utils import same
from torch._inductor import config
from torch._inductor.compile_fx import compile_fx_inner
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
from torch.testing._internal.common_utils import (
DeterministicGuard,
freeze_rng_state,
Expand Down Expand Up @@ -982,6 +985,51 @@ def fn(x, y, z):

self.assertEqual(ref, res)

@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "flash attention not supported"
)
def test_flash_attention_dynamic(self):
class Model(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

self.q = nn.Linear(1024, 1024)
self.k = nn.Linear(1024, 1024)
self.v = nn.Linear(1024, 1024)

def forward(self, x):
batch_size, seq_len, _ = x.size()

queries = self.q(x).view(batch_size, seq_len, 8, 128).transpose(2, 1)
keys = self.k(x).view(batch_size, seq_len, 8, 128).transpose(2, 1)
values = self.v(x).view(batch_size, seq_len, 8, 128).transpose(2, 1)

attn = F.scaled_dot_product_attention(
queries,
keys,
values,
)

return attn

cnts = torch._dynamo.testing.CompileCounterWithBackend("inductor")

model = Model().cuda().half()
model = torch.compile(model, backend=cnts, dynamic=True)

with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=False, enable_mem_efficient=False
):
input1 = torch.rand(5, 512, 1024, device="cuda", dtype=torch.float16)
input2 = torch.rand(5, 513, 1024, device="cuda", dtype=torch.float16)
input3 = torch.rand(5, 514, 1024, device="cuda", dtype=torch.float16)

out1 = model(input1)
out2 = model(input2)
out3 = model(input3)

self.assertEqual(cnts.frame_count, 1)

@config.patch({"triton.cudagraphs": True})
def test_index_put_no_fallback_cudagraph(self):
def fn(x, y, z):
Expand Down
4 changes: 2 additions & 2 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2764,9 +2764,9 @@
output_differentiability: [True, False, False, False]
query, key, value, attn_bias: _scaled_dot_product_efficient_attention_backward(grad, query, key, value, attn_bias, output, log_sumexp, philox_seed, philox_offset, dropout_p, grad_input_mask, is_causal, scale)

- name: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
- name: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
output_differentiability: [True, False, False, False, False, False, False, False, False]
query, key, value: _scaled_dot_product_flash_attention_backward(grad, query, key, value, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)
query, key, value: _scaled_dot_product_flash_attention_backward_symint(grad, query, key, value, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)

# - name: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, int? max_q, int? max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor query_padded, Tensor key_padded, Tensor value_padded, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
# output_differentiability: [True, False, False, False, False, False, False, False]
Expand Down
2 changes: 1 addition & 1 deletion torch/_C/return_types.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ from typing import (
Union,
)

from torch import contiguous_format, Generator, inf, memory_format, strided, Tensor
from torch import contiguous_format, Generator, inf, memory_format, strided, Tensor, SymInt
from torch.types import (
_bool,
_device,
Expand Down
6 changes: 5 additions & 1 deletion torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -4037,8 +4037,12 @@ def generate_output(output, indices):
)
elif isinstance(output, int):
return output
elif isinstance(output, torch.SymInt):
return output.node.expr
else:
assert output is None, "FallbackKernel output type is not supported"
assert (
output is None
), f"FallbackKernel output type {type(output)} is not supported"
return None

outputs = generate_output(example_output, [])
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/inductor/aoti_torch/shim_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,8 @@ AOTITorchError aoti_torch__scaled_dot_product_flash_attention(
at::Tensor* ret3_tensor = new at::Tensor(std::move(r3));
*ret3 = tensor_pointer_to_tensor_handle(ret3_tensor);
}
*ret4 = r4;
*ret5 = r5;
*ret4 = r4.expect_int();
*ret5 = r5.expect_int();
at::Tensor* ret6_tensor = new at::Tensor(std::move(r6));
*ret6 = tensor_pointer_to_tensor_handle(ret6_tensor);
at::Tensor* ret7_tensor = new at::Tensor(std::move(r7));
Expand Down
2 changes: 1 addition & 1 deletion torchgen/api/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -1129,7 +1129,7 @@ def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument:
"::std::tuple<at::Tensor,::std::vector<at::Tensor>>",
"::std::vector<at::Tensor>",
# Needed for flash attention forw/backward
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,int64_t,int64_t,at::Tensor,at::Tensor,at::Tensor>",
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,c10::SymInt,c10::SymInt,at::Tensor,at::Tensor,at::Tensor>",
"at::Scalar",
"bool",
"int64_t",
Expand Down