Skip to content

Commit 10a2d32

Browse files
committed
handle individual model config json
1 parent 3c4b526 commit 10a2d32

File tree

2 files changed

+29
-17
lines changed

2 files changed

+29
-17
lines changed

llama_cpp/server/model.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from threading import Lock
44
import logging
55
import llama_cpp
6-
from llama_cpp.server.settings import Settings, get_settings
6+
from llama_cpp.server.settings import Settings, ModelSettings, get_settings
77

88
FILE_EXT = ".gguf"
99
MODEL_ENV_ARG = "MODEL"
@@ -29,22 +29,30 @@ def __init__(self, settings: Settings) -> None:
2929
if os.path.isfile(settings.model):
3030
self(settings.model.split(os.path.sep)[-1].split(FILE_EXT)[0])
3131

32-
def __call__(self, model: str, **kwargs: Any) -> llama_cpp.Llama:
32+
def __call__(self, model: Optional[str] = None) -> llama_cpp.Llama:
33+
# handle backward compatibility, model param optional
3334
try:
3435
model_path = self._models[model]
3536
except KeyError:
3637
if self._model:
37-
if self._settings.verbose: logger.info(f"Model file for {model} NOT found! Using preloaded")
38+
if self._settings.verbose: logger.warn(f"Model file for {model} NOT found! Using preloaded")
3839
return self._model
3940
else: raise Exception(404, f"Model file for {model} NOT found")
40-
4141

4242
if self._model:
4343
if self._model.model_path == model_path:
4444
return self._model
4545
del self._model
4646

47-
settings = self._settings
47+
settings_path = os.path.join(os.path.dirname(model_path),
48+
model_path.split(os.path.sep)[-1].split(FILE_EXT)[0] + ".json")
49+
try:
50+
with open(settings_path, 'rb') as f:
51+
settings = ModelSettings.model_validate_json(f.read())
52+
except Exception as e:
53+
if self._settings.verbose: logger.warn(f"Loading settings for {model} FAILED! Using default")
54+
settings = self._settings
55+
4856
self._model = llama_cpp.Llama(
4957
model_path=model_path,
5058
# Model Params
@@ -88,14 +96,13 @@ def __call__(self, model: str, **kwargs: Any) -> llama_cpp.Llama:
8896
cache_size=settings.cache_size,
8997
# Misc
9098
verbose=settings.verbose,
91-
**kwargs
9299
)
93100
return self._model
94101

95-
def __getitem__(self, model):
102+
def __getitem__(self, model: str) -> str:
96103
return self._models[model]
97104

98-
def __setitem__(self, model, path):
105+
def __setitem__(self, model: str, path: str):
99106
self._models[model] = path
100107

101108
def __iter__(self):

llama_cpp/server/settings.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,13 @@
11
import multiprocessing
22
from typing import Optional, List, Literal
33
from pydantic import Field
4-
from pydantic_settings import BaseSettings
4+
from pydantic_settings import BaseSettings, SettingsConfigDict
55
import llama_cpp
66

77
# Disable warning for model and model_alias settings
88
BaseSettings.model_config['protected_namespaces'] = ()
99

10-
class Settings(BaseSettings):
11-
model: str = Field(
12-
description="The path to the model to use for generating completions."
13-
)
14-
model_alias: Optional[str] = Field(
15-
default=None,
16-
description="The alias of the model to use for generating completions.",
17-
)
10+
class ModelSettings(BaseSettings):
1811
# Model Params
1912
n_gpu_layers: int = Field(
2013
default=0,
@@ -133,6 +126,9 @@ class Settings(BaseSettings):
133126
verbose: bool = Field(
134127
default=True, description="Whether to print debug information."
135128
)
129+
130+
class ServerSettings(BaseSettings):
131+
model_config = SettingsConfigDict(env_file='.env')
136132
# Server Params
137133
host: str = Field(default="localhost", description="Listen address")
138134
port: int = Field(default=8000, description="Listen port")
@@ -141,6 +137,15 @@ class Settings(BaseSettings):
141137
description="Whether to interrupt requests when a new request is received.",
142138
)
143139

140+
class Settings(ModelSettings, ServerSettings):
141+
model: str = Field(
142+
description="The path to the model to use for generating completions."
143+
)
144+
model_alias: Optional[str] = Field(
145+
default=None,
146+
description="The alias of the model to use for generating completions.",
147+
)
148+
144149
SETTINGS: Optional[Settings] = None
145150

146151
def set_settings(settings: Settings):

0 commit comments

Comments
 (0)