Skip to content

Conversation

@drisspg
Copy link
Contributor

@drisspg drisspg commented Nov 22, 2022

Summary

Replaces the the inline block of code in nn.funcitonal.mha with _scaled_dot_product_attention. This function allows the fused kernels to be called if all the required input conditions are met.

cc @VitalyFedyunin @ngimel

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 22, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/89470

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 41b3f91:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@drisspg drisspg force-pushed the call__sdp_in_mha_functional branch 2 times, most recently from 32b4332 to 603616c Compare November 23, 2022 22:04
@drisspg drisspg marked this pull request as ready for review November 23, 2022 22:36
@drisspg drisspg requested a review from cpuhrsch November 23, 2022 22:36
// Scale q,k before matmul for stability see https://tinyurl.com/sudb9s96 for math
const double scaling_factor = ::sqrt(::sqrt(static_cast<double>(embed_size)));
const auto embed_size = SymFloat(query_.sym_size(-1));
// const double scaling_factor = ::sqrt(::sqrt(static_cast<double>(embed_size)));
Copy link
Contributor Author

@drisspg drisspg Nov 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO Different options I was working through, need to remove before landing
but curious if this is the best way to do things

return os;
}

SymFloat SymFloat::sqrt() const {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs double checking and placement guidance

attn_output_weights = softmax(attn_output_weights, dim=-1)
if dropout_p > 0.0:
attn_output_weights = dropout(attn_output_weights, p=dropout_p)
if attn_mask.size(0) == 1:
Copy link
Contributor Author

@drisspg drisspg Nov 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hacky.. the API between SDP vs the nn.funcitonal.mha has a decent amount of impedance mismatch requiring all this transposing and viewing fluff. Ideally this would also work with nested tensors out of the box but need to do some a once over of this forward understand the gap

@drisspg
Copy link
Contributor Author

drisspg commented Nov 23, 2022

@BowenBao @abock Sorry for pinging you directly but I am getting a test failure for:

test/onnx/test_models_onnxruntime.py::TestModelsONNXRuntime_is_script_False::test_transformer_encoder - 
torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::_scaled_dot_product_attention' to ONNX opset version 14 is not supported. 
Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues.

Which can be found on the hud: https://hud.pytorch.org/pr/89470

I am not sure what are the next steps to enable ONNX support. I tried reading through the wiki but didn't find anything very fruitful. Any guidance would be much appreciated

@drisspg drisspg added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 24, 2022
@drisspg drisspg force-pushed the call__sdp_in_mha_functional branch from 603616c to 22acd66 Compare November 28, 2022 17:47
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL we have a c10::Join

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tried writing something like:

namespace std {
  template <>
  SymFloat sqrt<SymFloat>(const SymFloat& self) {
    ...
  }
}

Not sure if this would work exactly, but maybe we can get std::sqrt() to Just Work on symfloats, so we don't need any code changes in the future

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't seem to match the function template specialization to a function template. Tried for both std:: and c10::complex

@facebook-github-bot
Copy link
Contributor

@mikekgfb has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@drisspg drisspg added module: performance Issues related to performance, either of kernel code or framework glue better-engineering Relatively self-contained tasks for better engineering contributors labels Nov 28, 2022
@pytorch-bot pytorch-bot bot added the release notes: onnx torch.onnx related changes that should show up in the release notes label Nov 28, 2022
Copy link
Contributor

@mikekgfb mikekgfb Nov 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The transpose() calls necessary to pull this off might cost us significant overheads.

There's an impedance mismatch not just between functional.MHA and sdp_attention, but also between nn.MHA and functional.MHA.

nn.MHA treats batch_first as preferred format, but functional.MHA appears to only support not batch_first. I see two options:
1 - pass in the bacth_first variable (but that's somewhat non-preferred because we might end up with control flow as a function of an input, the input being batch_first.)
2 - have an inverted polarity MHA which prefers batch_first and use that one internally, then have nn.MHA and legacy compatibility functional.MHA call the new implementation - with a possible path to deprecating the old interface eventually)

@drisspg
Copy link
Contributor Author

drisspg commented Nov 28, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot
Copy link
Contributor

@cpuhrsch has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@albanD albanD removed their request for review November 30, 2022 17:45
@drisspg drisspg force-pushed the call__sdp_in_mha_functional branch from e10a728 to d62aa79 Compare November 30, 2022 19:26
@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

1 similar comment
@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

drisspg added a commit to drisspg/glow that referenced this pull request Nov 30, 2022
Summary:
Replaces the the inline block of code in nn.funcitonal.mha with `_scaled_dot_product_attention`. This function allows the fused kernels to be called if all the required input conditions are met.

cc VitalyFedyunin ngimel

X-link: pytorch/pytorch#89470

Reviewed By: cpuhrsch

Differential Revision: D41625335

Pulled By: drisspg

fbshipit-source-id: dcf79e9d51d1bb0b1649f6bf0a8b0e2869170874
drisspg added a commit to drisspg/pytorch that referenced this pull request Nov 30, 2022
Summary:
Replaces the the inline block of code in nn.funcitonal.mha with `_scaled_dot_product_attention`. This function allows the fused kernels to be called if all the required input conditions are met.

cc VitalyFedyunin ngimel

Pull Request resolved: pytorch#89470

Reviewed By: cpuhrsch

Differential Revision: D41625335

Pulled By: drisspg

fbshipit-source-id: 460b5cadfbd2e8f8a21fb46bce92fe831984ee02
@drisspg drisspg force-pushed the call__sdp_in_mha_functional branch from 858d7d4 to f4690a6 Compare November 30, 2022 21:32
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D41625335

drisspg added a commit to drisspg/glow that referenced this pull request Dec 1, 2022
Summary:
Pull Request resolved: pytorch#6038

Replaces the the inline block of code in nn.funcitonal.mha with `_scaled_dot_product_attention`. This function allows the fused kernels to be called if all the required input conditions are met.

cc VitalyFedyunin ngimel

X-link: pytorch/pytorch#89470

Reviewed By: cpuhrsch

Differential Revision: D41625335

Pulled By: drisspg

fbshipit-source-id: 7bfe67cbc52d545faa0eefa7600f39a1685d01e4
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D41625335

@drisspg drisspg force-pushed the call__sdp_in_mha_functional branch from f4690a6 to dce73fa Compare December 1, 2022 04:00
drisspg added a commit to drisspg/pytorch that referenced this pull request Dec 1, 2022
Summary:
X-link: pytorch/glow#6038

Replaces the the inline block of code in nn.funcitonal.mha with `_scaled_dot_product_attention`. This function allows the fused kernels to be called if all the required input conditions are met.

cc VitalyFedyunin ngimel

Pull Request resolved: pytorch#89470

Reviewed By: cpuhrsch

Differential Revision: D41625335

Pulled By: drisspg

fbshipit-source-id: cd7d010a6c325618e0df9dd75246a291451c8021
Summary:
X-link: pytorch/glow#6038

Replaces the the inline block of code in nn.funcitonal.mha with `_scaled_dot_product_attention`. This function allows the fused kernels to be called if all the required input conditions are met.

cc VitalyFedyunin ngimel

Pull Request resolved: pytorch#89470

Reviewed By: mostafaelhoushi, cpuhrsch

Differential Revision: D41625335

Pulled By: drisspg

fbshipit-source-id: 1723c11739fc73963bd9be8dc04f45e5abda79c0
drisspg added a commit to drisspg/glow that referenced this pull request Dec 2, 2022
Summary:
Pull Request resolved: pytorch#6038

Replaces the the inline block of code in nn.funcitonal.mha with `_scaled_dot_product_attention`. This function allows the fused kernels to be called if all the required input conditions are met.

cc VitalyFedyunin ngimel

X-link: pytorch/pytorch#89470

Reviewed By: mostafaelhoushi, cpuhrsch

Differential Revision: D41625335

Pulled By: drisspg

fbshipit-source-id: 44947eb40c48a8530d84155c0b11020278155bd8
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D41625335

@drisspg drisspg force-pushed the call__sdp_in_mha_functional branch from dce73fa to 41b3f91 Compare December 2, 2022 16:47
facebook-github-bot pushed a commit to pytorch/glow that referenced this pull request Dec 2, 2022
Summary:
Pull Request resolved: #6038

Replaces the the inline block of code in nn.funcitonal.mha with `_scaled_dot_product_attention`. This function allows the fused kernels to be called if all the required input conditions are met.

cc VitalyFedyunin ngimel

X-link: pytorch/pytorch#89470

Reviewed By: mostafaelhoushi, cpuhrsch

Differential Revision: D41625335

Pulled By: drisspg

fbshipit-source-id: c3ce8e1fbec25af249e6c8c8cda3086fdddaf558
@facebook-github-bot
Copy link
Contributor

@pytorchbot merge -f 'Landed internally'

(Initiating merge automatically since Phabricator Diff has merged, using force because this PR might not pass merge_rules.json but landed internally)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
# Summary
Replaces the the inline block of code in nn.funcitonal.mha with `_scaled_dot_product_attention`. This function allows the fused kernels to be called if all the required input conditions are met.

Pull Request resolved: pytorch#89470
Approved by: https://github.com/cpuhrsch, https://github.com/mikekgfb
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
…orch#89847)"

This reverts commit b9afa92.

Reverted pytorch#89847 on behalf of https://github.com/jeanschmidt due to Need to revert this commit as it is causing conflict when reverting pytorch#89470
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
This reverts commit 4d7ec30.

Reverted pytorch#89470 on behalf of https://github.com/jeanschmidt due to breaking internal builds
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
# Summary
Replaces the the inline block of code in nn.funcitonal.mha with `_scaled_dot_product_attention`. This function allows the fused kernels to be called if all the required input conditions are met.

Pull Request resolved: pytorch#89470
Approved by: https://github.com/cpuhrsch, https://github.com/mikekgfb
@drisspg drisspg changed the title Call _sdp_attention in nn.functional.mha [SDPA] Call _sdp_attention in nn.functional.mha Jan 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

better-engineering Relatively self-contained tasks for better engineering contributors ciflow/trunk Trigger trunk jobs on your pull request Merged module: performance Issues related to performance, either of kernel code or framework glue release notes: onnx torch.onnx related changes that should show up in the release notes Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants