Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2078,3 +2078,19 @@ def __call__(
self, input_ids: npt.NDArray[np.intc], logits: npt.NDArray[np.single]
) -> bool:
return any([stopping_criteria(input_ids, logits) for stopping_criteria in self])


class MinTokensLogitsProcessor(LogitsProcessor):
def __init__(self, min_tokens: int, token_eos: int):
self.min_tokens = min_tokens
self.token_eos = token_eos
self.prompt_tokens = None

def __call__(
self, input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single]
) -> npt.NDArray[np.single]:
if self.prompt_tokens is None:
self.prompt_tokens = len(input_ids)
if len(input_ids) - self.prompt_tokens < self.min_tokens:
scores[self.token_eos] = -np.inf
return scores
20 changes: 20 additions & 0 deletions llama_cpp/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ async def create_completion(
"best_of",
"logit_bias_type",
"user",
"min_tokens",
}
kwargs = body.model_dump(exclude=exclude)

Expand All @@ -288,6 +289,15 @@ async def create_completion(
if body.grammar is not None:
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)

if body.min_tokens > 0:
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
)
if "logits_processor" not in kwargs:
kwargs["logits_processor"] = _min_tokens_logits_processor
else:
kwargs["logits_processor"].extend(_min_tokens_logits_processor)

iterator_or_completion: Union[
llama_cpp.CreateCompletionResponse,
Iterator[llama_cpp.CreateCompletionStreamResponse],
Expand Down Expand Up @@ -445,6 +455,7 @@ async def create_chat_completion(
"n",
"logit_bias_type",
"user",
"min_tokens",
}
kwargs = body.model_dump(exclude=exclude)
llama = llama_proxy(body.model)
Expand All @@ -458,6 +469,15 @@ async def create_chat_completion(
if body.grammar is not None:
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)

if body.min_tokens > 0:
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
)
if "logits_processor" not in kwargs:
kwargs["logits_processor"] = _min_tokens_logits_processor
else:
kwargs["logits_processor"].extend(_min_tokens_logits_processor)

iterator_or_completion: Union[
llama_cpp.ChatCompletion, Iterator[llama_cpp.ChatCompletionChunk]
] = await run_in_threadpool(llama.create_chat_completion, **kwargs)
Expand Down
8 changes: 8 additions & 0 deletions llama_cpp/server/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
default=16, ge=1, description="The maximum number of tokens to generate."
)

min_tokens_field = Field(
default=0,
ge=0,
description="The minimum number of tokens to generate. It may return fewer tokens if another condition is met (e.g. max_tokens, stop).",
)

temperature_field = Field(
default=0.8,
description="Adjust the randomness of the generated text.\n\n"
Expand Down Expand Up @@ -111,6 +117,7 @@ class CreateCompletionRequest(BaseModel):
max_tokens: Optional[int] = Field(
default=16, ge=0, description="The maximum number of tokens to generate."
)
min_tokens: int = min_tokens_field
temperature: float = temperature_field
top_p: float = top_p_field
min_p: float = min_p_field
Expand Down Expand Up @@ -206,6 +213,7 @@ class CreateChatCompletionRequest(BaseModel):
default=None,
description="The maximum number of tokens to generate. Defaults to inf",
)
min_tokens: int = min_tokens_field
logprobs: Optional[bool] = Field(
default=False,
description="Whether to output the logprobs or not. Default is True"
Expand Down