Skip to content

Fix FinchPress for Qwen models familiy #82

Merged
SimJeg merged 7 commits intomainfrom
fix-finch-sep
Jun 20, 2025
Merged

Fix FinchPress for Qwen models familiy #82
SimJeg merged 7 commits intomainfrom
fix-finch-sep

Conversation

@alessiodevoto
Copy link
Contributor

This addresses #80 by :

  1. adding a special case for Qwen (the eos_token strategy is not suitable, so we use another special token)
  2. testing that all tokenizers will work with FinchPress (i.e. whether tokenizer has either bos_token or suitable special token)

I'm not sure about the tokenizer testing as if a user wants to run make test, they need to have access and download all the possible tokenizers. I would either (a) make this test optional (if possible) (b) remove it and just raise an Error in FinchPress

@maxjeblick maxjeblick self-assigned this Jun 18, 2025
@maxjeblick maxjeblick self-requested a review June 18, 2025 11:03
@maxjeblick maxjeblick removed their assignment Jun 18, 2025
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
@alessiodevoto alessiodevoto force-pushed the fix-finch-sep branch 4 times, most recently from 283c2ad to 79cc835 Compare June 18, 2025 15:33
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
@alessiodevoto
Copy link
Contributor Author

Now in FinchPress we add a special token and use that to (a) delimit the context (b) check whether we are in prefilling stage or not, this way we use the same approach for all models

Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
@SimJeg SimJeg linked an issue Jun 19, 2025 that may be closed by this pull request
@SimJeg
Copy link
Collaborator

SimJeg commented Jun 19, 2025

@alessiodevoto thanks for your work 🙏 Before we merge:

  1. @alessiodevoto could you confirm you can run evaluation with FinchPress + Qwen3 ?
  2. @giulio98 could you review the code and confirm it's ok for you ?

@alessiodevoto
Copy link
Contributor Author

@SimJeg yes, tested on Qwen2.5 and Qwen3 👍

Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
@giulio98
Copy link
Contributor

Hello @alessiodevoto @SimJeg , I did first pass on the code please here

model.resize_token_embeddings(len(tokenizer))

do not resize model embeddings as it can produce a silent bug if the delimiter token is not correctly removed from the output embeddings.

@SimJeg
Copy link
Collaborator

SimJeg commented Jun 20, 2025

I did first pass on the code please here

please make all your comments

do not resize model embeddings as it can produce a silent bug if the delimiter token is not correctly removed from the output embeddings.

can you elaborate: which bug ? while would the delimiter not be correctly removed ?

@alessiodevoto
Copy link
Contributor Author

Hi @giulio98 thanks for the review! I'm not entirely sure I get your concern, but I believe the "silent bug" you're referring to shouldn't occur here since we: 1) explicitly check for exactly one delimiter token, and 2) remove it directly from the embeddings. Please let me know if I'm missing something!

Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
@giulio98
Copy link
Contributor

Hi @alessiodevoto @SimJeg,

I tested the current version of this PR on LongBench NarrativeQA for Qwen2.5-7B-Instruct using the following YARN scaling config:

model_kwargs.update({
    "max_position_embeddings": 131072,
    "rope_scaling": {
        "factor": 4.0,
        "original_max_position_embeddings": 32768,
        "type": "yarn"
    }
})

I got a very low score (6.84) and thought it was due to this warning when using resize_token_embeddings

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance...

But instead the implementation is correct because the delimiter token embeddings is correctly removed from output. After further inspection, I realized the problem was caused by the rerotation logic, which didn’t account for rope_scaling. I tried to fix it with the following code, which should now be agnostic to both LLaMA 3 and YARN scaling factors:

@staticmethod
def _rerotate_cos_sin(x, inv_freq, important_pos_batch):
    B, H, L = important_pos_batch.shape
    device = important_pos_batch.device
    device_type = x.device.type
    dtype = x.dtype
    idx = torch.arange(0, L, device=device).unsqueeze(0)
    inv_freq = inv_freq[None, None, :, None].float().expand(B, H, -1, 1)
    idx = idx[:, None, :].float().expand(B, H, L)
    delta_pos = (idx - important_pos_batch).unsqueeze(2)

    device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"

    with torch.autocast(device_type=device_type, enabled=False):
        freqs = delta_pos * inv_freq
        freqs = freqs.transpose(2, 3)
        emb = torch.cat((freqs, freqs), dim=-1)
        cos = emb.cos().contiguous()
        sin = emb.sin().contiguous()
    return cos.to(dtype=dtype), sin.to(dtype=dtype)

I then replaced this block (

indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)
):

indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)
# Rerotate keys
if self.rerotate_keys:
    cos, sin = kwargs["position_embeddings"]
    keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * (-sin.unsqueeze(1)))
    keys = keys.gather(2, indices).contiguous()
    cos, sin = cos[:, : indices.shape[2]], sin[:, : indices.shape[2]]
    keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1))
else:
    keys = keys.gather(2, indices).contiguous()

with this version:

# Rerotate keys
if self.rerotate_keys:
    new_cos, new_sin = self._rerotate_cos_sin(keys, module.rotary_emb.inv_freq, indices)
    indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)
    keys = keys.gather(2, indices).contiguous()
    keys = (keys * new_cos) + (rotate_half(keys) * new_sin)
else:
    indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)
    keys = keys.gather(2, indices).contiguous()

Results

Model Before Fix After Fix Δ
Llama3.1-8B-Instruct 30.47 30.59 +0.12
Qwen2.5-7B-Instruct (no YARN) 28.96 28.84 −0.12
Qwen2.5-7B-Instruct (YARN) 6.84 27.78 +20.94

Let me know if we should open a separate PR for this fix, or if you'd prefer integrating it here.
P.S.: I also observed degradation in expected_attention when using Qwen with YARN.

@SimJeg
Copy link
Collaborator

SimJeg commented Jun 20, 2025

@giulio98, interesting finding thanks ! please open a new issue and / or a PR to fix FinchPress, ExpectedAttentionPress and KeyRerotationPress. I will merge this one

@SimJeg SimJeg merged commit 97408ee into main Jun 20, 2025
3 checks passed
@SimJeg SimJeg deleted the fix-finch-sep branch June 20, 2025 15:50
maxjeblick pushed a commit that referenced this pull request Aug 12, 2025
Signed-off-by: Max Jeblick <maximilianjeblick@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

FinchPress not working on Qwen Model Family

4 participants