1313"""
1414import os
1515import json
16- from typing import List , Optional , Literal , Union , Iterator
16+ from typing import List , Optional , Literal , Union , Iterator , Dict
17+ from typing_extensions import TypedDict
1718
1819import llama_cpp
1920
@@ -64,13 +65,24 @@ class CreateCompletionRequest(BaseModel):
6465 max_tokens : int = 16
6566 temperature : float = 0.8
6667 top_p : float = 0.95
67- logprobs : Optional [int ] = Field (None )
6868 echo : bool = False
6969 stop : List [str ] = []
70- repeat_penalty : float = 1.1
71- top_k : int = 40
7270 stream : bool = False
7371
72+ # ignored or currently unsupported
73+ model : Optional [str ] = Field (None )
74+ n : Optional [int ] = 1
75+ logprobs : Optional [int ] = Field (None )
76+ presence_penalty : Optional [float ] = 0
77+ frequency_penalty : Optional [float ] = 0
78+ best_of : Optional [int ] = 1
79+ logit_bias : Optional [Dict [str , float ]] = Field (None )
80+ user : Optional [str ] = Field (None )
81+
82+ # llama.cpp specific parameters
83+ top_k : int = 40
84+ repeat_penalty : float = 1.1
85+
7486 class Config :
7587 schema_extra = {
7688 "example" : {
@@ -91,7 +103,20 @@ def create_completion(request: CreateCompletionRequest):
91103 if request .stream :
92104 chunks : Iterator [llama_cpp .CompletionChunk ] = llama (** request .dict ()) # type: ignore
93105 return EventSourceResponse (dict (data = json .dumps (chunk )) for chunk in chunks )
94- return llama (** request .dict ())
106+ return llama (
107+ ** request .dict (
108+ exclude = {
109+ "model" ,
110+ "n" ,
111+ "logprobs" ,
112+ "frequency_penalty" ,
113+ "presence_penalty" ,
114+ "best_of" ,
115+ "logit_bias" ,
116+ "user" ,
117+ }
118+ )
119+ )
95120
96121
97122class CreateEmbeddingRequest (BaseModel ):
@@ -132,6 +157,16 @@ class CreateChatCompletionRequest(BaseModel):
132157 stream : bool = False
133158 stop : List [str ] = []
134159 max_tokens : int = 128
160+
161+ # ignored or currently unsupported
162+ model : Optional [str ] = Field (None )
163+ n : Optional [int ] = 1
164+ presence_penalty : Optional [float ] = 0
165+ frequency_penalty : Optional [float ] = 0
166+ logit_bias : Optional [Dict [str , float ]] = Field (None )
167+ user : Optional [str ] = Field (None )
168+
169+ # llama.cpp specific parameters
135170 repeat_penalty : float = 1.1
136171
137172 class Config :
@@ -160,7 +195,16 @@ async def create_chat_completion(
160195 request : CreateChatCompletionRequest ,
161196) -> Union [llama_cpp .ChatCompletion , EventSourceResponse ]:
162197 completion_or_chunks = llama .create_chat_completion (
163- ** request .dict (exclude = {"model" }),
198+ ** request .dict (
199+ exclude = {
200+ "model" ,
201+ "n" ,
202+ "presence_penalty" ,
203+ "frequency_penalty" ,
204+ "logit_bias" ,
205+ "user" ,
206+ }
207+ ),
164208 )
165209
166210 if request .stream :
@@ -179,3 +223,40 @@ async def server_sent_events(
179223 )
180224 completion : llama_cpp .ChatCompletion = completion_or_chunks # type: ignore
181225 return completion
226+
227+
228+ class ModelData (TypedDict ):
229+ id : str
230+ object : Literal ["model" ]
231+ owned_by : str
232+ permissions : List [str ]
233+
234+
235+ class ModelList (TypedDict ):
236+ object : Literal ["list" ]
237+ data : List [ModelData ]
238+
239+
240+ GetModelResponse = create_model_from_typeddict (ModelList )
241+
242+
243+ @app .get ("/v1/models" , response_model = GetModelResponse )
244+ def get_models () -> ModelList :
245+ return {
246+ "object" : "list" ,
247+ "data" : [
248+ {
249+ "id" : llama .model_path ,
250+ "object" : "model" ,
251+ "owned_by" : "me" ,
252+ "permissions" : [],
253+ }
254+ ],
255+ }
256+
257+
258+ if __name__ == "__main__" :
259+ import os
260+ import uvicorn
261+
262+ uvicorn .run (app , host = os .getenv ("HOST" , "localhost" ), port = os .getenv ("PORT" , 8000 ))
0 commit comments