Skip to content

support batch size=0 for flash attention#166318

Closed
liangel-02 wants to merge 1 commit intomainfrom
sdpa-bs-zero
Closed

support batch size=0 for flash attention#166318
liangel-02 wants to merge 1 commit intomainfrom
sdpa-bs-zero

Conversation

@liangel-02
Copy link
Contributor

Fixes #165944

Summary

Today, if we attempt to run flash attention with batch_size 0, we get error Runtime Error: batch size must be positive. This PR fixes this by returning early with empty tensors in the fwd and bwd.

Test plan
python test/test_transformers.py -k test_scaled_dot_product_attention - added case for batch_size=0

@liangel-02 liangel-02 added topic: improvements topic category module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion labels Oct 27, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 27, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/166318

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit e3e06d9 with merge base 7ce723d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@liangel-02 liangel-02 added the release notes: nn release notes category label Oct 27, 2025
@liangel-02 liangel-02 requested a review from drisspg October 27, 2025 17:44
@liangel-02 liangel-02 marked this pull request as ready for review October 27, 2025 17:44
const int seqlen_k = k.size(1);
const int num_heads_k = k.size(2);

if (batch_size == 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@soulitzer what is the semantic for outputs/grads for empty are the 0s or empty?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm probably not a huge difference when tensors are zero-numel

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ill change to empty_like since filling with zeros after isnt necessary

@meta-codesync
Copy link

meta-codesync bot commented Oct 27, 2025

@liangel-02 has imported this pull request. If you are a Meta employee, you can view this in D85592445.

@liangel-02
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 28, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion release notes: nn release notes category topic: improvements topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support Batchsize 0 for flash attention by early return

5 participants