Skip to content

Conversation

@vasqu
Copy link
Contributor

@vasqu vasqu commented Sep 4, 2025

The issue at hand

  • Sliding window in eager/sdpa/flex used a window size of roughly (SW-1, SW // 2) instead of the intended (SW // 2, SW // 2)
  • Sliding window in flash attention ignored the bidirectionality and used a window size of (SW-1, SW-1)

The fix

  • Indicate bidirectionality at init with the new config flag
  • Fix window size (at config init + in the mask overlay)

Scripts

For sanity check, use the following script:

import torch
from sentence_transformers import SentenceTransformer


similarities_per_attn = []
rankings_per_attn = []
for attention in ["eager", "sdpa", "flash_attention_2"]:
    model = SentenceTransformer(
        "google/embeddinggemma-300m", 
        model_kwargs={"attn_implementation": f"{attention}", "dtype": torch.bfloat16}
    )

    query = "Which planet is known as the Red Planet?" * 30
    documents = [
        "Venus is often called Earth's twin because of its similar size and proximity." * 30,
        "Mars, known for its reddish appearance, is often referred to as the Red Planet." * 30,
        "Jupiter, the largest planet in our solar system, has a prominent red spot." * 30,
        "Saturn, famous for its rings, is sometimes mistaken for the Red Planet." * 30
    ]
    query_embeddings = model.encode_query(query)
    document_embeddings = model.encode_document(documents)

    # Compute similarities to determine a ranking
    similarities = model.similarity(query_embeddings, document_embeddings)
    similarities_per_attn.append(similarities)

    # Convert similarities to a ranking
    ranking = similarities.argsort(descending=True)[0]
    rankings_per_attn.append(ranking)

    print(f"{attention}: ")
    print(similarities)
    print(ranking)
    print()


similarity, rank = similarities_per_attn[0], rankings_per_attn[0]
for i in range(1, len(similarities_per_attn)):
    if not (torch.allclose(similarity, similarities_per_attn[i], atol=3e-2, rtol=3e-2) and torch.equal(rank, rankings_per_attn[i])):
        raise AssertionError()

For visualization:

from transformers import AutoConfig, AutoModel, AutoTokenizer


tokenizer = AutoTokenizer.from_pretrained("google/embeddinggemma-300m")
config = AutoConfig.from_pretrained("google/embeddinggemma-300m")
config.sliding_window = 6
model = AutoModel.from_pretrained("google/embeddinggemma-300m", config=config)

query = "Which planet is known as the Red Planet?" * 2
inputs = tokenizer(query, return_tensors="pt")

model(**inputs)

And insert/debug right after the masks in the models with

from transformers.masking_utils import tensor_to_mask_visual

print(tensor_to_mask_visual(causal_mask_mapping["sliding_attention"][0][0], grid_size=(30, 50)))

Before fix:
image
After fix:
image

cc @Cyrilvallez @tomaarsen

@tomaarsen tomaarsen self-requested a review September 4, 2025 17:22
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@vasqu
Copy link
Contributor Author

vasqu commented Sep 4, 2025

Alternatively, we could change the is_causal flag during the attention init based on the config's bidirectional flag - open for both.

Edit: See https://github.com/vasqu/transformers/pull/1/files

@Cyrilvallez
Copy link
Member

Nice! Thanks for the fix! Indeed, I'd rather change the flag in the init if you don't mind, so we don't have to change the sdpa integration! 🤗

@vasqu vasqu changed the title [Gemma Embedding] Fix Flash Attention usage [Gemma Embedding] Fix SWA Sep 5, 2025
@github-actions
Copy link
Contributor

github-actions bot commented Sep 5, 2025

[For maintainers] Suggested jobs to run (before merge)

run-slow: gemma3, gemma3n

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks a lot!

@Cyrilvallez Cyrilvallez merged commit 948bc0f into huggingface:main Sep 5, 2025
17 checks passed
@vasqu vasqu deleted the fix-gemma-embedding-fa branch September 5, 2025 15:12
Cyrilvallez pushed a commit that referenced this pull request Sep 5, 2025
* fix gemma embedding flash attention

* fix sdpa

* fix atttempt number 2

* alternative gemma fix

* fix modular
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants