Skip to content

Conversation

@chenyang78
Copy link

@chenyang78 chenyang78 commented Sep 25, 2023

Stack from ghstack (oldest at bottom):

This PR supports _scaled_dot_product_flash_attention fallback kernel.
Note that in the abi_compatible mode, we retrieve outputs by passing
output argument pointers rather than relying on std::get.

It also fixes an issue related to dynamic shapes, where we wrongfully
query undefined dynamic symbols.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @kadeng @muchulee8 @aakhundov

This PR supports _scaled_dot_product_flash_attention fallback kernel.
Note that in the abi_compatible mode, we retrieve outputs by passing
output argument pointers rather than relying on std::get.

It also fixes an issue related to dynamic shapes, where we wrongfully
query undefined dynamic symbols.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 25, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110003

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 22 Pending

As of commit 58c4f39 with merge base e42d450 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

chenyang78 pushed a commit that referenced this pull request Sep 25, 2023
This PR supports _scaled_dot_product_flash_attention fallback kernel.
Note that in the abi_compatible mode, we retrieve outputs by passing
output argument pointers rather than relying on std::get.

It also fixes an issue related to dynamic shapes, where we wrongfully
query undefined dynamic symbols.

ghstack-source-id: b6fb281
Pull Request resolved: #110003
@ezyang
Copy link
Contributor

ezyang commented Sep 25, 2023

I think I will mostly defer to @desertfire here. Is there something specifically you want me to review?

@chenyang78
Copy link
Author

I think I will mostly defer to @desertfire here. Is there something specifically you want me to review?

Since the PR touches the ABI part (shim.h and shim_common.cpp), I think you may want to take a look? Thanks!

AtenTensorHandle query,
AtenTensorHandle key,
AtenTensorHandle value,
...);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variadic function in the ABI?! Seems suspicious. Tell me more?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, looks kinda suspicious. Maybe we need to define something like:

aoti_torch__scaled_dot_product_flash_attention_6()
aoti_torch__scaled_dot_product_flash_attention_7()
aoti_torch__scaled_dot_product_flash_attention_8()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naively, I would have expected to take all optional arguments and force you to fill in the defaults.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.ops.aten._scaled_dot_product_flash_attention takes a couple of default arguments (dropout_p, is_causal, return_debug_mask and scale) that may not be presented in the IR. For example, given torch.ops.aten._scaled_dot_product_flash_attention(q, k, v), at https://github.com/pytorch/pytorch/blob/main/torch/_inductor/ir.py#L3853, we only see three args for q, k, v. We are missing all the default arguments. Consequently, at the codegen time, we don't know the default values for generating _scaled_doc_product_flash_attention fallback. (Question - is there any way that we could get the default argument values when we create IR for this FallbackKernel? )

Note that we could make default values for _scaled_doc_product_flash_attention by explicitly detecting the kernel name in the wrapper code, but I think it may be not general enough. For example, we might end up with a number of special rules for various fallback kernels in the wrapper.

Since C doesn't support default arguments, I went with the variadic-function solution.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ordered_kwargs_for_cpp_kernel should give you the default values for which is read from schema.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ordered_kwargs_for_cpp_kernel should give you the default values for which is read from schema.

Hmm, we don't use schema for _scaled_dot_product_flash_attention. Do we want to use schema for it in a way similar to other kernels like ConvolutionUnary?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, looks kinda suspicious. Maybe we need to define something like:

aoti_torch__scaled_dot_product_flash_attention_6()
aoti_torch__scaled_dot_product_flash_attention_7()
aoti_torch__scaled_dot_product_flash_attention_8()

Hmm, it's probably less ideal, because we would have a special rule for the scaled_dot_product_flash_attention fallback in the wrapper code to generate these.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naively, I would have expected to take all optional arguments and force you to fill in the defaults.

Yeah, makes sense.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per discussion, replaced the variadic function with the one with default arguments being filled up. This change requires us to replace torch.ops.aten._scaled_dot_product_flash_attention with torch.nn.functional.scaled_dot_product_attention. Only the latter provides a schema that can be used to retrieve the default values for the optional arguments.

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no variadic functions allowed unless you convince me it's a good idea

Comment on lines 209 to 238
auto ret = at::_scaled_dot_product_flash_attention(
*query_tensor,
*key_tensor,
*value_tensor,
dropout_p,
is_causal,
return_debug_mask,
scale);

