Skip to content

Conversation

@imzhuhl
Copy link
Contributor

@imzhuhl imzhuhl commented Oct 30, 2023

Fixes #112380

According to the issue, I try to add mask support to the flash attn structure on CPU.

First mask is supported by using cum_seq_[q/kv], so the code for flashAttentionKernel is modified.
Secondly, an interface needs to be added to allow users to call, I choose to overload the _scaled_dot_product_flash_attention function.

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 30, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/112381

Note: Links to docs will display an error until the docs builds have been completed.

❌ 17 New Failures

As of commit a50f8dd with merge base 0d669f0 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@github-actions github-actions bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Oct 30, 2023
@github-actions
Copy link
Contributor

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@zou3519 zou3519 requested a review from drisspg October 30, 2023 14:39
@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 30, 2023
@drisspg
Copy link
Contributor

drisspg commented Oct 30, 2023

I think we want to do this, we need to properly spit the APIs I.e. add this: #110546

And this and then you would add an impl to the "nested" registration for CPU

cc @cpuhrsch

@imzhuhl
Copy link
Contributor Author

imzhuhl commented Oct 31, 2023

I think we want to do this, we need to properly spit the APIs I.e. add this: #110546

And this and then you would add an impl to the "nested" registration for CPU

cc @cpuhrsch

So we expect pytorch to have two different flash attn APIs. I think I should commit the changes after #110546 merges, Is that okay?

@jgong5 jgong5 requested a review from Valentine233 October 31, 2023 09:50
@albanD albanD removed their request for review October 31, 2023 15:27
@Valentine233
Copy link
Collaborator

Valentine233 commented Nov 1, 2023

For the support of attention mask, I suppose that using cum_seq_q or cum_seq_k is not appropriate because they do not mean the same thing as attention mask.

Maybe it's better to implement _scaled_dot_product_efficient_attention CPU version as it has attn_bias as input? There are some early discussions here #103826 (comment). I could help do this if needed.

cc @jgong5

@jgong5
Copy link
Collaborator

jgong5 commented Nov 1, 2023

For the support of attention mask, I suppose that using cum_seq_q or cum_seq_k is not appropriate because they do not mean the same thing as attention mask.

Maybe it's better to implement _scaled_dot_product_efficient_attention CPU version as it has attn_bias as input? There are some early discussions here #103826. I could help do this if needed.

cc @jgong5

Agreed with @Valentine233 . Overloading the meaning of cum_seq_q or cum_seq_k for attention mask doesn't seem the right way to support the mask. Since the algorithms for flash attention and efficient attention are very similar, I don't see the benefit of implement another version for efficient attention on CPU. Would it be a viable option to provide an overload for _scaled_dot_product_flash_attention to support attention mask with an added arg for attn_mask on CPU? @drisspg

@Valentine233
Copy link
Collaborator

Hi @drisspg, do you have any opinion on the API to support CPU flash attention with mask?

@drisspg
Copy link
Contributor

drisspg commented Nov 6, 2023

Hey, so like @jgong5 said I think it doesn't really make sense too add a new backend if we are going to be modifying the existing kernel code.

If it makes it easier we can decouple the signature of the kernel from cpu and cuda. We should probably register a new stem function could and call it "sdpa_fused_attention" or something where we only register a cpu backend.

I think this would also clean up the meta registration for sdpa as we won't be forcing the cpu code to abide by the constraints of the cuda code

@github-actions
Copy link
Contributor

github-actions bot commented Jan 5, 2024

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Jan 5, 2024
@drisspg
Copy link
Contributor

drisspg commented Jan 5, 2024

Closing in favor of: #115913

@drisspg drisspg closed this Jan 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: cpu CPU specific problem (e.g., perf, algorithm) open source Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Need support CPU flash attention with mask

6 participants