-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[SDPA] Call _sdp_attention in nn.functional.mha #89470
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/89470
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 41b3f91: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
32b4332 to
603616c
Compare
| // Scale q,k before matmul for stability see https://tinyurl.com/sudb9s96 for math | ||
| const double scaling_factor = ::sqrt(::sqrt(static_cast<double>(embed_size))); | ||
| const auto embed_size = SymFloat(query_.sym_size(-1)); | ||
| // const double scaling_factor = ::sqrt(::sqrt(static_cast<double>(embed_size))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO Different options I was working through, need to remove before landing
but curious if this is the best way to do things
| return os; | ||
| } | ||
|
|
||
| SymFloat SymFloat::sqrt() const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs double checking and placement guidance
| attn_output_weights = softmax(attn_output_weights, dim=-1) | ||
| if dropout_p > 0.0: | ||
| attn_output_weights = dropout(attn_output_weights, p=dropout_p) | ||
| if attn_mask.size(0) == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hacky.. the API between SDP vs the nn.funcitonal.mha has a decent amount of impedance mismatch requiring all this transposing and viewing fluff. Ideally this would also work with nested tensors out of the box but need to do some a once over of this forward understand the gap
|
@BowenBao @abock Sorry for pinging you directly but I am getting a test failure for: Which can be found on the hud: https://hud.pytorch.org/pr/89470 I am not sure what are the next steps to enable ONNX support. I tried reading through the wiki but didn't find anything very fruitful. Any guidance would be much appreciated |
603616c to
22acd66
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TIL we have a c10::Join
c10/core/SymFloat.h
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you tried writing something like:
namespace std {
template <>
SymFloat sqrt<SymFloat>(const SymFloat& self) {
...
}
}
Not sure if this would work exactly, but maybe we can get std::sqrt() to Just Work on symfloats, so we don't need any code changes in the future
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't seem to match the function template specialization to a function template. Tried for both std:: and c10::complex
|
@mikekgfb has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
torch/nn/functional.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The transpose() calls necessary to pull this off might cost us significant overheads.
There's an impedance mismatch not just between functional.MHA and sdp_attention, but also between nn.MHA and functional.MHA.
nn.MHA treats batch_first as preferred format, but functional.MHA appears to only support not batch_first. I see two options:
1 - pass in the bacth_first variable (but that's somewhat non-preferred because we might end up with control flow as a function of an input, the input being batch_first.)
2 - have an inverted polarity MHA which prefers batch_first and use that one internally, then have nn.MHA and legacy compatibility functional.MHA call the new implementation - with a possible path to deprecating the old interface eventually)
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
@cpuhrsch has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
e10a728 to
d62aa79
Compare
|
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
1 similar comment
|
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Summary: Replaces the the inline block of code in nn.funcitonal.mha with `_scaled_dot_product_attention`. This function allows the fused kernels to be called if all the required input conditions are met. cc VitalyFedyunin ngimel X-link: pytorch/pytorch#89470 Reviewed By: cpuhrsch Differential Revision: D41625335 Pulled By: drisspg fbshipit-source-id: dcf79e9d51d1bb0b1649f6bf0a8b0e2869170874
Summary: Replaces the the inline block of code in nn.funcitonal.mha with `_scaled_dot_product_attention`. This function allows the fused kernels to be called if all the required input conditions are met. cc VitalyFedyunin ngimel Pull Request resolved: pytorch#89470 Reviewed By: cpuhrsch Differential Revision: D41625335 Pulled By: drisspg fbshipit-source-id: 460b5cadfbd2e8f8a21fb46bce92fe831984ee02
858d7d4 to
f4690a6
Compare
|
This pull request was exported from Phabricator. Differential Revision: D41625335 |
Summary: Pull Request resolved: pytorch#6038 Replaces the the inline block of code in nn.funcitonal.mha with `_scaled_dot_product_attention`. This function allows the fused kernels to be called if all the required input conditions are met. cc VitalyFedyunin ngimel X-link: pytorch/pytorch#89470 Reviewed By: cpuhrsch Differential Revision: D41625335 Pulled By: drisspg fbshipit-source-id: 7bfe67cbc52d545faa0eefa7600f39a1685d01e4
|
This pull request was exported from Phabricator. Differential Revision: D41625335 |
f4690a6 to
dce73fa
Compare
Summary: X-link: pytorch/glow#6038 Replaces the the inline block of code in nn.funcitonal.mha with `_scaled_dot_product_attention`. This function allows the fused kernels to be called if all the required input conditions are met. cc VitalyFedyunin ngimel Pull Request resolved: pytorch#89470 Reviewed By: cpuhrsch Differential Revision: D41625335 Pulled By: drisspg fbshipit-source-id: cd7d010a6c325618e0df9dd75246a291451c8021
Summary: X-link: pytorch/glow#6038 Replaces the the inline block of code in nn.funcitonal.mha with `_scaled_dot_product_attention`. This function allows the fused kernels to be called if all the required input conditions are met. cc VitalyFedyunin ngimel Pull Request resolved: pytorch#89470 Reviewed By: mostafaelhoushi, cpuhrsch Differential Revision: D41625335 Pulled By: drisspg fbshipit-source-id: 1723c11739fc73963bd9be8dc04f45e5abda79c0
Summary: Pull Request resolved: pytorch#6038 Replaces the the inline block of code in nn.funcitonal.mha with `_scaled_dot_product_attention`. This function allows the fused kernels to be called if all the required input conditions are met. cc VitalyFedyunin ngimel X-link: pytorch/pytorch#89470 Reviewed By: mostafaelhoushi, cpuhrsch Differential Revision: D41625335 Pulled By: drisspg fbshipit-source-id: 44947eb40c48a8530d84155c0b11020278155bd8
|
This pull request was exported from Phabricator. Differential Revision: D41625335 |
dce73fa to
41b3f91
Compare
Summary: Pull Request resolved: #6038 Replaces the the inline block of code in nn.funcitonal.mha with `_scaled_dot_product_attention`. This function allows the fused kernels to be called if all the required input conditions are met. cc VitalyFedyunin ngimel X-link: pytorch/pytorch#89470 Reviewed By: mostafaelhoushi, cpuhrsch Differential Revision: D41625335 Pulled By: drisspg fbshipit-source-id: c3ce8e1fbec25af249e6c8c8cda3086fdddaf558
|
@pytorchbot merge -f 'Landed internally' (Initiating merge automatically since Phabricator Diff has merged, using force because this PR might not pass merge_rules.json but landed internally) |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
# Summary Replaces the the inline block of code in nn.funcitonal.mha with `_scaled_dot_product_attention`. This function allows the fused kernels to be called if all the required input conditions are met. Pull Request resolved: pytorch#89470 Approved by: https://github.com/cpuhrsch, https://github.com/mikekgfb
…orch#89847)" This reverts commit b9afa92. Reverted pytorch#89847 on behalf of https://github.com/jeanschmidt due to Need to revert this commit as it is causing conflict when reverting pytorch#89470
This reverts commit 4d7ec30. Reverted pytorch#89470 on behalf of https://github.com/jeanschmidt due to breaking internal builds
# Summary Replaces the the inline block of code in nn.funcitonal.mha with `_scaled_dot_product_attention`. This function allows the fused kernels to be called if all the required input conditions are met. Pull Request resolved: pytorch#89470 Approved by: https://github.com/cpuhrsch, https://github.com/mikekgfb
Summary
Replaces the the inline block of code in nn.funcitonal.mha with
_scaled_dot_product_attention. This function allows the fused kernels to be called if all the required input conditions are met.cc @VitalyFedyunin @ngimel