Skip to content

Commit 2d07e36

Browse files
committed
inherit from LlamaModel
Signed-off-by: ca1207 <ca1207zzz@gmail.com>
1 parent 4b20a28 commit 2d07e36

File tree

1 file changed

+20
-227
lines changed

1 file changed

+20
-227
lines changed

vllm/model_executor/models/motif.py

Lines changed: 20 additions & 227 deletions
Original file line numberDiff line numberDiff line change
@@ -7,38 +7,27 @@
77
# LICENSE: https://huggingface.co/Motif-Technologies/Motif-2.6B/blob/main/LICENSE
88
"""Inference-only Motif model compatible with HuggingFace weights."""
99
import math
10-
from collections.abc import Iterable
11-
from typing import Any, Optional, Union
10+
from typing import Any, Optional
1211

1312
import torch
1413
from torch import nn
1514
from transformers import PretrainedConfig
1615

1716
from vllm.attention import Attention, AttentionType
1817
from vllm.attention.selector import _Backend
19-
from vllm.compilation.decorators import support_torch_compile
2018
from vllm.config import CacheConfig, VllmConfig
21-
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
19+
from vllm.distributed import get_tensor_model_parallel_world_size
2220
from vllm.model_executor.layers.layernorm import PolyNorm, RMSNorm
2321
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
2422
QKVParallelLinear,
2523
RowParallelLinear)
26-
from vllm.model_executor.layers.logits_processor import LogitsProcessor
2724
from vllm.model_executor.layers.quantization import QuantizationConfig
2825
from vllm.model_executor.layers.rotary_embedding import get_rope
29-
from vllm.model_executor.layers.vocab_parallel_embedding import (
30-
ParallelLMHead, VocabParallelEmbedding)
31-
from vllm.model_executor.model_loader.weight_utils import (
32-
default_weight_loader, maybe_remap_kv_scale_name)
33-
from vllm.model_executor.sampling_metadata import SamplingMetadata
34-
from vllm.sequence import IntermediateTensors
26+
from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaModel
3527

3628
from .adapters import as_seq_cls_model
37-
from .interfaces import SupportsLoRA, SupportsPP
38-
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
39-
is_pp_missing_parameter,
40-
make_empty_intermediate_tensors_factory, make_layers,
41-
maybe_prefix)
29+
from .interfaces import SupportsV0Only
30+
from .utils import extract_layer_index
4231

4332

