forked from abetlen/llama-cpp-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path__main__.py
More file actions
127 lines (106 loc) · 4.03 KB
/
__main__.py
File metadata and controls
127 lines (106 loc) · 4.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""Example FastAPI server for llama.cpp.
To run this example:
```bash
pip install fastapi uvicorn sse-starlette pydantic-settings
export MODEL=../models/7B/...
```
Then run:
```
uvicorn llama_cpp.server.app:app --reload
```
or
```
python3 -m llama_cpp.server
```
Then visit http://localhost:8000/docs to see the interactive API docs.
"""
import os
import sys
import argparse
from typing import List, Literal, Union
import uvicorn
from llama_cpp import __version__
from llama_cpp.server.app import create_app
from llama_cpp.server.settings import Settings, ServerSettings, set_settings
from llama_cpp.server.model import set_llama
from llama_cpp.server.plugins import import_plugins
EXE_NAME = 'llama_server'
def get_base_type(annotation):
if getattr(annotation, '__origin__', None) is Literal:
return type(annotation.__args__[0])
elif getattr(annotation, '__origin__', None) is Union:
non_optional_args = [arg for arg in annotation.__args__ if arg is not type(None)]
if non_optional_args:
return get_base_type(non_optional_args[0])
elif getattr(annotation, '__origin__', None) is list or getattr(annotation, '__origin__', None) is List:
return get_base_type(annotation.__args__[0])
else:
return annotation
def contains_list_type(annotation) -> bool:
origin = getattr(annotation, '__origin__', None)
if origin is list or origin is List:
return True
elif origin in (Literal, Union):
return any(contains_list_type(arg) for arg in annotation.__args__)
else:
return False
def parse_bool_arg(arg):
if isinstance(arg, bytes):
arg = arg.decode('utf-8')
true_values = {'1', 'on', 't', 'true', 'y', 'yes'}
false_values = {'0', 'off', 'f', 'false', 'n', 'no'}
arg_str = str(arg).lower().strip()
if arg_str in true_values:
return True
elif arg_str in false_values:
return False
else:
raise ValueError(f'Invalid boolean argument: {arg}')
def main():
description = "🦙 Llama.cpp python server. Host your own LLMs!🚀"
parser = argparse.ArgumentParser(EXE_NAME, description=description)
for name, field in (ServerSettings.model_fields|Settings.model_fields).items():
description = field.description
if field.default and description and not field.is_required():
description += f" (default: {field.default})"
base_type = get_base_type(field.annotation) if field.annotation is not None else str
list_type = contains_list_type(field.annotation)
if base_type is not bool:
parser.add_argument(
f"--{name}",
dest=name,
nargs="*" if list_type else None,
type=base_type,
help=description,
)
if base_type is bool:
parser.add_argument(
f"--{name}",
dest=name,
type=parse_bool_arg,
help=f"{description}",
)
args = parser.parse_args()
try:
server_settings = ServerSettings(**{k: v for k, v in vars(args).items() if v is not None})
set_settings(server_settings)
if server_settings.config and os.path.exists(server_settings.config):
with open(server_settings.config, 'rb') as f:
llama_settings = Settings.model_validate_json(f.read())
else:
llama_settings = Settings(**{k: v for k, v in vars(args).items() if v is not None})
lifespan = "" if server_settings.plugins else None
set_llama(llama_settings)
app = create_app(title="🦙 llama.cpp Python API", version=__version__)
except Exception as e:
print(e, file=sys.stderr)
parser.print_help()
sys.exit(1)
if server_settings.plugins and os.path.isdir(server_settings.plugins):
for plugin in import_plugins(server_settings.plugins):
app = plugin().init(app)
uvicorn.run(
app, host=server_settings.host, port=server_settings.port
)
if __name__ == "__main__":
main()