Conversation
|
Cc: @yiyixuxu could you review the changes made to |
|
@entrpn can you use a custom attention instead? (without updating our default attention processor) |
Hi @yiyixuxu , we wrapped the flash attention kernel call under condition |
|
I'm just wondering if it makes sense for Flash Attention to have its attention processor since this one is meant for SDPA cc @DN6 here too |
|
Hi @yiyixuxu , what about we create another AttnProcess with flash attention in parallel with |
|
@zpcore this way user can explicitly set to use flash attention if they want to |
|
@yiyixuxu - to better understand, can you please help me understand why wrapping the flash attention kernel call under condition |
|
is it not possible that XLA_AVAILABLE but the user does not want to use flash attention? |
|
Thanks for the review feedback. We split out the XLA flash attention process from AttnProcessor2_0 as requested in the review. PTAL |
| if len(args) > 0 or kwargs.get("scale", None) is not None: | ||
| deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." | ||
| deprecate("scale", "1.0.0", deprecation_message) |
There was a problem hiding this comment.
Since this is a new attention processor, I think we can safely remove this.
| return hidden_states | ||
|
|
||
|
|
||
| class XLAFlashAttnProcessor2_0: |
There was a problem hiding this comment.
So, this will be automatically used when using the compatible models under an XLA environment, right?
There was a problem hiding this comment.
Yes, AttnProcessor2_0 will be replaced with XLAFlashAttnProcessor2_0 if XLA version condition satisfied.
| if is_torch_xla_available(): | ||
| from torch_xla.experimental.custom_kernel import flash_attention |
There was a problem hiding this comment.
Does this need to go through any version check guards too i.e., a minimum version known to have flash_attention?
There was a problem hiding this comment.
Introduced the version check function is_torch_xla_version in import_utils.py. Added the version check for torch_xla here.
| AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() | ||
| ) | ||
| if hasattr(F, "scaled_dot_product_attention") and self.scale_qk: | ||
| if is_torch_xla_available: |
There was a problem hiding this comment.
Same here too. Does this need to be guarded with a version check too?
There was a problem hiding this comment.
Added the version check for torch_xla here too.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
I think @yiyixuxu's point here is valid:
IMO it's better to use a similar API to xformers to enable the XLA processor. |
OK, now I get it! We have added functions like |
yiyixuxu
left a comment
There was a problem hiding this comment.
thanks! l think we can merge this soon!
| if ( | ||
| use_xla_flash_attention | ||
| and is_torch_xla_available | ||
| and is_torch_xla_version('>', '2.2') | ||
| and (not is_spmd() or is_torch_xla_version('>', '2.3')) | ||
| ): | ||
| processor = XLAFlashAttnProcessor2_0(partition_spec) |
There was a problem hiding this comment.
| if ( | |
| use_xla_flash_attention | |
| and is_torch_xla_available | |
| and is_torch_xla_version('>', '2.2') | |
| and (not is_spmd() or is_torch_xla_version('>', '2.3')) | |
| ): | |
| processor = XLAFlashAttnProcessor2_0(partition_spec) | |
| if use_xla_flash_attention: | |
| if is_torch_xla_version("<", "2.3"): | |
| raise ... | |
| elif is_spmd() and is_torch_xla_version("<", "2.4"): | |
| raise ... | |
| else: | |
| processor = XLAFlashAttnProcessor2_0(partition_spec) | |
| ): | |
| processor = XLAFlashAttnProcessor2_0(partition_spec) |
if user explicitly set xla_flash_attention, we want to give very explicit warn/error message when the condition wasn't met so they can take actions accordingly - we don't want to silently switch to something just because it wasn't installed
| partition_spec = self.partition_spec if is_spmd() else None | ||
| hidden_states = flash_attention(query, key, value, causal=False, partition_spec=partition_spec) | ||
| else: | ||
| hidden_states = F.scaled_dot_product_attention( |
There was a problem hiding this comment.
we don't need to support SDPA in this XLAFlash attention processor! - we can remove all the logics related to it!
There was a problem hiding this comment.
There is a constraint when using the pallas kernel. We need this all(tensor.shape[2] >= 4096 for tensor in [query, key, value]) or xla will error out.
However, we added a new error message when it fall back to scaled_dot_product_attention to avoid silently skip the kernel.
There was a problem hiding this comment.
ok thank you for explaining to me!
| if not hasattr(F, "scaled_dot_product_attention"): | ||
| raise ImportError("XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
There was a problem hiding this comment.
| if not hasattr(F, "scaled_dot_product_attention"): | |
| raise ImportError("XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") |
I think we don't need to support SDPA in XLA Flash attention processor! let's remove all the logics related to that to simplify things a bit!
There was a problem hiding this comment.
Please check my comment above for why we still keep it here. Thanks
| return compare_versions(parse(_torch_version), operation, version) | ||
|
|
||
|
|
||
| def is_torch_xla_version(operation: str, version: str): |
There was a problem hiding this comment.
can we make sure we can call is_torch_xla_version() when it is not installed? currently, I think you will have to run it together with is_torch_xla_available(), because the _torch_xla_version is not defined otherwise
we can do like this
|
Thank you all! |
* update ptxla example --------- Co-authored-by: Juan Acevedo <jfacevedo@google.com> Co-authored-by: Pei Zhang <zpcore@gmail.com> Co-authored-by: Pei Zhang <piz@google.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Pei Zhang <pei@Peis-MacBook-Pro.local> Co-authored-by: hlky <hlky@hlky.ac>
@sayakpaul can you please review. This new PR supersedes the other one I had opened a while back, which I just closed. Thank you.
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.