-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[SDPA] Use scaled_dot_product_attention within attention.cpp #87312
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SDPA] Use scaled_dot_product_attention within attention.cpp #87312
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87312
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 1 PendingAs of commit 100c3c4: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
|
||
| const auto dim_per_head = D / num_head; | ||
|
|
||
| if (query.is_same(key) && key.is_same(value) && !need_weights) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By using qkv_projection below we might be able to broaden the applicability of _scaled_dot_product_attention to just !need_weights
dec4bf4 to
596e945
Compare
| chunks[0] = (chunks[0].view({x_size_0, -1, num_head, sdp_dim_per_head})) | ||
| .transpose(1, 2); | ||
| chunks[1] = (chunks[1].view({x_size_0, -1, num_head, sdp_dim_per_head})) | ||
| .transpose(1, 2); | ||
| chunks[2] = (chunks[2].view({x_size_0, -1, num_head, sdp_dim_per_head})) | ||
| .transpose(1, 2); | ||
|
|
||
| auto y = at::_scaled_dot_product_attention( | ||
| chunks[0], chunks[1], chunks[2], mask, 0.0, need_weights, false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that the last implementation from @danthe3rd allows to pass the tensor without having to reshape anything, avoiding a call to contiguous both before and after the operator. This can provide significant speedups btw
|
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
5f42406 to
d64515c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could also error out here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that's a good idea for now if there's no risk we'll take this path from attention.cpp
e5a3ce9 to
7f206e1
Compare
| #if BETTER_TRANSFORMER_USE_FLASH_ATTENTION | ||
| } | ||
| #endif | ||
| x = std::get<0>(at::_native_multi_head_attention( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this needs to be here, before the avoided dispatch but the native_multiheaded_attention worked for all tensors subclasses so now using the dispatcher to correctly dispatch to cpu or cuda
danthe3rd
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if we need some more tests for this specific behavior (eg one of the nested tensor with seqlen=1)
| tensor_stride_ptr[(i - 1) * tensor_stride_0]; | ||
| // TODO: When 0 seq_len nested tensors are allowed we need to guard against this | ||
| int64_t previous_numel = tensor_size_ptr[(i - 1) * tensor_stride_0] * tensor_stride_ptr[(i - 1) * tensor_stride_0]; | ||
| int64_t current_offset_constant = (tensor_offsets[i] - tensor_offsets[i - 1]) / previous_numel; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add a check that (tensor_offsets[i] - tensor_offsets[i - 1]) % previous_numel == 0 before dividing here - and also before the loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes we very much do but spoke to christian and this PR is essentially neuter the SDP inclusion in native mha and then in a follow up PR expand scope
029ab31 to
6c6a8ac
Compare
|
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
1 similar comment
|
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
4c9c905 to
18dc15c
Compare
619f511 to
0aef791
Compare
test/test_native_mha.py
Outdated
| need_weights=need_weights, | ||
| average_attn_weights=average_attn_weights, | ||
| ) | ||
| def test_native_multihead_self_attention(self, device, dtype, use_nt, need_weights, average_attn_weights, use_padding=False, pad_all=False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test is currently failing for the cases when it can run fused sdp, need_weights = false
test/test_transformers.py
Outdated
| math_ref_test = math_ref_test.to(dtype=torch.float32).contiguous() | ||
| math_ref_lp_test = math_ref_lp_test.to(dtype=torch.float32).contiguous() | ||
|
|
||
| self.assertEqual(math_ref_test, math_ref_lp_test, atol=4e-1, rtol=4e-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comparison is for the math_ref run on fp32 to math_ref on fp16. Using this to define a reasonable epsilon for fp16 to fp32 comparisons.
The second assert compares fused_sdp_fp16 vs math_ref_32 and ensures that it is within the same bounds as math_ref compares.
Also note that we are scaling up the uniform distribution to between (-10, 10)
|
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
1 similar comment
|
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
101e478 to
100c3c4
Compare
|
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
Hey @drisspg. |
# Summary Use the private _scaled_dot_product_attention to support _native_multiheaded_attention. _SDP provides access to fused kernels when certain conditions are meant enabling a speed up for MHA. cc @cpuhrsch @jbschlosser @bhosmer @mikaylagawarecki Pull Request resolved: pytorch#87312 Approved by: https://github.com/cpuhrsch
# Summary Use the private _scaled_dot_product_attention to support _native_multiheaded_attention. _SDP provides access to fused kernels when certain conditions are meant enabling a speed up for MHA. cc @cpuhrsch @jbschlosser @bhosmer @mikaylagawarecki Pull Request resolved: pytorch#87312 Approved by: https://github.com/cpuhrsch
Summary
Use the private _scaled_dot_product_attention to support _native_multiheaded_attention. _SDP provides access to fused kernels when certain conditions are meant enabling a speed up for MHA.
cc @cpuhrsch @jbschlosser @bhosmer @mikaylagawarecki