Conversation
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
283c2ad to
79cc835
Compare
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
79cc835 to
4735ced
Compare
|
Now in FinchPress we add a special token and use that to (a) delimit the context (b) check whether we are in prefilling stage or not, this way we use the same approach for all models |
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
|
@alessiodevoto thanks for your work 🙏 Before we merge:
|
|
@SimJeg yes, tested on Qwen2.5 and Qwen3 👍 |
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
|
Hello @alessiodevoto @SimJeg , I did first pass on the code please here kvpress/kvpress/presses/finch_press.py Line 140 in 3b3b842 do not resize model embeddings as it can produce a silent bug if the delimiter token is not correctly removed from the output embeddings. |
please make all your comments
can you elaborate: which bug ? while would the delimiter not be correctly removed ? |
|
Hi @giulio98 thanks for the review! I'm not entirely sure I get your concern, but I believe the "silent bug" you're referring to shouldn't occur here since we: 1) explicitly check for exactly one delimiter token, and 2) remove it directly from the embeddings. Please let me know if I'm missing something! |
Signed-off-by: alessiodevoto <devoto.alessio@gmail.com>
|
I tested the current version of this PR on LongBench NarrativeQA for Qwen2.5-7B-Instruct using the following YARN scaling config: model_kwargs.update({
"max_position_embeddings": 131072,
"rope_scaling": {
"factor": 4.0,
"original_max_position_embeddings": 32768,
"type": "yarn"
}
})I got a very low score (6.84) and thought it was due to this warning when using But instead the implementation is correct because the delimiter token embeddings is correctly removed from output. After further inspection, I realized the problem was caused by the rerotation logic, which didn’t account for rope_scaling. I tried to fix it with the following code, which should now be agnostic to both LLaMA 3 and YARN scaling factors: @staticmethod
def _rerotate_cos_sin(x, inv_freq, important_pos_batch):
B, H, L = important_pos_batch.shape
device = important_pos_batch.device
device_type = x.device.type
dtype = x.dtype
idx = torch.arange(0, L, device=device).unsqueeze(0)
inv_freq = inv_freq[None, None, :, None].float().expand(B, H, -1, 1)
idx = idx[:, None, :].float().expand(B, H, L)
delta_pos = (idx - important_pos_batch).unsqueeze(2)
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = delta_pos * inv_freq
freqs = freqs.transpose(2, 3)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().contiguous()
sin = emb.sin().contiguous()
return cos.to(dtype=dtype), sin.to(dtype=dtype)I then replaced this block ( kvpress/kvpress/presses/finch_press.py Line 98 in a65cf49 indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)
# Rerotate keys
if self.rerotate_keys:
cos, sin = kwargs["position_embeddings"]
keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * (-sin.unsqueeze(1)))
keys = keys.gather(2, indices).contiguous()
cos, sin = cos[:, : indices.shape[2]], sin[:, : indices.shape[2]]
keys = (keys * cos.unsqueeze(1)) + (rotate_half(keys) * sin.unsqueeze(1))
else:
keys = keys.gather(2, indices).contiguous()with this version: # Rerotate keys
if self.rerotate_keys:
new_cos, new_sin = self._rerotate_cos_sin(keys, module.rotary_emb.inv_freq, indices)
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)
keys = keys.gather(2, indices).contiguous()
keys = (keys * new_cos) + (rotate_half(keys) * new_sin)
else:
indices = indices.unsqueeze(-1).expand(-1, -1, -1, module.head_dim)
keys = keys.gather(2, indices).contiguous()Results
Let me know if we should open a separate PR for this fix, or if you'd prefer integrating it here. |
|
@giulio98, interesting finding thanks ! please open a new issue and / or a PR to fix FinchPress, ExpectedAttentionPress and KeyRerotationPress. I will merge this one |
Signed-off-by: Max Jeblick <maximilianjeblick@gmail.com>
This addresses #80 by :
I'm not sure about the tokenizer testing as if a user wants to run
make test, they need to have access and download all the possible tokenizers. I would either (a) make this test optional (if possible) (b) remove it and just raise an Error in FinchPress