-
Notifications
You must be signed in to change notification settings - Fork 26.3k
CPU flash attention with mask #112381
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CPU flash attention with mask #112381
Conversation
This PR needs a
|
So we expect pytorch to have two different flash attn APIs. I think I should commit the changes after #110546 merges, Is that okay? |
|
For the support of attention mask, I suppose that using Maybe it's better to implement cc @jgong5 |
Agreed with @Valentine233 . Overloading the meaning of |
|
Hi @drisspg, do you have any opinion on the API to support CPU flash attention with mask? |
|
Hey, so like @jgong5 said I think it doesn't really make sense too add a new backend if we are going to be modifying the existing kernel code. If it makes it easier we can decouple the signature of the kernel from cpu and cuda. We should probably register a new stem function could and call it "sdpa_fused_attention" or something where we only register a cpu backend. I think this would also clean up the meta registration for sdpa as we won't be forcing the cpu code to abide by the constraints of the cuda code |
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
|
Closing in favor of: #115913 |
Fixes #112380
According to the issue, I try to add mask support to the flash attn structure on CPU.
First mask is supported by using cum_seq_[q/kv], so the code for flashAttentionKernel is modified.
Secondly, an interface needs to be added to allow users to call, I choose to overload the
_scaled_dot_product_flash_attentionfunction.cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10