support batch size=0 for flash attention#166318
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/166318
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit e3e06d9 with merge base 7ce723d ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp
Outdated
Show resolved
Hide resolved
| const int seqlen_k = k.size(1); | ||
| const int num_heads_k = k.size(2); | ||
|
|
||
| if (batch_size == 0) { |
There was a problem hiding this comment.
@soulitzer what is the semantic for outputs/grads for empty are the 0s or empty?
There was a problem hiding this comment.
hmm probably not a huge difference when tensors are zero-numel
There was a problem hiding this comment.
ill change to empty_like since filling with zeros after isnt necessary
a94a66f to
5ec1bd0
Compare
aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp
Outdated
Show resolved
Hide resolved
|
@liangel-02 has imported this pull request. If you are a Meta employee, you can view this in D85592445. |
5ec1bd0 to
e3e06d9
Compare
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes #165944
Summary
Today, if we attempt to run flash attention with batch_size 0, we get error
Runtime Error: batch size must be positive. This PR fixes this by returning early with empty tensors in the fwd and bwd.Test plan
python test/test_transformers.py -k test_scaled_dot_product_attention- added case for batch_size=0