Skip to content

Commit 2e2000f

Browse files
authored
[Model] Add LFM2 architecture (#22845)
Signed-off-by: Paul Pak <paulpak58@gmail.com>
1 parent 3128240 commit 2e2000f

File tree

11 files changed

+960
-8
lines changed

11 files changed

+960
-8
lines changed

docs/models/supported_models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ th {
373373
| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
374374
| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | ✅︎ |
375375
| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ |
376+
| `Lfm2ForCausalLM` | LFM2 | `LiquidAI/LFM2-1.2B`, `LiquidAI/LFM2-700M`, `LiquidAI/LFM2-350M`, etc. | ✅︎ | ✅︎ | ✅︎ |
376377
| `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ |
377378
| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | ✅︎ |
378379
| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | ✅︎ |

tests/models/language/generation/test_hybrid.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
"hmellor/tiny-random-BambaForCausalLM",
3232
"ibm-granite/granite-4.0-tiny-preview",
3333
"tiiuae/Falcon-H1-0.5B-Base",
34+
"LiquidAI/LFM2-1.2B",
3435
]
3536

3637
HF_UNSUPPORTED_MODELS = [
@@ -52,13 +53,18 @@
5253
"hmellor/tiny-random-BambaForCausalLM",
5354
"ibm-granite/granite-4.0-tiny-preview",
5455
"tiiuae/Falcon-H1-0.5B-Base",
56+
"LiquidAI/LFM2-1.2B",
5557
]
5658

5759
FULL_CUDA_GRAPH_MODELS = [
5860
"ai21labs/Jamba-tiny-dev",
5961
"Zyphra/Zamba2-1.2B-instruct",
6062
]
6163

64+
V0_UNSUPPORTED_MODELS = [
65+
"LiquidAI/LFM2-1.2B",
66+
]
67+
6268
# Avoid OOM
6369
MAX_NUM_SEQS = 4
6470

@@ -94,9 +100,12 @@ def test_models(
94100
else:
95101
hf_outputs = None
96102

97-
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
98-
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
99-
example_prompts, max_tokens, num_logprobs)
103+
if model not in V0_UNSUPPORTED_MODELS:
104+
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
105+
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
106+
example_prompts, max_tokens, num_logprobs)
107+
else:
108+
vllm_v0_outputs = None
100109

101110
if model in V1_SUPPORTED_MODELS:
102111
with monkeypatch.context() as m:
@@ -112,7 +121,7 @@ def test_models(
112121
else:
113122
vllm_v1_outputs = None
114123

115-
if hf_outputs is not None:
124+
if hf_outputs is not None and vllm_v0_outputs is not None:
116125
check_logprobs_close(
117126
outputs_0_lst=hf_outputs,
118127
outputs_1_lst=vllm_v0_outputs,
@@ -122,6 +131,7 @@ def test_models(
122131

123132
if model in V1_SUPPORTED_MODELS:
124133
ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
134+
assert ref_outputs is not None
125135
check_logprobs_close(
126136
outputs_0_lst=ref_outputs,
127137
outputs_1_lst=vllm_v1_outputs,
@@ -140,6 +150,9 @@ def test_batching(
140150
max_tokens: int,
141151
num_logprobs: int,
142152
) -> None:
153+
if model in V0_UNSUPPORTED_MODELS:
154+
pytest.skip(
155+
f"Unsupported V0 Engine. Skipping `test_batching` on {model}.")
143156

144157
try:
145158
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
@@ -392,9 +405,12 @@ def test_full_cuda_graph(
392405
else:
393406
hf_outputs = None
394407

395-
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
396-
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
397-
example_prompts, max_tokens, num_logprobs)
408+
if model not in V0_UNSUPPORTED_MODELS:
409+
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
410+
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
411+
example_prompts, max_tokens, num_logprobs)
412+
else:
413+
vllm_v0_outputs = None
398414

399415
with monkeypatch.context() as m:
400416
m.setenv("VLLM_USE_V1", "1")
@@ -408,7 +424,7 @@ def test_full_cuda_graph(
408424
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
409425
example_prompts, max_tokens, num_logprobs)
410426

411-
if hf_outputs is not None:
427+
if hf_outputs is not None and vllm_v0_outputs is not None:
412428
check_logprobs_close(
413429
outputs_0_lst=hf_outputs,
414430
outputs_1_lst=vllm_v0_outputs,
@@ -417,6 +433,7 @@ def test_full_cuda_graph(
417433
)
418434

419435
ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
436+
assert ref_outputs is not None
420437
check_logprobs_close(
421438
outputs_0_lst=ref_outputs,
422439
outputs_1_lst=vllm_v1_outputs,

tests/models/registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,8 @@ def check_available_online(
230230
"tiny": "ai21labs/Jamba-tiny-dev",
231231
"random": "ai21labs/Jamba-tiny-random", # noqa: E501
232232
}),
233+
"Lfm2ForCausalLM": _HfExamplesInfo("LiquidAI/LFM2-1.2B",
234+
min_transformers_version="4.54"),
233235
"LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.2-1B-Instruct",
234236
extras={"guard": "meta-llama/Llama-Guard-3-1B", # noqa: E501
235237
"hermes": "NousResearch/Hermes-3-Llama-3.1-8B", # noqa: E501

tests/models/test_initialization.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ def _initialize_kv_caches_v1(self, vllm_config):
9595

9696
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
9797
def test_can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch):
98+
if model_arch == "Lfm2ForCausalLM":
99+
pytest.skip("Skipping until test supports V1-only models")
98100
can_initialize(model_arch, monkeypatch, HF_EXAMPLE_MODELS)
99101

100102

vllm/config/compilation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,7 @@ class CompilationConfig:
337337
"vllm.unified_attention_with_output",
338338
"vllm.mamba_mixer2",
339339
"vllm.mamba_mixer",
340+
"vllm.short_conv",
340341
]
341342

342343
def compute_hash(self) -> str:

vllm/model_executor/layers/mamba/mamba_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,16 @@ def mamba2_state_dtype(
5454

5555
return (conv_state_dtype, temporal_state_dtype)
5656

57+
@classmethod
58+
def short_conv_state_dtype(
59+
cls,
60+
model_dtype: Union[ModelDType, torch.dtype],
61+
mamba_cache_dtype: MambaDType,
62+
) -> tuple[torch.dtype, ...]:
63+
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype,
64+
model_dtype)
65+
return (conv_state_dtype, )
66+
5767

5868
class MambaStateShapeCalculator:
5969

@@ -122,6 +132,20 @@ def mamba2_state_shape(
122132
tp_world_size), head_dim, state_size)
123133
return conv_state_shape, temporal_state_shape
124134

135+
@classmethod
136+
def short_conv_state_shape(
137+
cls,
138+
tp_world_size: int,
139+
intermediate_size: int,
140+
conv_kernel: int,
141+
use_v1: bool = True,
142+
) -> tuple[tuple[int, int]]:
143+
conv_dim = divide(intermediate_size, tp_world_size)
144+
conv_state_shape = (conv_kernel - 1, conv_dim)
145+
if not use_v1:
146+
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
147+
return (conv_state_shape, )
148+
125149
@classmethod
126150
def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int):
127151
"""Compute the increase in group numbers to account for

0 commit comments

Comments
 (0)