This repository was archived by the owner on Jun 5, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 91
Expand file tree
/
Copy pathmodels.py
More file actions
298 lines (228 loc) · 7.24 KB
/
models.py
File metadata and controls
298 lines (228 loc) · 7.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
import datetime
from enum import Enum
from typing import Annotated, Any, Dict, List, Optional
import numpy as np
from pydantic import BaseModel, BeforeValidator, ConfigDict, PlainSerializer, StringConstraints
class AlertSeverity(str, Enum):
INFO = "info"
CRITICAL = "critical"
class Alert(BaseModel):
id: str
prompt_id: str
code_snippet: Optional[str]
trigger_string: Optional[str]
trigger_type: str
trigger_category: AlertSeverity
timestamp: datetime.datetime
class Output(BaseModel):
id: Any
prompt_id: Any
timestamp: Any
output: Any
input_tokens: Optional[int] = None
output_tokens: Optional[int] = None
input_cost: Optional[float] = None
output_cost: Optional[float] = None
class Prompt(BaseModel):
id: Any
timestamp: Any
provider: Optional[Any]
request: Any
type: Any
workspace_id: Optional[str]
class TokenUsage(BaseModel):
"""
TokenUsage it's not a table, it's a model to represent the token usage.
The data is stored in the outputs table.
"""
input_tokens: int = 0
output_tokens: int = 0
input_cost: float = 0
output_cost: float = 0
@classmethod
def from_dict(cls, usage_dict: Dict) -> "TokenUsage":
return cls(
input_tokens=usage_dict.get("prompt_tokens", 0) or usage_dict.get("input_tokens", 0),
output_tokens=usage_dict.get("completion_tokens", 0)
or usage_dict.get("output_tokens", 0),
input_cost=0,
output_cost=0,
)
@classmethod
def from_db(
cls,
input_tokens: Optional[int],
output_tokens: Optional[int],
input_cost: Optional[float],
output_cost: Optional[float],
) -> "TokenUsage":
return cls(
input_tokens=0 if not input_tokens else input_tokens,
output_tokens=0 if not output_tokens else output_tokens,
input_cost=0 if not input_cost else input_cost,
output_cost=0 if not output_cost else output_cost,
)
def __add__(self, other: "TokenUsage") -> "TokenUsage":
return TokenUsage(
input_tokens=self.input_tokens + other.input_tokens,
output_tokens=self.output_tokens + other.output_tokens,
input_cost=self.input_cost + other.input_cost,
output_cost=self.output_cost + other.output_cost,
)
WorkspaceNameStr = Annotated[
str,
StringConstraints(
strip_whitespace=True, to_lower=True, pattern=r"^[a-zA-Z0-9_-]+$", strict=True
),
]
class WorkspaceRow(BaseModel):
"""A workspace row entry.
Since our model currently includes instructions
in the same table, this is returned as a single
object.
"""
id: str
name: WorkspaceNameStr
custom_instructions: Optional[str]
class GetWorkspaceByNameConditions(BaseModel):
name: WorkspaceNameStr
def get_conditions(self):
return {"name": self.name}
class Session(BaseModel):
id: str
active_workspace_id: str
last_update: datetime.datetime
# Models for select queries
class ProviderType(str, Enum):
"""
Represents the different types of providers we support.
"""
openai = "openai"
anthropic = "anthropic"
vllm = "vllm"
ollama = "ollama"
lm_studio = "lm_studio"
llamacpp = "llamacpp"
openrouter = "openrouter"
class IntermediatePromptWithOutputUsageAlerts(BaseModel):
"""
An intermediate model to represent the result of a query
for a prompt and related outputs, usage stats & alerts.
"""
prompt_id: Any
prompt_timestamp: Any
provider: Optional[Any]
request: Any
type: Any
output_id: Optional[Any]
output: Optional[Any]
output_timestamp: Optional[Any]
input_tokens: Optional[int]
output_tokens: Optional[int]
input_cost: Optional[float]
output_cost: Optional[float]
alert_id: Optional[Any]
code_snippet: Optional[Any]
trigger_string: Optional[Any]
trigger_type: Optional[Any]
trigger_category: Optional[Any]
alert_timestamp: Optional[Any]
class GetPromptWithOutputsRow(BaseModel):
id: Any
timestamp: Any
provider: Optional[Any]
request: Any
type: Any
output_id: Optional[Any]
output: Optional[Any]
output_timestamp: Optional[Any]
input_tokens: Optional[int]
output_tokens: Optional[int]
input_cost: Optional[float]
output_cost: Optional[float]
alerts: List[Alert] = []
class WorkspaceWithSessionInfo(BaseModel):
"""Returns a workspace ID with an optional
session ID. If the session ID is None, then
the workspace is not active.
"""
id: str
name: WorkspaceNameStr
session_id: Optional[str]
class WorkspaceWithModel(BaseModel):
"""Returns a workspace ID with model name"""
id: str
name: WorkspaceNameStr
provider_model_name: str
class ActiveWorkspace(BaseModel):
"""Returns a full active workspace object with the
with the session information.
"""
id: str
name: WorkspaceNameStr
custom_instructions: Optional[str]
session_id: str
last_update: datetime.datetime
class ProviderEndpoint(BaseModel):
id: str
name: str
description: str
provider_type: str
endpoint: str
auth_type: str
class ProviderAuthMaterial(BaseModel):
provider_endpoint_id: str
auth_type: str
auth_blob: str
class ProviderModel(BaseModel):
provider_endpoint_id: str
provider_endpoint_name: Optional[str] = None
name: str
class MuxRule(BaseModel):
id: str
provider_endpoint_id: str
provider_model_name: str
workspace_id: str
matcher_type: str
matcher_blob: str
priority: int
created_at: Optional[datetime.datetime] = None
updated_at: Optional[datetime.datetime] = None
def nd_array_custom_before_validator(x):
# custome before validation logic
return x
def nd_array_custom_serializer(x):
# custome serialization logic
return str(x)
# Pydantic doesn't support numpy arrays out of the box hence we need to construct a custom type.
# There are 2 things necessary for a Pydantic custom type: Validator and Serializer
# The lines below build our custom type
# Docs: https://docs.pydantic.dev/latest/concepts/types/#adding-validation-and-serialization
# Open Pydantic issue for npy support: https://github.com/pydantic/pydantic/issues/7017
NdArray = Annotated[
np.ndarray,
BeforeValidator(nd_array_custom_before_validator),
PlainSerializer(nd_array_custom_serializer, return_type=str),
]
class Persona(BaseModel):
"""
Represents a persona object.
"""
id: str
name: str
description: str
class PersonaEmbedding(Persona):
"""
Represents a persona object with an embedding.
"""
description_embedding: NdArray
# Part of the workaround to allow numpy arrays in pydantic models
model_config = ConfigDict(arbitrary_types_allowed=True)
class PersonaDistance(Persona):
"""
Result of an SQL query to get the distance between the query and the persona description.
A vector similarity search is performed to get the distance. Distance values ranges [0, 2].
0 means the vectors are identical, 2 means they are orthogonal.
See [sqlite docs](https://alexgarcia.xyz/sqlite-vec/api-reference.html#vec_distance_cosine)
"""
distance: float