Add Unified Sequence Parallel attention#12693
Conversation
|
It would be nice to get a testing script so that we can quickly check things. |
|
I added a basic test script with a simple forward and backward op. Is it better to have a test script with flash_attention_backward and forward?? |
a244006 to
9dee8f8
Compare
9dee8f8 to
9ebcff5
Compare
|
Let us know if this is ready for a review! |
|
Yep, ready for review! I tested it with a 4-process setup (2×2 mesh, on cpu) and everything checks out, shapes look good and gradients flow correctly. Looking forward for feedback and happy to address any issues. |
sayakpaul
left a comment
There was a problem hiding this comment.
Thanks for getting started on this!
| grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value)) | ||
|
|
||
| return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None | ||
| return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None |
There was a problem hiding this comment.
The forward function has 12 inputs (without ctx (context)) but the backward is giving 11 output. Normally the two should be the same. I was getting an error like this while testing: "RuntimeError: function backward returned an incorrect number of gradients (expected 12, got 11)".
There was a problem hiding this comment.
Yes, it can be reproduced in this notebook (it happens only during the backward): https://colab.research.google.com/drive/1Ac4nVSVjKHrPpcSRlX0E3NzY0mDEmkMx?usp=sharing
|
I am trying with the following code: import torch
from torch import distributed as dist
from diffusers import AutoModel, DiffusionPipeline, ContextParallelConfig
def setup_distributed():
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
device = torch.device(f"cuda:{dist.get_rank()}")
torch.cuda.set_device(device)
return device
device = setup_distributed()
# Need to add parallel support for this.
# pipeline.transformer.set_attention_backend("flash_hub")
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16,
).to(device)
pipeline.transformer.set_attention_backend("_native_cudnn")
pipeline.transformer.enable_parallelism(
config=ContextParallelConfig(ulysses_degree=2, ring_degree=2)
)
prompt = """
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
"""
generator = torch.Generator().manual_seed(42)
image = pipeline(prompt, guidance_scale=3.5, num_inference_steps=50, generator=generator).images[0]
if dist.get_rank() == 0:
image.save("output_ua.png")
if dist.is_initialized():
dist.destroy_process_group()Run the above with And it leads to: |
I spent quite some time investigating this issue but wasn’t able to find the cause. I tried to reproduce it, but the model is too large for the small GPUs I can use, and |
Oooh finally tracked it down and could reproduce it on cpu! The bug is in the That |
|
I think that is perfect, I didn't know specific about torch 2.9. I will apply the diff. I will just do final test on lse on |
We need to add dedicated testing for CP x attention backends, anyway. So, we can skip for now. Sufficient documentation should suffice.
Sounds good! |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
sayakpaul
left a comment
There was a problem hiding this comment.
Looking good! Let's also add docs and remove test file.
| raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.") | ||
| if self.ring_degree > 1 and self.ulysses_degree > 1: | ||
| raise ValueError( | ||
| "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1." | ||
| ) | ||
| if self.rotate_method != "allgather": |
|
@bot /style |
|
Style bot fixed some files and pushed the changes. |
|
Okay I will add the docs and then remove the test file. |
|
Sure, feel free to! |
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
| scatter_idx: int = 2, | ||
| gather_idx: int = 1, |
There was a problem hiding this comment.
Small nit. Not a merge blocker. I don't think we need these here since they're not configurable through any of the public APIs. I think you can just hard code the scatter and gather idx.
|
The failing test (https://github.com/huggingface/diffusers/actions/runs/20943227141/job/60181041176?pr=12693#step:7:367) is passing locally for me (both CUDA and non-CUDA). This seems like a one-off transient error. Thanks a lot for your contributions! |
|
@Bissmella please generate your MVP ceritificate from https://huggingface.co/spaces/diffusers/generate-mvp-certificate. Also, let us know your HF account ID so that we can grant you some credits. Looking forward to more collaborations. |
|
Thank you so much @sayakpaul and @DN6! |
|
Could you try again? |
|
Yes worked, got it. Thanks |
|
Cool, your HF profile should also have the pro subscription and some credits to run experiments. |



What does this PR do?
This is a draft implementation of the Unified SP attention approach.
_all_to_all_dim_exchangewith custom scatter and gather indicesTemplatedUnifiedAttentionCore implementation complete, needs: