forked from stacklok/codegate
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathstorage_engine.py
More file actions
132 lines (115 loc) · 4.7 KB
/
storage_engine.py
File metadata and controls
132 lines (115 loc) · 4.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import structlog
import weaviate
from weaviate.classes.config import DataType
from weaviate.classes.query import MetadataQuery
from weaviate.embedded import EmbeddedOptions
from codegate.config import Config
from codegate.inference.inference_engine import LlamaCppInferenceEngine
logger = structlog.get_logger("codegate")
schema_config = [
{
"name": "Package",
"properties": [
{"name": "name", "data_type": DataType.TEXT},
{"name": "type", "data_type": DataType.TEXT},
{"name": "status", "data_type": DataType.TEXT},
{"name": "description", "data_type": DataType.TEXT},
],
},
]
class StorageEngine:
def get_client(self, data_path):
try:
# Get current config
config = Config.get_config()
# Configure Weaviate logging
additional_env_vars = {
# Basic logging configuration
"LOG_FORMAT": config.log_format.value.lower(),
"LOG_LEVEL": config.log_level.value.lower(),
# Disable colored output
"LOG_FORCE_COLOR": "false",
# Configure JSON format
"LOG_JSON_FIELDS": "timestamp, level,message",
# Configure text format
"LOG_METHOD": config.log_format.value.lower(),
"LOG_LEVEL_IN_UPPER": "false", # Keep level lowercase like codegate format
# Disable additional fields
"LOG_GIT_HASH": "false",
"LOG_VERSION": "false",
"LOG_BUILD_INFO": "false",
}
client = weaviate.WeaviateClient(
embedded_options=EmbeddedOptions(
persistence_data_path=data_path,
additional_env_vars=additional_env_vars,
),
)
return client
except Exception as e:
logger.error(f"Error during client creation: {str(e)}")
return None
def __init__(self, data_path="./weaviate_data"):
self.data_path = data_path
self.inference_engine = LlamaCppInferenceEngine()
self.model_path = "./models/all-minilm-L6-v2-q5_k_m.gguf"
self.schema_config = schema_config
# setup schema for weaviate
weaviate_client = self.get_client(self.data_path)
if weaviate_client is not None:
try:
weaviate_client.connect()
self.setup_schema(weaviate_client)
except Exception as e:
logger.error(f"Failed to connect or setup schema: {str(e)}")
finally:
try:
weaviate_client.close()
except Exception as e:
logger.error(f"Failed to close client: {str(e)}")
else:
logger.error("Could not find client, skipping schema setup.")
def setup_schema(self, client):
for class_config in self.schema_config:
if not client.collections.exists(class_config["name"]):
client.collections.create(
class_config["name"], properties=class_config["properties"]
)
logger.info(f"Weaviate schema for class {class_config['name']} setup complete.")
async def search(self, query: str, limit=5, distance=0.3) -> list[object]:
"""
Search the 'Package' collection based on a query string.
Args:
query (str): The text query for which to search.
limit (int): The number of results to return.
Returns:
list: A list of matching results with their properties and distances.
"""
# Generate the vector for the query
query_vector = await self.inference_engine.embed(self.model_path, [query])
# Perform the vector search
weaviate_client = self.get_client(self.data_path)
if weaviate_client is None:
logger.error("Could not find client, not returning results.")
return []
try:
weaviate_client.connect()
collection = weaviate_client.collections.get("Package")
response = collection.query.near_vector(
query_vector[0],
limit=limit,
distance=distance,
return_metadata=MetadataQuery(distance=True),
)
weaviate_client.close()
if not response:
return []
return response.objects
except Exception as e:
logger.error(f"Error during search: {str(e)}")
return []
finally:
try:
weaviate_client.close()
except Exception as e:
logger.error(f"Failed to close client: {str(e)}")