Skip to content

Commit e1b5b9b

Browse files
committed
Update fastapi server example
1 parent 6de2f24 commit e1b5b9b

File tree

1 file changed

+87
-6
lines changed

1 file changed

+87
-6
lines changed

examples/high_level_api/fastapi_server.py

Lines changed: 87 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
"""
1414
import os
1515
import json
16-
from typing import List, Optional, Literal, Union, Iterator
16+
from typing import List, Optional, Literal, Union, Iterator, Dict
17+
from typing_extensions import TypedDict
1718

1819
import llama_cpp
1920

@@ -64,13 +65,24 @@ class CreateCompletionRequest(BaseModel):
6465
max_tokens: int = 16
6566
temperature: float = 0.8
6667
top_p: float = 0.95
67-
logprobs: Optional[int] = Field(None)
6868
echo: bool = False
6969
stop: List[str] = []
70-
repeat_penalty: float = 1.1
71-
top_k: int = 40
7270
stream: bool = False
7371

72+
# ignored or currently unsupported
73+
model: Optional[str] = Field(None)
74+
n: Optional[int] = 1
75+
logprobs: Optional[int] = Field(None)
76+
presence_penalty: Optional[float] = 0
77+
frequency_penalty: Optional[float] = 0
78+
best_of: Optional[int] = 1
79+
logit_bias: Optional[Dict[str, float]] = Field(None)
80+
user: Optional[str] = Field(None)
81+
82+
# llama.cpp specific parameters
83+
top_k: int = 40
84+
repeat_penalty: float = 1.1
85+
7486
class Config:
7587
schema_extra = {
7688
"example": {
@@ -91,7 +103,20 @@ def create_completion(request: CreateCompletionRequest):
91103
if request.stream:
92104
chunks: Iterator[llama_cpp.CompletionChunk] = llama(**request.dict()) # type: ignore
93105
return EventSourceResponse(dict(data=json.dumps(chunk)) for chunk in chunks)
94-
return llama(**request.dict())
106+
return llama(
107+
**request.dict(
108+
exclude={
109+
"model",
110+
"n",
111+
"logprobs",
112+
"frequency_penalty",
113+
"presence_penalty",
114+
"best_of",
115+
"logit_bias",
116+
"user",
117+
}
118+
)
119+
)
95120

96121

97122
class CreateEmbeddingRequest(BaseModel):
@@ -132,6 +157,16 @@ class CreateChatCompletionRequest(BaseModel):
132157
stream: bool = False
133158
stop: List[str] = []
134159
max_tokens: int = 128
160+
161+
# ignored or currently unsupported
162+
model: Optional[str] = Field(None)
163+
n: Optional[int] = 1
164+
presence_penalty: Optional[float] = 0
165+
frequency_penalty: Optional[float] = 0
166+
logit_bias: Optional[Dict[str, float]] = Field(None)
167+
user: Optional[str] = Field(None)
168+
169+
# llama.cpp specific parameters
135170
repeat_penalty: float = 1.1
136171

137172
class Config:
@@ -160,7 +195,16 @@ async def create_chat_completion(
160195
request: CreateChatCompletionRequest,
161196
) -> Union[llama_cpp.ChatCompletion, EventSourceResponse]:
162197
completion_or_chunks = llama.create_chat_completion(
163-
**request.dict(exclude={"model"}),
198+
**request.dict(
199+
exclude={
200+
"model",
201+
"n",
202+
"presence_penalty",
203+
"frequency_penalty",
204+
"logit_bias",
205+
"user",
206+
}
207+
),
164208
)
165209

166210
if request.stream:
@@ -179,3 +223,40 @@ async def server_sent_events(
179223
)
180224
completion: llama_cpp.ChatCompletion = completion_or_chunks # type: ignore
181225
return completion
226+
227+
228+
class ModelData(TypedDict):
229+
id: str
230+
object: Literal["model"]
231+
owned_by: str
232+
permissions: List[str]
233+
234+
235+
class ModelList(TypedDict):
236+
object: Literal["list"]
237+
data: List[ModelData]
238+
239+
240+
GetModelResponse = create_model_from_typeddict(ModelList)
241+
242+
243+
@app.get("/v1/models", response_model=GetModelResponse)
244+
def get_models() -> ModelList:
245+
return {
246+
"object": "list",
247+
"data": [
248+
{
249+
"id": llama.model_path,
250+
"object": "model",
251+
"owned_by": "me",
252+
"permissions": [],
253+
}
254+
],
255+
}
256+
257+
258+
if __name__ == "__main__":
259+
import os
260+
import uvicorn
261+
262+
uvicorn.run(app, host=os.getenv("HOST", "localhost"), port=os.getenv("PORT", 8000))

0 commit comments

Comments
 (0)