[core] support sage attention + FA2 through kernels#12439
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
MekkCyber
left a comment
There was a problem hiding this comment.
Very cool ! I will try to look into the torch compile compatibility, but for the other variants, they are the same as sageattn, what i mean is sageattn is just a wrapper that dispatches to the correct kernel depending on the hardware used : https://github.com/thu-ml/SageAttention/blob/main/sageattention/core.py#L140
So, you mean we shouldn't have to have different dispatched functions like this? |
|
Yes I think we don't need that because it depends on the hardware. For example if a user chooses : |
| _SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"] | ||
| _SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"] | ||
| _SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"] |
There was a problem hiding this comment.
I don't see their usage, hence removed.
|
FYI, I've ported SageAttention to Python stable ABI (ABI3) and libtorch stable ABI, which should simplify building for HF Kernels: There are also some refactors in my main branch to simplify building. If someone can maintain the build system, then I no longer need to maintain my repo :) |
|
This PR is ready to be reviewed now. As discussed with @MekkCyber over DMs, we're disabling In order for us to support it with I think we should be good with the PR. Cc: @MekkCyber @DN6 |
kernelskernels
|
@DN6 it should be up for another review. I have updated the test suite and ensured that they pass successfully as well. PTAL. |
DN6
left a comment
There was a problem hiding this comment.
Good to merge. But we need to remove the parallel config check and use supports_context_parallel=False
| return_lse: bool = False, | ||
| _parallel_config: Optional["ParallelConfig"] = None, | ||
| ) -> torch.Tensor: | ||
| if _parallel_config: |
There was a problem hiding this comment.
Use:
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
)It will raise an error when trying to enable parallelism with this backend. This check isn't needed.
What does this PR do?
Code to test (SAGE):
Result:

FA2:
Notes
torch.compilesupport when using sage attention like we have for flash and flash 3. Currently, this fails.Code to test
Error: https://pastebin.com/3HS6HNzR
sageattnvariants (see here), which would be cool to expose from the Hub kernel.Cc: @MekkCyber