Skip to content
8 changes: 8 additions & 0 deletions llama_cpp/_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@ def token_eos(self) -> int:
assert self.model is not None
return llama_cpp.llama_token_eos(self.model)

def token_cls(self) -> int:
assert self.model is not None
return llama_cpp.llama_token_cls(self.model)

def token_sep(self) -> int:
assert self.model is not None
return llama_cpp.llama_token_sep(self.model)

def token_nl(self) -> int:
assert self.model is not None
return llama_cpp.llama_token_nl(self.model)
Expand Down
7 changes: 7 additions & 0 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import ctypes
import typing
import fnmatch
import warnings
import multiprocessing

from typing import (
Expand Down Expand Up @@ -1019,6 +1020,12 @@ def _create_completion(
)
model_name: str = model if model is not None else self.model_path

if prompt_tokens[:2] == [self.token_bos()] * 2:
warnings.warn(
f'Detected duplicate leading "{self._model.token_get_text(self.token_bos())}" in prompt, this will likely reduce response quality, consider removing it...',
RuntimeWarning,
)

# NOTE: This likely doesn't work correctly for the first token in the prompt
# because of the extra space added to the start of the prompt_tokens
if logit_bias is not None:
Expand Down
17 changes: 8 additions & 9 deletions llama_cpp/llama_chat_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ class ChatFormatterResponse:
prompt: str
stop: Optional[Union[str, List[str]]] = None
stopping_criteria: Optional[llama.StoppingCriteriaList] = None
added_special: bool = False


class ChatFormatter(Protocol):
Expand Down Expand Up @@ -232,7 +233,7 @@ def stop_on_last_token(
return tokens[-1] in self.stop_token_ids
stopping_criteria = llama.StoppingCriteriaList([stop_on_last_token])

return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token], stopping_criteria=stopping_criteria)
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token], stopping_criteria=stopping_criteria, added_special=True)

def to_chat_handler(self) -> LlamaChatCompletionHandler:
return chat_formatter_to_chat_completion_handler(self)
Expand Down Expand Up @@ -548,7 +549,7 @@ def chat_completion_handler(
tools=tools,
tool_choice=tool_choice,
)
prompt = result.prompt
prompt = llama.tokenize(result.prompt.encode("utf-8"), add_bos=not result.added_special, special=True)
if result.stop is not None:
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
Expand Down Expand Up @@ -655,7 +656,7 @@ def format_autotokenizer(
prompt: str = tokenizer.apply_chat_template(messages, tokenize=False) # type: ignore
assert isinstance(prompt, str)
# Return formatted prompt and eos token by default
return ChatFormatterResponse(prompt=prompt, stop=tokenizer.eos_token)
return ChatFormatterResponse(prompt=prompt, stop=tokenizer.eos_token, added_special=True)

return format_autotokenizer

Expand Down Expand Up @@ -708,7 +709,7 @@ def format_tokenizer_config(
bos_token=bos_token,
eos_token=eos_token,
)
return ChatFormatterResponse(prompt=prompt, stop=[eos_token, bos_token])
return ChatFormatterResponse(prompt=prompt, stop=[eos_token, bos_token], added_special=True)

return format_tokenizer_config

Expand Down Expand Up @@ -918,7 +919,7 @@ def format_llama2(
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any,
) -> ChatFormatterResponse:
_system_template = "<s>[INST] <<SYS>>\n{system_message}\n<</SYS>>"
_system_template = "[INST] <<SYS>>\n{system_message}\n<</SYS>>"
_roles = dict(user="<s>[INST]", assistant="[/INST]")
_messages = _map_roles(messages, _roles)
system_message = _get_system_message(messages)
Expand All @@ -940,11 +941,10 @@ def format_llama3(
user="<|start_header_id|>user<|end_header_id|>\n\n",
assistant="<|start_header_id|>assistant<|end_header_id|>\n\n",
)
_begin_token = "<|begin_of_text|>"
_sep = "<|eot_id|>"
_messages = _map_roles(messages, _roles)
_messages.append((_roles["assistant"], None))
_prompt = _format_no_colon_single(_begin_token, _messages, _sep)
_prompt = _format_no_colon_single("", _messages, _sep)
return ChatFormatterResponse(prompt=_prompt, stop=_sep)


Expand Down Expand Up @@ -1229,10 +1229,9 @@ def format_mistral_instruct(
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any,
) -> ChatFormatterResponse:
bos = "<s>"
eos = "</s>"
stop = eos
prompt = bos
prompt = ""
for message in messages:
if (
message["role"] == "user"
Expand Down
3 changes: 2 additions & 1 deletion tests/test_llama_chat_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ def test_mistral_instruct():
response = llama_chat_format.format_mistral_instruct(
messages=messages,
)
prompt = ("" if response.added_special else "<s>") + response.prompt
reference = chat_formatter.render(
messages=messages,
bos_token="<s>",
eos_token="</s>",
)
assert response.prompt == reference
assert prompt == reference


mistral_7b_tokenizer_config = """{
Expand Down