Commit 49d63a7
Yang Chen
[inductor] support _scaled_dot_product_flash_attention fallback
This PR supports _scaled_dot_product_flash_attention fallback kernel.
Note that in the abi_compatible mode, we retrieve outputs by passing
output argument pointers rather than relying on std::get.
It also fixes an issue related to dynamic shapes, where we wrongfully
query undefined dynamic symbols.
ghstack-source-id: 3c51dab
Pull Request resolved: #1100031 parent e42d450 commit 49d63a7
File tree
5 files changed
+184
-25
lines changed- test/inductor
- torch
- _inductor
- codegen
- csrc/inductor/aoti_torch
- c
5 files changed
+184
-25
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
538 | 538 | | |
539 | 539 | | |
540 | 540 | | |
| 541 | + | |
| 542 | + | |
| 543 | + | |
| 544 | + | |
| 545 | + | |
| 546 | + | |
| 547 | + | |
| 548 | + | |
| 549 | + | |
| 550 | + | |
| 551 | + | |
| 552 | + | |
| 553 | + | |
| 554 | + | |
| 555 | + | |
| 556 | + | |
| 557 | + | |
| 558 | + | |
| 559 | + | |
| 560 | + | |
| 561 | + | |
| 562 | + | |
| 563 | + | |
| 564 | + | |
| 565 | + | |
| 566 | + | |
| 567 | + | |
| 568 | + | |
| 569 | + | |
| 570 | + | |
| 571 | + | |
| 572 | + | |
| 573 | + | |
| 574 | + | |
| 575 | + | |
541 | 576 | | |
542 | | - | |
| 577 | + | |
543 | 578 | | |
544 | 579 | | |
545 | 580 | | |
546 | 581 | | |
547 | 582 | | |
548 | | - | |
| 583 | + | |
549 | 584 | | |
550 | 585 | | |
551 | 586 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
290 | 290 | | |
291 | 291 | | |
292 | 292 | | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
293 | 298 | | |
294 | 299 | | |
295 | 300 | | |
| |||
581 | 586 | | |
582 | 587 | | |
583 | 588 | | |
| 589 | + | |
584 | 590 | | |
585 | 591 | | |
586 | 592 | | |
| |||
589 | 595 | | |
590 | 596 | | |
591 | 597 | | |
| 598 | + | |
592 | 599 | | |
593 | 600 | | |
594 | 601 | | |
| |||
599 | 606 | | |
600 | 607 | | |
601 | 608 | | |
| 609 | + | |
602 | 610 | | |
603 | 611 | | |
604 | 612 | | |
| |||
617 | 625 | | |
618 | 626 | | |
619 | 627 | | |
620 | | - | |
| 628 | + | |
621 | 629 | | |
622 | 630 | | |
623 | 631 | | |
| |||
637 | 645 | | |
638 | 646 | | |
639 | 647 | | |
| 648 | + | |
| 649 | + | |
| 650 | + | |
640 | 651 | | |
641 | 652 | | |
642 | 653 | | |
| |||
1182 | 1193 | | |
1183 | 1194 | | |
1184 | 1195 | | |
1185 | | - | |
| 1196 | + | |
| 1197 | + | |
| 1198 | + | |
| 1199 | + | |
| 1200 | + | |
1186 | 1201 | | |
1187 | 1202 | | |
1188 | 1203 | | |
| |||
1402 | 1417 | | |
1403 | 1418 | | |
1404 | 1419 | | |
| 1420 | + | |
| 1421 | + | |
| 1422 | + | |
| 1423 | + | |
| 1424 | + | |
| 1425 | + | |
| 1426 | + | |
| 1427 | + | |
| 1428 | + | |
| 1429 | + | |
| 1430 | + | |
| 1431 | + | |
| 1432 | + | |
| 1433 | + | |
| 1434 | + | |
| 1435 | + | |
| 1436 | + | |
| 1437 | + | |
| 1438 | + | |
| 1439 | + | |
| 1440 | + | |
| 1441 | + | |
| 1442 | + | |
| 1443 | + | |
| 1444 | + | |
| 1445 | + | |
| 1446 | + | |
| 1447 | + | |
| 1448 | + | |
1405 | 1449 | | |
1406 | 1450 | | |
1407 | | - | |
1408 | | - | |
1409 | | - | |
1410 | | - | |
1411 | | - | |
1412 | | - | |
1413 | | - | |
1414 | | - | |
1415 | | - | |
1416 | | - | |
1417 | | - | |
1418 | | - | |
1419 | | - | |
| 1451 | + | |
1420 | 1452 | | |
1421 | 1453 | | |
1422 | 1454 | | |
| |||
1461 | 1493 | | |
1462 | 1494 | | |
1463 | 1495 | | |
1464 | | - | |
1465 | | - | |
| 1496 | + | |
| 1497 | + | |
| 1498 | + | |
| 1499 | + | |
| 1500 | + | |
| 1501 | + | |
1466 | 1502 | | |
1467 | 1503 | | |
1468 | 1504 | | |
| |||
1584 | 1620 | | |
1585 | 1621 | | |
1586 | 1622 | | |
| 1623 | + | |
| 1624 | + | |
| 1625 | + | |
| 1626 | + | |
| 1627 | + | |
1587 | 1628 | | |
1588 | 1629 | | |
1589 | 1630 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3627 | 3627 | | |
3628 | 3628 | | |
3629 | 3629 | | |
| 3630 | + | |
| 3631 | + | |
| 3632 | + | |
| 3633 | + | |
3630 | 3634 | | |
3631 | 3635 | | |
3632 | 3636 | | |
| |||
3879 | 3883 | | |
3880 | 3884 | | |
3881 | 3885 | | |
3882 | | - | |
| 3886 | + | |
| 3887 | + | |
3883 | 3888 | | |
3884 | 3889 | | |
3885 | 3890 | | |
| |||
3899 | 3904 | | |
3900 | 3905 | | |
3901 | 3906 | | |
3902 | | - | |
| 3907 | + | |
3903 | 3908 | | |
3904 | 3909 | | |
3905 | 3910 | | |
| |||
3908 | 3913 | | |
3909 | 3914 | | |
3910 | 3915 | | |
3911 | | - | |
3912 | | - | |
3913 | | - | |
3914 | | - | |
| 3916 | + | |
| 3917 | + | |
| 3918 | + | |
| 3919 | + | |
3915 | 3920 | | |
3916 | 3921 | | |
3917 | 3922 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
152 | 152 | | |
153 | 153 | | |
154 | 154 | | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
155 | 174 | | |
156 | 175 | | |
157 | 176 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
6 | 6 | | |
7 | 7 | | |
8 | 8 | | |
| 9 | + | |
9 | 10 | | |
10 | 11 | | |
11 | 12 | | |
12 | 13 | | |
| 14 | + | |
13 | 15 | | |
14 | 16 | | |
15 | 17 | | |
16 | 18 | | |
17 | 19 | | |
18 | 20 | | |
| 21 | + | |
19 | 22 | | |
20 | 23 | | |
21 | 24 | | |
| |||
182 | 185 | | |
183 | 186 | | |
184 | 187 | | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
185 | 244 | | |
186 | 245 | | |
187 | 246 | | |
| |||
0 commit comments