at::Tensor* ret0_tensor = new at::Tensor(std::move(std::get<0>(ret)));
*ret0 = tensor_pointer_to_tensor_handle(ret0_tensor);
at::Tensor* ret1_tensor = new at::Tensor(std::move(std::get<1>(ret)));
*ret1 = tensor_pointer_to_tensor_handle(ret1_tensor);
// ret2 and ret3 may be null
if (ret2) {
at::Tensor* ret2_tensor = new at::Tensor(std::move(std::get<2>(ret)));
*ret2 = tensor_pointer_to_tensor_handle(ret2_tensor);
}
if (ret3) {
at::Tensor* ret3_tensor = new at::Tensor(std::move(std::get<3>(ret)));
*ret3 = tensor_pointer_to_tensor_handle(ret3_tensor);
}
*ret4 = std::get<4>(ret);
*ret5 = std::get<5>(ret);
at::Tensor* ret6_tensor = new at::Tensor(std::move(std::get<6>(ret)));
*ret6 = tensor_pointer_to_tensor_handle(ret6_tensor);
at::Tensor* ret7_tensor = new at::Tensor(std::move(std::get<7>(ret)));
*ret7 = tensor_pointer_to_tensor_handle(ret7_tensor);
at::Tensor* ret8_tensor = new at::Tensor(std::move(std::get<8>(ret)));
*ret8 = tensor_pointer_to_tensor_handle(ret8_tensor);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For brevity we can use structured binding, e.g.:

auto [ret0, ret1, ret2, ret3, ret4, ret5, ret6, ret7, ret8] = at::_scaled_dot_product_flash_attention(
        *query_tensor,
        *key_tensor,
        *value_tensor,
        dropout_p,
        is_causal,
        return_debug_mask,
        scale);

double dropout_p = 0.0,
bool is_causal = false,
bool return_debug_mask = false,
c10::optional<double> scale = c10::nullopt) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it supposed to be C friendly API? Why c10::optional then?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it supposed to be C friendly API? Why c10::optional then?

We provide C interface. The implementation can be in C++.

AtenTensorHandle* ret // returns new reference
);

AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_dot_product_flash_attention(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we use it in this diff?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we use it in this diff?

Yes, we do.

AtenTensorHandle query,
AtenTensorHandle key,
AtenTensorHandle value,
...);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, looks kinda suspicious. Maybe we need to define something like:

aoti_torch__scaled_dot_product_flash_attention_6()
aoti_torch__scaled_dot_product_flash_attention_7()
aoti_torch__scaled_dot_product_flash_attention_8()

self.first_device_guard = True
self.supports_intermediate_hooks = True
self.expr_printer = pexpr
self.defined_symbols = set()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comments to explain what this is for

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't really understand the fix here, and I would like to understand it. Could you describe what goes wrong and how this fixes it?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original issue showed up with a large model. Let me come up with a simple test to demonstrate it.

def generate_c_shim_fallback_kernel_call(self, fallback_kernel, args):
output_args = []
output_raii_handles = []
output_name_base = fallback_kernel.get_name()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably need a counter to distinguish different fallback_kernel_calls?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably need a counter to distinguish different fallback_kernel_calls?

I think different fallback_kernel_calls have a unique buffer name, buf0, buf1, etc?

self.generate_c_shim_extern_kernel_call(kernel, args)
if isinstance(extern_kernel, ir.FallbackKernel):
self.generate_c_shim_fallback_kernel_call(extern_kernel, args)
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This path wasn't tested, and is a to-do item on my radar. You can just remove it, and change generate_c_shim_fallback_kernel_call to generate_c_shim_extern_kernel_alloc_call

return f"reinterpret_tensor({', '.join(args)})"

def codegen_multi_output(self, name, value):
# if V.graph.aot_mode and name in set(V.graph.get_output_names()):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this comment for? Consider adding a comment to explain what happens when config.aot_inductor.abi_compatible is True.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's accidentally left here. Will remove it.

va_start(args, value);

double dropout_p = 0.0;
if (num_inputs >= 4) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means we will run aoti_torch__scaled_dot_product_flash_attention_internal multiple times when we have num_inputs==7?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good catch! Will fix it.

int64_t* ret5,
AtenTensorHandle* ret6, // returns new reference
AtenTensorHandle* ret7, // returns new reference
AtenTensorHandle* ret8, // returns new reference
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we have to put return values at the beginning of parameters because of varargs, so we should just revert #109834 for consistency. We can still change at this point.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we have to put return values at the beginning of parameters because of varargs, so we should just revert #109834 for consistency. We can still change at this point.

I will revert #109834 once the variadic solution is accepted.

});
}

static AOTITorchError aoti_torch__scaled_dot_product_flash_attention_internal(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: move this to the anonymous namespace.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think static makes internal linkage. But Yes, I will move it to the anon namespace for consistency.

…lback"

This PR supports _scaled_dot_product_flash_attention fallback kernel.
Note that in the abi_compatible mode, we retrieve outputs by passing
output argument pointers rather than relying on std::get.

It also fixes an issue related to dynamic shapes, where we wrongfully
query undefined dynamic symbols.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 kadeng muchulee8 aakhundov

[ghstack-poisoned]
chenyang78 pushed a commit that referenced this pull request Sep 25, 2023
This PR supports _scaled_dot_product_flash_attention fallback kernel.
Note that in the abi_compatible mode, we retrieve outputs by passing
output argument pointers rather than relying on std::get.

It also fixes an issue related to dynamic shapes, where we wrongfully
query undefined dynamic symbols.

ghstack-source-id: 3c51dab
Pull Request resolved: #110003
return_debug_mask,
scale);

