Skip to content

Commit 21d32be

Browse files
committed
Update on "Change flash attention outputs to be SymInt instead of int"
Signed-off-by: Edward Z. Yang <ezyangmeta.com> [ghstack-poisoned]
1 parent eec2383 commit 21d32be

File tree

2 files changed

+53
-1
lines changed

2 files changed

+53
-1
lines changed

test/inductor/test_cuda_repro.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55

66
import torch
77
import torch._dynamo.config as dynamo_config
8+
import torch.backends.cuda
9+
import torch.nn.functional as F
810
from torch import nn
911
from torch._dynamo.debug_utils import same_two_models
1012
from torch._dynamo.testing import rand_strided
1113
from torch._dynamo.utils import same
1214
from torch._inductor import config
1315
from torch._inductor.compile_fx import compile_fx_inner
1416
from torch.fx.experimental.proxy_tensor import make_fx
17+
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
1518
from torch.testing._internal.common_utils import (
1619
DeterministicGuard,
1720
freeze_rng_state,
@@ -982,6 +985,51 @@ def fn(x, y, z):
982985

983986
self.assertEqual(ref, res)
984987

988+
@unittest.skipIf(
989+
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "flash attention not supported"
990+
)
991+
def test_flash_attention_dynamic(self):
992+
class Model(nn.Module):
993+
def __init__(self, *args, **kwargs) -> None:
994+
super().__init__(*args, **kwargs)
995+
996+
self.q = nn.Linear(1024, 1024)
997+
self.k = nn.Linear(1024, 1024)
998+
self.v = nn.Linear(1024, 1024)
999+
1000+
def forward(self, x):
1001+
batch_size, seq_len, _ = x.size()
1002+
1003+
queries = self.q(x).view(batch_size, seq_len, 8, 128).transpose(2, 1)
1004+
keys = self.k(x).view(batch_size, seq_len, 8, 128).transpose(2, 1)
1005+
values = self.v(x).view(batch_size, seq_len, 8, 128).transpose(2, 1)
1006+
1007+
attn = F.scaled_dot_product_attention(
1008+
queries,
1009+
keys,
1010+
values,
1011+
)
1012+
1013+
return attn
1014+
1015+
cnts = torch._dynamo.testing.CompileCounterWithBackend("inductor")
1016+
1017+
model = Model().cuda().half()
1018+
model = torch.compile(model, backend=cnts, dynamic=True)
1019+
1020+
with torch.backends.cuda.sdp_kernel(
1021+
enable_flash=True, enable_math=False, enable_mem_efficient=False
1022+
):
1023+
input1 = torch.rand(5, 512, 1024, device="cuda", dtype=torch.float16)
1024+
input2 = torch.rand(5, 513, 1024, device="cuda", dtype=torch.float16)
1025+
input3 = torch.rand(5, 514, 1024, device="cuda", dtype=torch.float16)
1026+
1027+
out1 = model(input1)
1028+
out2 = model(input2)
1029+
out3 = model(input3)
1030+
1031+
self.assertEqual(cnts.frame_count, 1)
1032+
9851033
@config.patch({"triton.cudagraphs": True})
9861034
def test_index_put_no_fallback_cudagraph(self):
9871035
def fn(x, y, z):

torch/_inductor/ir.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4037,8 +4037,12 @@ def generate_output(output, indices):
40374037
)
40384038
elif isinstance(output, int):
40394039
return output
4040+
elif isinstance(output, torch.SymInt):
4041+
return output.node.expr
40404042
else:
4041-
assert output is None, "FallbackKernel output type is not supported"
4043+
assert (
4044+
output is None
4045+
), f"FallbackKernel output type {type(output)} is not supported"
40424046
return None
40434047

40444048
outputs = generate_output(example_output, [])

0 commit comments

Comments
 (0)