Skip to content

Commit cd30961

Browse files
sahil-kabirSahil Kabiryonigozlangithub-actions[bot]
authored
Integrate colqwen2.5 using colqwen2 modelling code (#40600)
* adding option for 2.5 * minor - arg in conversion script * getting started on modelling.py * minor - shouldve been using modular * adressing comments + fixing datatype/device _get method * minor * commiting suggestion Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> * docs + first test * ruff fix * minor fix * ruff fix * model fix * model fix * fine-grained check, with a hardcoded score from the original Hf implementation. * minor ruff * update tests values with CI hardware * adding 2.5 to conversion script * Apply style fixes --------- Co-authored-by: Sahil Kabir <sahilkabir@Sahils-MacBook-Pro.local> Co-authored-by: Yoni Gozlan <74535834+yonigozlan@users.noreply.github.com> Co-authored-by: yonigozlan <yoni.gozlan@huggingface.co> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent dd8f231 commit cd30961

File tree

5 files changed

+89
-8
lines changed

5 files changed

+89
-8
lines changed

docs/source/en/model_doc/colqwen2.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,24 @@ print("Retrieval scores (query x image):")
158158
print(scores)
159159
```
160160

161+
You can also use checkpoints for `ColQwen2.5` that are **compatible with the ColQwen2 architecture**. This version of the model uses [Qwen2_5_VL](./qwen2_5_vl) as the backbone.
162+
163+
```python
164+
import torch
165+
from transformers import ColQwen2ForRetrieval, ColQwen2Processor
166+
from transformers.utils.import_utils import is_flash_attn_2_available
167+
168+
model_name = "Sahil-Kabir/colqwen2.5-v0.2-hf" # An existing compatible checkpoint
169+
170+
model = ColQwen2ForRetrieval.from_pretrained(
171+
model_name,
172+
dtype=torch.bfloat16,
173+
device_map="auto",
174+
attn_implementation="flash_attention_2" if is_flash_attn_2_available() else "sdpa"
175+
)
176+
processor = ColQwen2Processor.from_pretrained(model_name)
177+
```
178+
161179
## Notes
162180

163181
- [`~ColQwen2Processor.score_retrieval`] returns a 2D tensor where the first dimension is the number of queries and the second dimension is the number of images. A higher score indicates more similarity between the query and image.

src/transformers/models/colqwen2/convert_colqwen2_weights_to_hf.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,10 @@
3939

4040
import torch
4141
from huggingface_hub import snapshot_download
42+
from peft import PeftModel
4243
from safetensors import safe_open
4344

44-
from transformers import AutoConfig
45+
from transformers import AutoConfig, AutoModel
4546
from transformers.models.colqwen2 import ColQwen2ForRetrieval
4647
from transformers.models.colqwen2.configuration_colqwen2 import ColQwen2Config
4748
from transformers.utils import logging
@@ -69,7 +70,7 @@ def load_original_state_dict(model_id: str, revision: Optional[str] = None) -> d
6970
original_state_dict[key] = f.get_tensor(key)
7071

7172
# Some weights are tied, so `lm.head`` is not saved. Let's clone to load state dict.
72-
if "lm_head.weight" not in original_state_dict:
73+
if "lm_head.weight" not in original_state_dict and "model.embed_tokens.weight" in original_state_dict:
7374
original_state_dict["lm_head.weight"] = original_state_dict["model.embed_tokens.weight"].clone()
7475

7576
return original_state_dict
@@ -124,7 +125,21 @@ def convert_colqwen2_weights_to_hf(
124125
config.is_composition = False
125126

126127
# Load the untrained model
127-
model = ColQwen2ForRetrieval(config=config).to("cpu").eval()
128+
vlm_name_or_path = getattr(config.vlm_config, "_name_or_path", None)
129+
if vlm_name_or_path and "2.5" in str(vlm_name_or_path):
130+
print(
131+
"Detected colqwen2.5 adapters in vlm_config; loading base model %s and merging PEFT weights."
132+
% vlm_name_or_path
133+
)
134+
base_model = AutoModel.from_pretrained(
135+
vlm_name_or_path,
136+
device_map="cpu",
137+
trust_remote_code=True,
138+
)
139+
peft_model = PeftModel.from_pretrained(base_model, model_id)
140+
model = peft_model.merge_and_unload()
141+
else:
142+
model = ColQwen2ForRetrieval(config=config).to("cpu").eval()
128143
print("Created model with new config and randomly initialized weights")
129144

130145
# NOTE: The new model was initialized with float32 weights. We need to convert it to the desired precision.
@@ -201,6 +216,7 @@ def convert_colqwen2_weights_to_hf(
201216
help="Name or path of the original VLM backbone model",
202217
default=None,
203218
)
219+
204220
args = parser.parse_args()
205221

206222
convert_colqwen2_weights_to_hf(

src/transformers/models/colqwen2/modeling_colqwen2.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,6 @@ def forward(
172172
inputs_embeds = self.vlm.language_model.embed_tokens(input_ids)
173173

174174
if pixel_values is not None:
175-
pixel_values = pixel_values.type(self.vlm.visual.get_dtype())
176175
image_embeds = self.vlm.visual(pixel_values, grid_thw=image_grid_thw)
177176
image_mask = (
178177
(input_ids == self.config.vlm_config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)

src/transformers/models/colqwen2/modular_colqwen2.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,6 @@ def forward(
359359
inputs_embeds = self.vlm.language_model.embed_tokens(input_ids)
360360

361361
if pixel_values is not None:
362-
pixel_values = pixel_values.type(self.vlm.visual.get_dtype())
363362
image_embeds = self.vlm.visual(pixel_values, grid_thw=image_grid_thw)
364363
image_mask = (
365364
(input_ids == self.config.vlm_config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)

tests/models/colqwen2/test_modeling_colqwen2.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -335,12 +335,61 @@ def test_model_integration_test(self):
335335
[15.6562, 12.2656, 20.2969],
336336
],
337337
("cuda", 8): [
338-
[15.0703, 8.7422, 15.0312],
339-
[9.5078, 16.8906, 10.6250],
340-
[15.6484, 12.3984, 20.4688],
338+
[16.2812, 8.3672, 14.5703],
339+
[9.4922, 17.1875, 10.3281],
340+
[15.0312, 11.3984, 20.1719],
341341
],
342342
}
343343
)
344344
expected_scores = torch.tensor(expectations.get_expectation(), dtype=scores.dtype)
345345

346346
assert torch.allclose(scores, expected_scores, atol=1e-3), f"Expected scores {expected_scores}, got {scores}"
347+
348+
@slow
349+
def test_model_integration_test_2(self):
350+
"""
351+
Test if the model is able to retrieve the correct pages for a small and easy dataset.
352+
This test uses a ColQwen2.5 checkpoint that is compatible with the ColQwen2 architecture.
353+
"""
354+
model = ColQwen2ForRetrieval.from_pretrained(
355+
"Sahil-Kabir/colqwen2.5-v0.2-hf",
356+
device_map=torch_device,
357+
dtype=torch.bfloat16,
358+
).eval()
359+
processor = ColQwen2Processor.from_pretrained("Sahil-Kabir/colqwen2.5-v0.2-hf", trust_remote_code=True)
360+
361+
# Load the test dataset
362+
ds = load_dataset("hf-internal-testing/document-visual-retrieval-test", split="test")
363+
364+
# Preprocess the examples
365+
batch_images = processor(images=list(ds["image"])).to(torch_device)
366+
batch_queries = processor(text=list(ds["query"])).to(torch_device)
367+
368+
with torch.inference_mode():
369+
image_embeddings = model(**batch_images).embeddings
370+
query_embeddings = model(**batch_queries).embeddings
371+
372+
# Compute retrieval scores
373+
scores = processor.score_retrieval(
374+
query_embeddings=query_embeddings,
375+
passage_embeddings=image_embeddings,
376+
)
377+
378+
assert scores.ndim == 2, f"Expected 2D tensor, got {scores.ndim}"
379+
assert scores.shape == (len(ds), len(ds)), f"Expected shape {(len(ds), len(ds))}, got {scores.shape}"
380+
381+
# Check if the maximum scores per row are in the diagonal of the matrix score
382+
self.assertTrue((scores.argmax(axis=1) == torch.arange(len(ds), device=scores.device)).all())
383+
# Further validation: fine-grained check, with a hardcoded score from the original Hf implementation.
384+
expectations = Expectations(
385+
{
386+
("cuda", 8): [
387+
[16.3750, 10.9375, 14.7500],
388+
[11.3750, 16.8750, 12.0625],
389+
[15.3125, 13.1250, 21.5000],
390+
]
391+
}
392+
)
393+
expected_scores = torch.tensor(expectations.get_expectation(), dtype=scores.dtype)
394+
395+
assert torch.allclose(scores, expected_scores, atol=0.15), f"Expected scores {expected_scores}, got {scores}"

0 commit comments

Comments
 (0)