-
Notifications
You must be signed in to change notification settings - Fork 31
Expand file tree
/
Copy pathapi_utils.py
More file actions
166 lines (142 loc) · 4.95 KB
/
api_utils.py
File metadata and controls
166 lines (142 loc) · 4.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import base64
import os
import requests
import traceback
import time
import json
import numpy as np
from openai import OpenAI
from openai import AuthenticationError
from config import api_config
from utils.constants import *
def _normalize_base_url(raw_url: str, default_url: str, service_name: str) -> str:
url = (raw_url or "").strip() or default_url
if not url:
raise ValueError(f"{service_name} API URL 为空,请在 Settings 中配置。")
if not url.startswith(("http://", "https://")):
url = "https://" + url
return url
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def vlm_generate(
prompt,
image,
key=VLM_API_KEY,
url=VLM_API_URL,
model=VLM_MODEL_TYPE,
temperature=0.5,
):
key = (key or "").strip() or VLM_API_KEY
model = (model or "").strip() or VLM_MODEL_TYPE
url = _normalize_base_url(url, VLM_API_URL, "VLM")
if os.path.exists(image):
image = f"data:image/jpeg;base64,{encode_image(image)}"
client = OpenAI(api_key=key, base_url=url)
# cnt = 0
# while cnt < 20:
# try:
chat_response = client.chat.completions.create(
model=model,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": image}},
],
}
],
temperature=temperature
)
chat_response = chat_response.choices[0].message.content
# except Exception as e:
# print(f"VLM API Request Failed! Retry {cnt}!")
# traceback.print_exc()
# import time; time.sleep(0.1)
# cnt += 1
return chat_response
def llm_generate(
prompt,
key=LLM_API_KEY,
url=LLM_API_URL,
model=LLM_MODEL_TYPE,
max_tokens=8192,
temperature=0.5
):
key = (key or "").strip() or LLM_API_KEY
model = (model or "").strip() or LLM_MODEL_TYPE
url = _normalize_base_url(url, LLM_API_URL, "LLM")
client = OpenAI(api_key=key, base_url=url)
res = "None"
cnt = 0
while cnt < 20:
try:
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": "You are a helpful assistant skilled in handling tabular data."},
{"role": "user", "content": prompt},
],
max_tokens=max_tokens,
temperature=temperature,
stream=False
)
res = response.choices[0].message.content
break
except AuthenticationError as e:
# 对于认证错误(401),立即失败,不重试
print(f"LLM API Authentication Failed! Invalid API key.")
print(f"Error: {str(e)}")
break
except Exception as e:
# 对于其他错误,可以重试
print(f"LLM API Request Failed! Retry {cnt}!")
traceback.print_exc()
import time; time.sleep(0.1)
cnt += 1
return res
def embedding_generate(
input_texts: list,
key=None,
url=None,
model=None,
dimensions=1024,
):
key = key or api_config.get("embedding_api_key") or EMBEDDING_API_KEY
raw_url = url or api_config.get("embedding_api_url") or EMBEDDING_API_URL
url = _normalize_base_url(raw_url, EMBEDDING_API_URL, "Embedding")
model = model or api_config.get("embedding_model") or EMBEDDING_MODEL_TYPE
client = OpenAI(api_key=key, base_url=url)
embeddings = []
for i in range(0, len(input_texts), 10):
inputs = input_texts[i : i + 10]
cnt = 0
while cnt < 20:
try:
response = client.embeddings.create(
model=model,
input=inputs,
dimensions=dimensions
)
res = json.loads(response.model_dump_json())["data"]
embeddings.extend([x["embedding"] for x in res])
break
except AuthenticationError as e:
# 对于认证错误(401),立即失败,不重试
print(f"EMBEDDING API Authentication Failed! Invalid API key.")
print(f"Error: {str(e)}")
break
except Exception as e:
# 对于其他错误,可以重试
print(f"EMBEDDING API Request Failed! Retry {cnt}!")
traceback.print_exc()
import time; time.sleep(0.1)
cnt += 1
return np.array(embeddings)
def main():
# print(llm_generate("Tell something about your!"))
# print(vlm_generate("Tell something about the image!", "/mnt/petrelfs/tangzirui/ST-Raptor/assets/examples.png"))
print(embedding_generate(["123", "345", "678"]))
if __name__ == "__main__":
main()