|
5 | 5 |
|
6 | 6 | import torch |
7 | 7 | import torch._dynamo.config as dynamo_config |
| 8 | +import torch.backends.cuda |
| 9 | +import torch.nn.functional as F |
8 | 10 | from torch import nn |
9 | 11 | from torch._dynamo.debug_utils import same_two_models |
10 | 12 | from torch._dynamo.testing import rand_strided |
11 | 13 | from torch._dynamo.utils import same |
12 | 14 | from torch._inductor import config |
13 | 15 | from torch._inductor.compile_fx import compile_fx_inner |
14 | 16 | from torch.fx.experimental.proxy_tensor import make_fx |
| 17 | +from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION |
15 | 18 | from torch.testing._internal.common_utils import ( |
16 | 19 | DeterministicGuard, |
17 | 20 | freeze_rng_state, |
@@ -982,6 +985,51 @@ def fn(x, y, z): |
982 | 985 |
|
983 | 986 | self.assertEqual(ref, res) |
984 | 987 |
|
| 988 | + @unittest.skipIf( |
| 989 | + not PLATFORM_SUPPORTS_FLASH_ATTENTION, "flash attention not supported" |
| 990 | + ) |
| 991 | + def test_flash_attention_dynamic(self): |
| 992 | + class Model(nn.Module): |
| 993 | + def __init__(self, *args, **kwargs) -> None: |
| 994 | + super().__init__(*args, **kwargs) |
| 995 | + |
| 996 | + self.q = nn.Linear(1024, 1024) |
| 997 | + self.k = nn.Linear(1024, 1024) |
| 998 | + self.v = nn.Linear(1024, 1024) |
| 999 | + |
| 1000 | + def forward(self, x): |
| 1001 | + batch_size, seq_len, _ = x.size() |
| 1002 | + |
| 1003 | + queries = self.q(x).view(batch_size, seq_len, 8, 128).transpose(2, 1) |
| 1004 | + keys = self.k(x).view(batch_size, seq_len, 8, 128).transpose(2, 1) |
| 1005 | + values = self.v(x).view(batch_size, seq_len, 8, 128).transpose(2, 1) |
| 1006 | + |
| 1007 | + attn = F.scaled_dot_product_attention( |
| 1008 | + queries, |
| 1009 | + keys, |
| 1010 | + values, |
| 1011 | + ) |
| 1012 | + |
| 1013 | + return attn |
| 1014 | + |
| 1015 | + cnts = torch._dynamo.testing.CompileCounterWithBackend("inductor") |
| 1016 | + |
| 1017 | + model = Model().cuda().half() |
| 1018 | + model = torch.compile(model, backend=cnts, dynamic=True) |
| 1019 | + |
| 1020 | + with torch.backends.cuda.sdp_kernel( |
| 1021 | + enable_flash=True, enable_math=False, enable_mem_efficient=False |
| 1022 | + ): |
| 1023 | + input1 = torch.rand(5, 512, 1024, device="cuda", dtype=torch.float16) |
| 1024 | + input2 = torch.rand(5, 513, 1024, device="cuda", dtype=torch.float16) |
| 1025 | + input3 = torch.rand(5, 514, 1024, device="cuda", dtype=torch.float16) |
| 1026 | + |
| 1027 | + out1 = model(input1) |
| 1028 | + out2 = model(input2) |
| 1029 | + out3 = model(input3) |
| 1030 | + |
| 1031 | + self.assertEqual(cnts.frame_count, 1) |
| 1032 | + |
985 | 1033 | @config.patch({"triton.cudagraphs": True}) |
986 | 1034 | def test_index_put_no_fallback_cudagraph(self): |
987 | 1035 | def fn(x, y, z): |
|
0 commit comments