Conversation
| # `flash-attn` | ||
| FLASH = "flash" | ||
| FLASH_VARLEN = "flash_varlen" | ||
| FLASH_HUB = "flash_hub" |
There was a problem hiding this comment.
Flash Attention is stable. So, we don't have to mark it private like FA3.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
MekkCyber
left a comment
There was a problem hiding this comment.
Very cool integration 🔥 ! I just left some nits
| fa3_interface_hub = _get_fa3_from_hub() | ||
| flash_attn_3_func_hub = fa3_interface_hub.flash_attn_func | ||
| fa_interface_hub = _get_fa_from_hub() | ||
| flash_attn_func_hub = fa_interface_hub.flash_attn_func |
There was a problem hiding this comment.
Why are we fetching both kernels here ?
There was a problem hiding this comment.
Because of the way APIs for attention backends are designed and also to support torch.compile with fullgraph traceability (when possible).
We will let it grow a bit and upon feedback, we can revisit how to better deal with this.
| FLASH = "flash" | ||
| FLASH_VARLEN = "flash_varlen" | ||
| FLASH_HUB = "flash_hub" | ||
| # FLASH_VARLEN_HUB = "flash_varlen_hub" # not supported yet. |
There was a problem hiding this comment.
is this related to the kernel or it just needs more time to be integrated ?
There was a problem hiding this comment.
We don't have models that use varlen.
There was a problem hiding this comment.
@sayakpaul qwen image uses varlen. also, native fused qkv+mlp attn requires varlen function.
|
@DN6 a gentle ping on this one. |
| raise | ||
|
|
||
|
|
||
| def _get_fa3_from_hub(): |
There was a problem hiding this comment.
This is a very thin wrapper. I would just call _get_from_hub("fa3") directly in attention_dispatch.
|
|
||
|
|
||
| def _get_fa3_from_hub(): | ||
| def _get_from_hub(key: str): |
There was a problem hiding this comment.
| def _get_from_hub(key: str): | |
| def _get_kernel_from_hub(key: str): |
|
Closing for #12439. |
What does this PR do?
Follow-up of #12236.
Testing code:
Tip
Works with
torch.compilefullgraph compatibility.I have tested the code on H100 and A100, and it works.