4433
class MotifMLP(nn.Module):
@@ -332,227 +321,31 @@ def forward(
332321
return hidden_states, residual
333322

334323

335-
@support_torch_compile(
336-
dynamic_arg_dims={
337-
"input_ids": 0,
338-
# positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
339-
# otherwise (seq_len, ).
340-
"positions": -1,
341-
"intermediate_tensors": 0,
342-
"inputs_embeds": 0,
343-
})
344-
class MotifModel(nn.Module):
324+
class MotifModel(LlamaModel):
345325

346326
def __init__(self,
347327
*,
348328
vllm_config: VllmConfig,
349329
prefix: str = "",
350330
decoder_layer_type: type[nn.Module] = MotifDecoderLayer):
351-
super().__init__()
352-
353-
config = vllm_config.model_config.hf_config
354-
cache_config = vllm_config.cache_config
355-
quant_config = vllm_config.quant_config
331+
super().__init__(vllm_config=vllm_config,
332+
prefix=prefix,
333+
layer_type=layer_type)
356334

357-
self.config = config
358-
self.quant_config = quant_config
359-
self.vocab_size = config.vocab_size
360-
361-
if get_pp_group().is_first_rank or (config.tie_word_embeddings
362-
and get_pp_group().is_last_rank):
363-
self.embed_tokens = VocabParallelEmbedding(
364-
config.vocab_size,
365-
config.hidden_size,
366-
quant_config=quant_config,
367-
prefix=f"{prefix}.embed_tokens",
368-
)
369-
else:
370-
self.embed_tokens = PPMissingLayer()
371-
372-
# Use the provided decoder layer type or default to MotifDecoderLayer
373-
decoder_layer_type = decoder_layer_type or MotifDecoderLayer
374-
self.start_layer, self.end_layer, self.layers = make_layers(
375-
config.num_hidden_layers,
376-
lambda prefix: decoder_layer_type(config=config,
377-
cache_config=cache_config,
378-
quant_config=quant_config,
379-
prefix=prefix),
380-
prefix=f"{prefix}.layers",
381-
)
382335

383-
self.make_empty_intermediate_tensors = (
384-
make_empty_intermediate_tensors_factory(
385-
["hidden_states", "residual"], config.hidden_size))
386-
if get_pp_group().is_last_rank:
387-
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
388-
else:
389-
self.norm = PPMissingLayer()
336+
# Motif model uses differential attention
337+
# Only supported in v0 (no chunked prefill support)
338+
class MotifForCausalLM(LlamaForCausalLM, SupportsV0Only):
390339

391-
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
392-
return self.embed_tokens(input_ids)
393-
394-
def forward(
395-
self,
396-
input_ids: torch.Tensor,
397-
positions: torch.Tensor,
398-
intermediate_tensors: Optional[IntermediateTensors] = None,
399-
inputs_embeds: Optional[torch.Tensor] = None,
400-
) -> Union[torch.Tensor, IntermediateTensors]:
401-
if get_pp_group().is_first_rank:
402-
if inputs_embeds is not None:
403-
hidden_states = inputs_embeds
404-
else:
405-
hidden_states = self.get_input_embeddings(input_ids)
406-
residual = None
407-
else:
408-
assert intermediate_tensors is not None
409-
hidden_states = intermediate_tensors["hidden_states"]
410-
residual = intermediate_tensors["residual"]
411-
for layer in self.layers[self.start_layer:self.end_layer]:
412-
hidden_states, residual = layer(
413-
positions,
414-
hidden_states,
415-
residual,
416-
)
417-
if not get_pp_group().is_last_rank:
418-
return IntermediateTensors({
419-
"hidden_states": hidden_states,
420-
"residual": residual
421-
})
422-
hidden_states, _ = self.norm(hidden_states, residual)
423-
return hidden_states
424-
425-
def load_weights(self, weights: Iterable[tuple[str,
426-
torch.Tensor]]) -> set[str]:
427-
stacked_params_mapping = [
428-
# (param_name, shard_name, shard_id)
429-
("qkv_proj", "q_proj", "q"),
430-
("qkv_proj", "k_proj", "k"),
431-
("qkv_proj", "v_proj", "v"),
432-
("gate_up_proj", "gate_proj", 0),
433-
("gate_up_proj", "up_proj", 1),
434-
]
435-
params_dict = dict(self.named_parameters(remove_duplicate=False))
436-
loaded_params: set[str] = set()
437-
for name, loaded_weight in weights:
438-
if "rotary_emb.inv_freq" in name:
439-
continue
440-
if (self.quant_config is not None and
441-
(scale_name := self.quant_config.get_cache_scale(name))):
442-
# Loading kv cache quantization scales
443-
param = params_dict[scale_name]
444-
weight_loader = getattr(param, "weight_loader",
445-
default_weight_loader)
446-
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
447-
loaded_weight[0])
448-
weight_loader(param, loaded_weight)
449-
loaded_params.add(scale_name)
450-
continue
451-
for (param_name, weight_name, shard_id) in stacked_params_mapping:
452-
if weight_name not in name:
453-
continue
454-
name = name.replace(weight_name, param_name)
455-
# Skip loading extra bias for GPTQ models.
456-
if name.endswith(".bias") and name not in params_dict:
457-
continue
458-
if is_pp_missing_parameter(name, self):
459-
continue
460-
param = params_dict[name]
461-
weight_loader = param.weight_loader
462-
weight_loader(param, loaded_weight, shard_id)
463-
break
464-
else:
465-
# Skip loading extra bias for GPTQ models.
466-
if name.endswith(".bias") and name not in params_dict:
467-
continue
468-
# Remapping the name of FP8 kv-scale.
469-
name = maybe_remap_kv_scale_name(name, params_dict)
470-
if name is None:
471-
continue
472-
if is_pp_missing_parameter(name, self):
473-
continue
474-
param = params_dict[name]
475-
weight_loader = getattr(param, "weight_loader",
476-
default_weight_loader)
477-
weight_loader(param, loaded_weight)
478-
loaded_params.add(name)
479-
return loaded_params
480-
481-
482-
class MotifForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
483-
packed_modules_mapping = {
484-
"qkv_proj": [
485-
"q_proj",
486-
"k_proj",
487-
"v_proj",
488-
],
489-
"gate_up_proj": [
490-
"gate_proj",
491-
"up_proj",
492-
],
493-
}
494-
495-
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
496-
super().__init__()
497-
config = vllm_config.model_config.hf_config
498-
quant_config = vllm_config.quant_config
499-
lora_config = vllm_config.lora_config
500-
501-
self.config = config
502-
self.lora_config = lora_config
503-
504-
self.quant_config = quant_config
505-
self.model = MotifModel(vllm_config=vllm_config,
506-
prefix=maybe_prefix(prefix, "model"))
507-
508-
if get_pp_group().is_last_rank:
509-
if config.tie_word_embeddings:
510-
self.lm_head = self.model.embed_tokens
511-
else:
512-
self.lm_head = ParallelLMHead(config.vocab_size,
513-
config.hidden_size,
514-
quant_config=quant_config,
515-
prefix=maybe_prefix(
516-
prefix, "lm_head"))
517-
else:
518-
self.lm_head = PPMissingLayer()
519-
520-
self.logits_processor = LogitsProcessor(config.vocab_size)
521-
522-
self.make_empty_intermediate_tensors = (
523-
self.model.make_empty_intermediate_tensors)
524-
525-
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
526-
return self.model.get_input_embeddings(input_ids)
340+
def __init__(self,
341+
*,
342+
vllm_config: VllmConfig,
343+
prefix: str = "",
344+
layer_type: type[nn.Module] = MotifDecoderLayer):
527345

528-
def forward(
529-
self,
530-
input_ids: torch.Tensor,
531-
positions: torch.Tensor,
532-
intermediate_tensors: Optional[IntermediateTensors] = None,
533-
inputs_embeds: Optional[torch.Tensor] = None,
534-
) -> Union[torch.Tensor, IntermediateTensors]:
535-
hidden_states = self.model(input_ids, positions, intermediate_tensors,
536-
inputs_embeds)
537-
return hidden_states
538-
539-
def compute_logits(
540-
self,
541-
hidden_states: torch.Tensor,
542-
sampling_metadata: SamplingMetadata,
543-
) -> Optional[torch.Tensor]:
544-
logits = self.logits_processor(self.lm_head, hidden_states,
545-
sampling_metadata)
546-
return logits
547-
548-
def load_weights(self, weights: Iterable[tuple[str,
549-
torch.Tensor]]) -> set[str]:
550-
loader = AutoWeightsLoader(
551-
self,
552-
skip_prefixes=(["lm_head."]
553-
if self.config.tie_word_embeddings else None),
554-
)
555-
return loader.load_weights(weights)
346+
super().__init__(vllm_config=vllm_config,
347+
prefix=prefix,
348+
layer_type=layer_type)
556349

557350

558351
MotifForSequenceClassification = as_seq_cls_model(MotifForCausalLM)

0 commit comments

Comments
 (0)