Skip to content

Conversation

@drisspg
Copy link
Contributor

@drisspg drisspg commented Oct 19, 2022

Summary

Use the private _scaled_dot_product_attention to support _native_multiheaded_attention. _SDP provides access to fused kernels when certain conditions are meant enabling a speed up for MHA.

cc @cpuhrsch @jbschlosser @bhosmer @mikaylagawarecki

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 19, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/87312

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 1 Pending

As of commit 100c3c4:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.


const auto dim_per_head = D / num_head;

if (query.is_same(key) && key.is_same(value) && !need_weights) {
Copy link
Contributor

@cpuhrsch cpuhrsch Oct 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By using qkv_projection below we might be able to broaden the applicability of _scaled_dot_product_attention to just !need_weights

@drisspg drisspg added module: nestedtensor NestedTensor tag see issue #25032 ciflow/trunk Trigger trunk jobs on your pull request labels Oct 19, 2022
@drisspg drisspg force-pushed the Use_scaled_dot_product_attention_within_attention.cpp branch from dec4bf4 to 596e945 Compare October 20, 2022 16:07
Comment on lines 402 to 410
chunks[0] = (chunks[0].view({x_size_0, -1, num_head, sdp_dim_per_head}))
.transpose(1, 2);
chunks[1] = (chunks[1].view({x_size_0, -1, num_head, sdp_dim_per_head}))
.transpose(1, 2);
chunks[2] = (chunks[2].view({x_size_0, -1, num_head, sdp_dim_per_head}))
.transpose(1, 2);

auto y = at::_scaled_dot_product_attention(
chunks[0], chunks[1], chunks[2], mask, 0.0, need_weights, false);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the last implementation from @danthe3rd allows to pass the tensor without having to reshape anything, avoiding a call to contiguous both before and after the operator. This can provide significant speedups btw

@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@drisspg drisspg force-pushed the Use_scaled_dot_product_attention_within_attention.cpp branch 2 times, most recently from 5f42406 to d64515c Compare October 25, 2022 21:57
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also error out here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's a good idea for now if there's no risk we'll take this path from attention.cpp

@drisspg drisspg force-pushed the Use_scaled_dot_product_attention_within_attention.cpp branch from e5a3ce9 to 7f206e1 Compare October 25, 2022 22:40
#if BETTER_TRANSFORMER_USE_FLASH_ATTENTION
}
#endif
x = std::get<0>(at::_native_multi_head_attention(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs to be here, before the avoided dispatch but the native_multiheaded_attention worked for all tensors subclasses so now using the dispatcher to correctly dispatch to cpu or cuda

Copy link
Contributor

@danthe3rd danthe3rd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we need some more tests for this specific behavior (eg one of the nested tensor with seqlen=1)

tensor_stride_ptr[(i - 1) * tensor_stride_0];
// TODO: When 0 seq_len nested tensors are allowed we need to guard against this
int64_t previous_numel = tensor_size_ptr[(i - 1) * tensor_stride_0] * tensor_stride_ptr[(i - 1) * tensor_stride_0];
int64_t current_offset_constant = (tensor_offsets[i] - tensor_offsets[i - 1]) / previous_numel;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a check that (tensor_offsets[i] - tensor_offsets[i - 1]) % previous_numel == 0 before dividing here - and also before the loop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/drisspg/pytorch/blob/029ab31c0602683821ebb6d9ce78be1fa70770f7/aten/src/ATen/native/transformers/cuda/sdp_utils.h#L65

Yes we very much do but spoke to christian and this PR is essentially neuter the SDP inclusion in native mha and then in a follow up PR expand scope

@drisspg drisspg force-pushed the Use_scaled_dot_product_attention_within_attention.cpp branch from 029ab31 to 6c6a8ac Compare October 26, 2022 17:02
@pytorch pytorch deleted a comment from cpuhrsch Oct 26, 2022
@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

1 similar comment
@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@drisspg drisspg force-pushed the Use_scaled_dot_product_attention_within_attention.cpp branch from 4c9c905 to 18dc15c Compare October 27, 2022 22:30
@drisspg drisspg force-pushed the Use_scaled_dot_product_attention_within_attention.cpp branch from 619f511 to 0aef791 Compare October 28, 2022 17:25
need_weights=need_weights,
average_attn_weights=average_attn_weights,
)
def test_native_multihead_self_attention(self, device, dtype, use_nt, need_weights, average_attn_weights, use_padding=False, pad_all=False
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is currently failing for the cases when it can run fused sdp, need_weights = false

math_ref_test = math_ref_test.to(dtype=torch.float32).contiguous()
math_ref_lp_test = math_ref_lp_test.to(dtype=torch.float32).contiguous()

self.assertEqual(math_ref_test, math_ref_lp_test, atol=4e-1, rtol=4e-1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comparison is for the math_ref run on fp32 to math_ref on fp16. Using this to define a reasonable epsilon for fp16 to fp32 comparisons.

The second assert compares fused_sdp_fp16 vs math_ref_32 and ensures that it is within the same bounds as math_ref compares.

Also note that we are scaling up the uniform distribution to between (-10, 10)

@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

1 similar comment
@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@drisspg drisspg force-pushed the Use_scaled_dot_product_attention_within_attention.cpp branch from 101e478 to 100c3c4 Compare October 28, 2022 23:42
@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@drisspg
Copy link
Contributor Author

drisspg commented Oct 31, 2022

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions
Copy link
Contributor

Hey @drisspg.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Nov 5, 2022
# Summary
Use the private _scaled_dot_product_attention to support _native_multiheaded_attention. _SDP provides access to fused kernels when certain conditions are meant enabling a speed up for MHA.

cc @cpuhrsch @jbschlosser @bhosmer @mikaylagawarecki
Pull Request resolved: pytorch#87312
Approved by: https://github.com/cpuhrsch
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
# Summary
Use the private _scaled_dot_product_attention to support _native_multiheaded_attention. _SDP provides access to fused kernels when certain conditions are meant enabling a speed up for MHA.

cc @cpuhrsch @jbschlosser @bhosmer @mikaylagawarecki
Pull Request resolved: pytorch#87312
Approved by: https://github.com/cpuhrsch
@drisspg drisspg changed the title Use scaled_dot_product_attention within attention.cpp [SDPA] Use scaled_dot_product_attention within attention.cpp Jan 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: nestedtensor NestedTensor tag see issue #25032

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants