|
7 | 7 | # LICENSE: https://huggingface.co/Motif-Technologies/Motif-2.6B/blob/main/LICENSE |
8 | 8 | """Inference-only Motif model compatible with HuggingFace weights.""" |
9 | 9 | import math |
10 | | -from collections.abc import Iterable |
11 | | -from typing import Any, Optional, Union |
| 10 | +from typing import Any, Optional |
12 | 11 |
|
13 | 12 | import torch |
14 | 13 | from torch import nn |
15 | 14 | from transformers import PretrainedConfig |
16 | 15 |
|
17 | 16 | from vllm.attention import Attention, AttentionType |
18 | 17 | from vllm.attention.selector import _Backend |
19 | | -from vllm.compilation.decorators import support_torch_compile |
20 | 18 | from vllm.config import CacheConfig, VllmConfig |
21 | | -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size |
| 19 | +from vllm.distributed import get_tensor_model_parallel_world_size |
22 | 20 | from vllm.model_executor.layers.layernorm import PolyNorm, RMSNorm |
23 | 21 | from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, |
24 | 22 | QKVParallelLinear, |
25 | 23 | RowParallelLinear) |
26 | | -from vllm.model_executor.layers.logits_processor import LogitsProcessor |
27 | 24 | from vllm.model_executor.layers.quantization import QuantizationConfig |
28 | 25 | from vllm.model_executor.layers.rotary_embedding import get_rope |
29 | | -from vllm.model_executor.layers.vocab_parallel_embedding import ( |
30 | | - ParallelLMHead, VocabParallelEmbedding) |
31 | | -from vllm.model_executor.model_loader.weight_utils import ( |
32 | | - default_weight_loader, maybe_remap_kv_scale_name) |
33 | | -from vllm.model_executor.sampling_metadata import SamplingMetadata |
34 | | -from vllm.sequence import IntermediateTensors |
| 26 | +from vllm.model_executor.models.llama import LlamaForCausalLM, LlamaModel |
35 | 27 |
|
36 | 28 | from .adapters import as_seq_cls_model |
37 | | -from .interfaces import SupportsLoRA, SupportsPP |
38 | | -from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, |
39 | | - is_pp_missing_parameter, |
40 | | - make_empty_intermediate_tensors_factory, make_layers, |
41 | | - maybe_prefix) |
| 29 | +from .interfaces import SupportsV0Only |
| 30 | +from .utils import extract_layer_index |
42 | 31 |
|
43 | 32 |
|
44 | 33 | class MotifMLP(nn.Module): |
@@ -332,227 +321,31 @@ def forward( |
332 | 321 | return hidden_states, residual |
333 | 322 |
|
334 | 323 |
|
335 | | -@support_torch_compile( |
336 | | - dynamic_arg_dims={ |
337 | | - "input_ids": 0, |
338 | | - # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, |
339 | | - # otherwise (seq_len, ). |
340 | | - "positions": -1, |
341 | | - "intermediate_tensors": 0, |
342 | | - "inputs_embeds": 0, |
343 | | - }) |
344 | | -class MotifModel(nn.Module): |
| 324 | +class MotifModel(LlamaModel): |
345 | 325 |
|
346 | 326 | def __init__(self, |
347 | 327 | *, |
348 | 328 | vllm_config: VllmConfig, |
349 | 329 | prefix: str = "", |
350 | 330 | decoder_layer_type: type[nn.Module] = MotifDecoderLayer): |
351 | | - super().__init__() |
352 | | - |
353 | | - config = vllm_config.model_config.hf_config |
354 | | - cache_config = vllm_config.cache_config |
355 | | - quant_config = vllm_config.quant_config |
| 331 | + super().__init__(vllm_config=vllm_config, |
| 332 | + prefix=prefix, |
| 333 | + layer_type=layer_type) |
356 | 334 |
|
357 | | - self.config = config |
358 | | - self.quant_config = quant_config |
359 | | - self.vocab_size = config.vocab_size |
360 | | - |
361 | | - if get_pp_group().is_first_rank or (config.tie_word_embeddings |
362 | | - and get_pp_group().is_last_rank): |
363 | | - self.embed_tokens = VocabParallelEmbedding( |
364 | | - config.vocab_size, |
365 | | - config.hidden_size, |
366 | | - quant_config=quant_config, |
367 | | - prefix=f"{prefix}.embed_tokens", |
368 | | - ) |
369 | | - else: |
370 | | - self.embed_tokens = PPMissingLayer() |
371 | | - |
372 | | - # Use the provided decoder layer type or default to MotifDecoderLayer |
373 | | - decoder_layer_type = decoder_layer_type or MotifDecoderLayer |
374 | | - self.start_layer, self.end_layer, self.layers = make_layers( |
375 | | - config.num_hidden_layers, |
376 | | - lambda prefix: decoder_layer_type(config=config, |
377 | | - cache_config=cache_config, |
378 | | - quant_config=quant_config, |
379 | | - prefix=prefix), |
380 | | - prefix=f"{prefix}.layers", |
381 | | - ) |
382 | 335 |
|
383 | | - self.make_empty_intermediate_tensors = ( |
384 | | - make_empty_intermediate_tensors_factory( |
385 | | - ["hidden_states", "residual"], config.hidden_size)) |
386 | | - if get_pp_group().is_last_rank: |
387 | | - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
388 | | - else: |
389 | | - self.norm = PPMissingLayer() |
| 336 | +# Motif model uses differential attention |
| 337 | +# Only supported in v0 (no chunked prefill support) |
| 338 | +class MotifForCausalLM(LlamaForCausalLM, SupportsV0Only): |
390 | 339 |
|
391 | | - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: |
392 | | - return self.embed_tokens(input_ids) |
393 | | - |
394 | | - def forward( |
395 | | - self, |
396 | | - input_ids: torch.Tensor, |
397 | | - positions: torch.Tensor, |
398 | | - intermediate_tensors: Optional[IntermediateTensors] = None, |
399 | | - inputs_embeds: Optional[torch.Tensor] = None, |
400 | | - ) -> Union[torch.Tensor, IntermediateTensors]: |
401 | | - if get_pp_group().is_first_rank: |
402 | | - if inputs_embeds is not None: |
403 | | - hidden_states = inputs_embeds |
404 | | - else: |
405 | | - hidden_states = self.get_input_embeddings(input_ids) |
406 | | - residual = None |
407 | | - else: |
408 | | - assert intermediate_tensors is not None |
409 | | - hidden_states = intermediate_tensors["hidden_states"] |
410 | | - residual = intermediate_tensors["residual"] |
411 | | - for layer in self.layers[self.start_layer:self.end_layer]: |
412 | | - hidden_states, residual = layer( |
413 | | - positions, |
414 | | - hidden_states, |
415 | | - residual, |
416 | | - ) |
417 | | - if not get_pp_group().is_last_rank: |
418 | | - return IntermediateTensors({ |
419 | | - "hidden_states": hidden_states, |
420 | | - "residual": residual |
421 | | - }) |
422 | | - hidden_states, _ = self.norm(hidden_states, residual) |
423 | | - return hidden_states |
424 | | - |
425 | | - def load_weights(self, weights: Iterable[tuple[str, |
426 | | - torch.Tensor]]) -> set[str]: |
427 | | - stacked_params_mapping = [ |
428 | | - # (param_name, shard_name, shard_id) |
429 | | - ("qkv_proj", "q_proj", "q"), |
430 | | - ("qkv_proj", "k_proj", "k"), |
431 | | - ("qkv_proj", "v_proj", "v"), |
432 | | - ("gate_up_proj", "gate_proj", 0), |
433 | | - ("gate_up_proj", "up_proj", 1), |
434 | | - ] |
435 | | - params_dict = dict(self.named_parameters(remove_duplicate=False)) |
436 | | - loaded_params: set[str] = set() |
437 | | - for name, loaded_weight in weights: |
438 | | - if "rotary_emb.inv_freq" in name: |
439 | | - continue |
440 | | - if (self.quant_config is not None and |
441 | | - (scale_name := self.quant_config.get_cache_scale(name))): |
442 | | - # Loading kv cache quantization scales |
443 | | - param = params_dict[scale_name] |
444 | | - weight_loader = getattr(param, "weight_loader", |
445 | | - default_weight_loader) |
446 | | - loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else |
447 | | - loaded_weight[0]) |
448 | | - weight_loader(param, loaded_weight) |
449 | | - loaded_params.add(scale_name) |
450 | | - continue |
451 | | - for (param_name, weight_name, shard_id) in stacked_params_mapping: |
452 | | - if weight_name not in name: |
453 | | - continue |
454 | | - name = name.replace(weight_name, param_name) |
455 | | - # Skip loading extra bias for GPTQ models. |
456 | | - if name.endswith(".bias") and name not in params_dict: |
457 | | - continue |
458 | | - if is_pp_missing_parameter(name, self): |
459 | | - continue |
460 | | - param = params_dict[name] |
461 | | - weight_loader = param.weight_loader |
462 | | - weight_loader(param, loaded_weight, shard_id) |
463 | | - break |
464 | | - else: |
465 | | - # Skip loading extra bias for GPTQ models. |
466 | | - if name.endswith(".bias") and name not in params_dict: |
467 | | - continue |
468 | | - # Remapping the name of FP8 kv-scale. |
469 | | - name = maybe_remap_kv_scale_name(name, params_dict) |
470 | | - if name is None: |
471 | | - continue |
472 | | - if is_pp_missing_parameter(name, self): |
473 | | - continue |
474 | | - param = params_dict[name] |
475 | | - weight_loader = getattr(param, "weight_loader", |
476 | | - default_weight_loader) |
477 | | - weight_loader(param, loaded_weight) |
478 | | - loaded_params.add(name) |
479 | | - return loaded_params |
480 | | - |
481 | | - |
482 | | -class MotifForCausalLM(nn.Module, SupportsLoRA, SupportsPP): |
483 | | - packed_modules_mapping = { |
484 | | - "qkv_proj": [ |
485 | | - "q_proj", |
486 | | - "k_proj", |
487 | | - "v_proj", |
488 | | - ], |
489 | | - "gate_up_proj": [ |
490 | | - "gate_proj", |
491 | | - "up_proj", |
492 | | - ], |
493 | | - } |
494 | | - |
495 | | - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
496 | | - super().__init__() |
497 | | - config = vllm_config.model_config.hf_config |
498 | | - quant_config = vllm_config.quant_config |
499 | | - lora_config = vllm_config.lora_config |
500 | | - |
501 | | - self.config = config |
502 | | - self.lora_config = lora_config |
503 | | - |
504 | | - self.quant_config = quant_config |
505 | | - self.model = MotifModel(vllm_config=vllm_config, |
506 | | - prefix=maybe_prefix(prefix, "model")) |
507 | | - |
508 | | - if get_pp_group().is_last_rank: |
509 | | - if config.tie_word_embeddings: |
510 | | - self.lm_head = self.model.embed_tokens |
511 | | - else: |
512 | | - self.lm_head = ParallelLMHead(config.vocab_size, |
513 | | - config.hidden_size, |
514 | | - quant_config=quant_config, |
515 | | - prefix=maybe_prefix( |
516 | | - prefix, "lm_head")) |
517 | | - else: |
518 | | - self.lm_head = PPMissingLayer() |
519 | | - |
520 | | - self.logits_processor = LogitsProcessor(config.vocab_size) |
521 | | - |
522 | | - self.make_empty_intermediate_tensors = ( |
523 | | - self.model.make_empty_intermediate_tensors) |
524 | | - |
525 | | - def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: |
526 | | - return self.model.get_input_embeddings(input_ids) |
| 340 | + def __init__(self, |
| 341 | + *, |
| 342 | + vllm_config: VllmConfig, |
| 343 | + prefix: str = "", |
| 344 | + layer_type: type[nn.Module] = MotifDecoderLayer): |
527 | 345 |
|
528 | | - def forward( |
529 | | - self, |
530 | | - input_ids: torch.Tensor, |
531 | | - positions: torch.Tensor, |
532 | | - intermediate_tensors: Optional[IntermediateTensors] = None, |
533 | | - inputs_embeds: Optional[torch.Tensor] = None, |
534 | | - ) -> Union[torch.Tensor, IntermediateTensors]: |
535 | | - hidden_states = self.model(input_ids, positions, intermediate_tensors, |
536 | | - inputs_embeds) |
537 | | - return hidden_states |
538 | | - |
539 | | - def compute_logits( |
540 | | - self, |
541 | | - hidden_states: torch.Tensor, |
542 | | - sampling_metadata: SamplingMetadata, |
543 | | - ) -> Optional[torch.Tensor]: |
544 | | - logits = self.logits_processor(self.lm_head, hidden_states, |
545 | | - sampling_metadata) |
546 | | - return logits |
547 | | - |
548 | | - def load_weights(self, weights: Iterable[tuple[str, |
549 | | - torch.Tensor]]) -> set[str]: |
550 | | - loader = AutoWeightsLoader( |
551 | | - self, |
552 | | - skip_prefixes=(["lm_head."] |
553 | | - if self.config.tie_word_embeddings else None), |
554 | | - ) |
555 | | - return loader.load_weights(weights) |
| 346 | + super().__init__(vllm_config=vllm_config, |
| 347 | + prefix=prefix, |
| 348 | + layer_type=layer_type) |
556 | 349 |
|
557 | 350 |
|
558 | 351 | MotifForSequenceClassification = as_seq_cls_model(MotifForCausalLM) |
0 commit comments