-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[aotinductor] support _scaled_dot_product_flash_attention fallback #110003
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This PR supports _scaled_dot_product_flash_attention fallback kernel. Note that in the abi_compatible mode, we retrieve outputs by passing output argument pointers rather than relying on std::get. It also fixes an issue related to dynamic shapes, where we wrongfully query undefined dynamic symbols. [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110003
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 22 PendingAs of commit 58c4f39 with merge base e42d450 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR supports _scaled_dot_product_flash_attention fallback kernel. Note that in the abi_compatible mode, we retrieve outputs by passing output argument pointers rather than relying on std::get. It also fixes an issue related to dynamic shapes, where we wrongfully query undefined dynamic symbols. ghstack-source-id: b6fb281 Pull Request resolved: #110003
|
I think I will mostly defer to @desertfire here. Is there something specifically you want me to review? |
Since the PR touches the ABI part (shim.h and shim_common.cpp), I think you may want to take a look? Thanks! |
| AtenTensorHandle query, | ||
| AtenTensorHandle key, | ||
| AtenTensorHandle value, | ||
| ...); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variadic function in the ABI?! Seems suspicious. Tell me more?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, looks kinda suspicious. Maybe we need to define something like:
aoti_torch__scaled_dot_product_flash_attention_6()
aoti_torch__scaled_dot_product_flash_attention_7()
aoti_torch__scaled_dot_product_flash_attention_8()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Naively, I would have expected to take all optional arguments and force you to fill in the defaults.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.ops.aten._scaled_dot_product_flash_attention takes a couple of default arguments (dropout_p, is_causal, return_debug_mask and scale) that may not be presented in the IR. For example, given torch.ops.aten._scaled_dot_product_flash_attention(q, k, v), at https://github.com/pytorch/pytorch/blob/main/torch/_inductor/ir.py#L3853, we only see three args for q, k, v. We are missing all the default arguments. Consequently, at the codegen time, we don't know the default values for generating _scaled_doc_product_flash_attention fallback. (Question - is there any way that we could get the default argument values when we create IR for this FallbackKernel? )
Note that we could make default values for _scaled_doc_product_flash_attention by explicitly detecting the kernel name in the wrapper code, but I think it may be not general enough. For example, we might end up with a number of special rules for various fallback kernels in the wrapper.
Since C doesn't support default arguments, I went with the variadic-function solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ordered_kwargs_for_cpp_kernel should give you the default values for which is read from schema.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ordered_kwargs_for_cpp_kernelshould give you the default values for which is read from schema.
Hmm, we don't use schema for _scaled_dot_product_flash_attention. Do we want to use schema for it in a way similar to other kernels like ConvolutionUnary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, looks kinda suspicious. Maybe we need to define something like:
aoti_torch__scaled_dot_product_flash_attention_6() aoti_torch__scaled_dot_product_flash_attention_7() aoti_torch__scaled_dot_product_flash_attention_8()
Hmm, it's probably less ideal, because we would have a special rule for the scaled_dot_product_flash_attention fallback in the wrapper code to generate these.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Naively, I would have expected to take all optional arguments and force you to fill in the defaults.
Yeah, makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per discussion, replaced the variadic function with the one with default arguments being filled up. This change requires us to replace torch.ops.aten._scaled_dot_product_flash_attention with torch.nn.functional.scaled_dot_product_attention. Only the latter provides a schema that can be used to retrieve the default values for the optional arguments.
ezyang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no variadic functions allowed unless you convince me it's a good idea
| auto ret = at::_scaled_dot_product_flash_attention( | ||
| *query_tensor, | ||
| *key_tensor, | ||
| *value_tensor, | ||
| dropout_p, | ||
| is_causal, | ||
| return_debug_mask, | ||
| scale); | ||
|
|
||
| at::Tensor* ret0_tensor = new at::Tensor(std::move(std::get<0>(ret))); | ||
| *ret0 = tensor_pointer_to_tensor_handle(ret0_tensor); | ||
| at::Tensor* ret1_tensor = new at::Tensor(std::move(std::get<1>(ret))); | ||
| *ret1 = tensor_pointer_to_tensor_handle(ret1_tensor); | ||
| // ret2 and ret3 may be null | ||
| if (ret2) { | ||
| at::Tensor* ret2_tensor = new at::Tensor(std::move(std::get<2>(ret))); | ||
| *ret2 = tensor_pointer_to_tensor_handle(ret2_tensor); | ||
| } | ||
| if (ret3) { | ||
| at::Tensor* ret3_tensor = new at::Tensor(std::move(std::get<3>(ret))); | ||
| *ret3 = tensor_pointer_to_tensor_handle(ret3_tensor); | ||
| } | ||
| *ret4 = std::get<4>(ret); | ||
| *ret5 = std::get<5>(ret); | ||
| at::Tensor* ret6_tensor = new at::Tensor(std::move(std::get<6>(ret))); | ||
| *ret6 = tensor_pointer_to_tensor_handle(ret6_tensor); | ||
| at::Tensor* ret7_tensor = new at::Tensor(std::move(std::get<7>(ret))); | ||
| *ret7 = tensor_pointer_to_tensor_handle(ret7_tensor); | ||
| at::Tensor* ret8_tensor = new at::Tensor(std::move(std::get<8>(ret))); | ||
| *ret8 = tensor_pointer_to_tensor_handle(ret8_tensor); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For brevity we can use structured binding, e.g.:
auto [ret0, ret1, ret2, ret3, ret4, ret5, ret6, ret7, ret8] = at::_scaled_dot_product_flash_attention(
*query_tensor,
*key_tensor,
*value_tensor,
dropout_p,
is_causal,
return_debug_mask,
scale);
| double dropout_p = 0.0, | ||
| bool is_causal = false, | ||
| bool return_debug_mask = false, | ||
| c10::optional<double> scale = c10::nullopt) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it supposed to be C friendly API? Why c10::optional then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it supposed to be C friendly API? Why
c10::optionalthen?
We provide C interface. The implementation can be in C++.
| AtenTensorHandle* ret // returns new reference | ||
| ); | ||
|
|
||
| AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_dot_product_flash_attention( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we use it in this diff?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we use it in this diff?
Yes, we do.
| AtenTensorHandle query, | ||
| AtenTensorHandle key, | ||
| AtenTensorHandle value, | ||
| ...); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, looks kinda suspicious. Maybe we need to define something like:
aoti_torch__scaled_dot_product_flash_attention_6()
aoti_torch__scaled_dot_product_flash_attention_7()
aoti_torch__scaled_dot_product_flash_attention_8()
| self.first_device_guard = True | ||
| self.supports_intermediate_hooks = True | ||
| self.expr_printer = pexpr | ||
| self.defined_symbols = set() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add comments to explain what this is for
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still don't really understand the fix here, and I would like to understand it. Could you describe what goes wrong and how this fixes it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original issue showed up with a large model. Let me come up with a simple test to demonstrate it.
torch/_inductor/codegen/wrapper.py
Outdated
| def generate_c_shim_fallback_kernel_call(self, fallback_kernel, args): | ||
| output_args = [] | ||
| output_raii_handles = [] | ||
| output_name_base = fallback_kernel.get_name() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We probably need a counter to distinguish different fallback_kernel_calls?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We probably need a counter to distinguish different fallback_kernel_calls?
I think different fallback_kernel_calls have a unique buffer name, buf0, buf1, etc?
torch/_inductor/codegen/wrapper.py
Outdated
| self.generate_c_shim_extern_kernel_call(kernel, args) | ||
| if isinstance(extern_kernel, ir.FallbackKernel): | ||
| self.generate_c_shim_fallback_kernel_call(extern_kernel, args) | ||
| else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This path wasn't tested, and is a to-do item on my radar. You can just remove it, and change generate_c_shim_fallback_kernel_call to generate_c_shim_extern_kernel_alloc_call
| return f"reinterpret_tensor({', '.join(args)})" | ||
|
|
||
| def codegen_multi_output(self, name, value): | ||
| # if V.graph.aot_mode and name in set(V.graph.get_output_names()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this comment for? Consider adding a comment to explain what happens when config.aot_inductor.abi_compatible is True.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's accidentally left here. Will remove it.
| va_start(args, value); | ||
|
|
||
| double dropout_p = 0.0; | ||
| if (num_inputs >= 4) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This means we will run aoti_torch__scaled_dot_product_flash_attention_internal multiple times when we have num_inputs==7?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, good catch! Will fix it.
| int64_t* ret5, | ||
| AtenTensorHandle* ret6, // returns new reference | ||
| AtenTensorHandle* ret7, // returns new reference | ||
| AtenTensorHandle* ret8, // returns new reference |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like we have to put return values at the beginning of parameters because of varargs, so we should just revert #109834 for consistency. We can still change at this point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| }); | ||
| } | ||
|
|
||
| static AOTITorchError aoti_torch__scaled_dot_product_flash_attention_internal( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: move this to the anonymous namespace.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think static makes internal linkage. But Yes, I will move it to the anon namespace for consistency.
…lback" This PR supports _scaled_dot_product_flash_attention fallback kernel. Note that in the abi_compatible mode, we retrieve outputs by passing output argument pointers rather than relying on std::get. It also fixes an issue related to dynamic shapes, where we wrongfully query undefined dynamic symbols. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 kadeng muchulee8 aakhundov [ghstack-poisoned]
This PR supports _scaled_dot_product_flash_attention fallback kernel. Note that in the abi_compatible mode, we retrieve outputs by passing output argument pointers rather than relying on std::get. It also fixes an issue related to dynamic shapes, where we wrongfully query undefined dynamic symbols. ghstack-source-id: 3c51dab Pull Request resolved: #110003
| return_debug_mask, | ||
| scale); | ||
|
|
||
| at::Tensor* ret0_tensor = new at::Tensor(std::move(r0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know I have been using this in other places, but now seems a good time to write a utility function for this new + tensor_pointer_to_tensor_handle pattern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, will do it in a follow-up PR. Thanks.
| for dim, shape in enumerate(shapes): | ||
| shape = V.graph.sizevars.simplify(shape) | ||
| if shape in needed: | ||
| self.defined_symbols.add(shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems wrong. shape might be s0+s1+1.
Perhaps you want .free_symbols?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shape is guarded by needed, which is needed = V.graph.sizevars.free_symbols() at line 576 above, so we won't generate bad code for non-symbol expressions such as s0 + s1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am worried we will miss the symbols though.
s0+s1 in {s0} will return false and defined symbols will be wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| AtenTensorHandle* ret2, // returns new reference | ||
| AtenTensorHandle* ret3, // returns new reference | ||
| int64_t* ret4, | ||
| int64_t* ret5, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dynamic shapes?!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those two fields stand for max_seqlen_batch_q and max_seqlen_batch_kv, which are returned from the sdpa kernel:
pytorch/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp
Lines 245 to 246 in 36eb1bb
| max_seqlen_batch_q, | |
| max_seqlen_batch_kv, |
I think they can be dynamic given they are related to sequence length and batch?
| int64_t* ret5, | ||
| AtenTensorHandle* ret6, // returns new reference | ||
| AtenTensorHandle* ret7, // returns new reference | ||
| AtenTensorHandle* ret8 // returns new reference |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is perf sensitive, you're going to want an out variant of this kernel at some point
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good point. Thanks.
| *ret0 = tensor_pointer_to_tensor_handle(ret0_tensor); | ||
| at::Tensor* ret1_tensor = new at::Tensor(std::move(r1)); | ||
| *ret1 = tensor_pointer_to_tensor_handle(ret1_tensor); | ||
| // ret2 and ret3 may be null |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
huh? Shouldn't we guarantee we call this with valid pointers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are for cases where the corresponding outputs are None (https://github.com/pytorch/pytorch/blob/main/torch/_inductor/ir.py#L3883). In such cases, I chose to pass nullptr, because I think it seems to match the python IR more closely? We could pass in a dummy valid pointer though. I am fine with either way. What's your preference? Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So are you saying something like, because we specialize on things being None or not, at compile time we KNOW whether or not ret2/ret3 return None, and so if they are returning None we don't have to pass in a pointer because we know it's always going to be none? I guess this is the variadic thing all over again; technically, we could have just had separate overloads for if ret2/ret3 return None or not.
I guess this as is is fine, but I'd like it if we asserted that ret2_tensor was in fact undefined if ret2 was null.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So are you saying something like, because we specialize on things being None or not, at compile time we KNOW whether or not ret2/ret3 return None, and so if they are returning None we don't have to pass in a pointer because we know it's always going to be none?
Yes, this is exactly the case.
|
see #110527 |
Got it. Thanks for addressing this. |
|
We are going to leave this in but our expectation is for dense tensor the int returns will never actually get used |
Question if we wanted to change this for dense tensors, and land something like this: #110546 (after a rebase and some error fixes). Would that be a problem, FC/BC wise? |
|
not for aoti as this isn't deployed yet |
|
Closing this in favor of #110085 |
Stack from ghstack (oldest at bottom):
This PR supports _scaled_dot_product_flash_attention fallback kernel.
Note that in the abi_compatible mode, we retrieve outputs by passing
output argument pointers rather than relying on std::get.
It also fixes an issue related to dynamic shapes, where we wrongfully
query undefined dynamic symbols.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @kadeng @muchulee8 @aakhundov