at::Tensor* ret0_tensor = new at::Tensor(std::move(r0));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know I have been using this in other places, but now seems a good time to write a utility function for this new + tensor_pointer_to_tensor_handle pattern.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, will do it in a follow-up PR. Thanks.

for dim, shape in enumerate(shapes):
shape = V.graph.sizevars.simplify(shape)
if shape in needed:
self.defined_symbols.add(shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems wrong. shape might be s0+s1+1.

Perhaps you want .free_symbols?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shape is guarded by needed, which is needed = V.graph.sizevars.free_symbols() at line 576 above, so we won't generate bad code for non-symbol expressions such as s0 + s1.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am worried we will miss the symbols though.

s0+s1 in {s0} will return false and defined symbols will be wrong.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jansel The issue is addressed in #110411, where I removed the source of generating such bad code. There was some historical reason why we chose the previous approach. I think the changes in #110411 make the implementation much simpler and less error-prone.

@ezyang ezyang changed the title [inductor] support _scaled_dot_product_flash_attention fallback [aotinductor] support _scaled_dot_product_flash_attention fallback Sep 27, 2023
AtenTensorHandle* ret2, // returns new reference
AtenTensorHandle* ret3, // returns new reference
int64_t* ret4,
int64_t* ret5,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dynamic shapes?!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those two fields stand for max_seqlen_batch_q and max_seqlen_batch_kv, which are returned from the sdpa kernel:

max_seqlen_batch_q,
max_seqlen_batch_kv,

I think they can be dynamic given they are related to sequence length and batch?

int64_t* ret5,
AtenTensorHandle* ret6, // returns new reference
AtenTensorHandle* ret7, // returns new reference
AtenTensorHandle* ret8 // returns new reference
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is perf sensitive, you're going to want an out variant of this kernel at some point

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good point. Thanks.

*ret0 = tensor_pointer_to_tensor_handle(ret0_tensor);
at::Tensor* ret1_tensor = new at::Tensor(std::move(r1));
*ret1 = tensor_pointer_to_tensor_handle(ret1_tensor);
// ret2 and ret3 may be null
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

huh? Shouldn't we guarantee we call this with valid pointers?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are for cases where the corresponding outputs are None (https://github.com/pytorch/pytorch/blob/main/torch/_inductor/ir.py#L3883). In such cases, I chose to pass nullptr, because I think it seems to match the python IR more closely? We could pass in a dummy valid pointer though. I am fine with either way. What's your preference? Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So are you saying something like, because we specialize on things being None or not, at compile time we KNOW whether or not ret2/ret3 return None, and so if they are returning None we don't have to pass in a pointer because we know it's always going to be none? I guess this is the variadic thing all over again; technically, we could have just had separate overloads for if ret2/ret3 return None or not.

I guess this as is is fine, but I'd like it if we asserted that ret2_tensor was in fact undefined if ret2 was null.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So are you saying something like, because we specialize on things being None or not, at compile time we KNOW whether or not ret2/ret3 return None, and so if they are returning None we don't have to pass in a pointer because we know it's always going to be none?

Yes, this is exactly the case.

@ezyang
Copy link
Contributor

ezyang commented Oct 4, 2023

I may moot the flash attention int returns, it looks to me that they are not actually dynamic. Confirming with @drisspg . See also #110322

@ezyang
Copy link
Contributor

ezyang commented Oct 4, 2023

see #110527

@chenyang78
Copy link
Author

see #110527

Got it. Thanks for addressing this.

@ezyang
Copy link
Contributor

ezyang commented Oct 6, 2023

We are going to leave this in but our expectation is for dense tensor the int returns will never actually get used

@drisspg
Copy link
Contributor

drisspg commented Oct 6, 2023

We are going to leave this in but our expectation is for dense tensor the int returns will never actually get used

Question if we wanted to change this for dense tensors, and land something like this: #110546 (after a rebase and some error fixes). Would that be a problem, FC/BC wise?

@ezyang
Copy link
Contributor

ezyang commented Oct 6, 2023

not for aoti as this isn't deployed yet

@chenyang78
Copy link
Author

Closing this in favor of #110085

@chenyang78 chenyang78 closed this Nov 17, 2023
@facebook-github-bot facebook-github-bot deleted the gh/chenyang78/2/head branch December 17, 2023 15:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants