forked from lancedb/lancedb
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathregistry.py
More file actions
185 lines (156 loc) · 5.99 KB
/
registry.py
File metadata and controls
185 lines (156 loc) · 5.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The LanceDB Authors
import json
from typing import Dict, Optional
from .base import EmbeddingFunction, EmbeddingFunctionConfig
class EmbeddingFunctionRegistry:
"""
This is a singleton class used to register embedding functions
and fetch them by name. It also handles serializing and deserializing.
You can implement your own embedding function by subclassing EmbeddingFunction
or TextEmbeddingFunction and registering it with the registry.
NOTE: Here TEXT is a type alias for Union[str, List[str], pa.Array,
pa.ChunkedArray, np.ndarray]
Examples
--------
>>> registry = EmbeddingFunctionRegistry.get_instance()
>>> @registry.register("my-embedding-function")
... class MyEmbeddingFunction(EmbeddingFunction):
... def ndims(self) -> int:
... return 128
...
... def compute_query_embeddings(self, query: str, *args, **kwargs):
... return self.compute_source_embeddings(query, *args, **kwargs)
...
... def compute_source_embeddings(self, texts, *args, **kwargs):
... return [np.random.rand(self.ndims()) for _ in range(len(texts))]
...
>>> registry.get("my-embedding-function")
<class 'lancedb.embeddings.registry.MyEmbeddingFunction'>
"""
@classmethod
def get_instance(cls):
return __REGISTRY__
def __init__(self):
self._functions = {}
def register(self, alias: str = None):
"""
This creates a decorator that can be used to register
an EmbeddingFunction.
Parameters
----------
alias : Optional[str]
a human friendly name for the embedding function. If not
provided, the class name will be used.
"""
# This is a decorator for a class that inherits from BaseModel
# It adds the class to the registry
def decorator(cls):
if not issubclass(cls, EmbeddingFunction):
raise TypeError("Must be a subclass of EmbeddingFunction")
if cls.__name__ in self._functions:
raise KeyError(f"{cls.__name__} was already registered")
key = alias or cls.__name__
self._functions[key] = cls
cls.__embedding_function_registry_alias__ = alias
return cls
return decorator
def reset(self):
"""
Reset the registry to its initial state
"""
self._functions = {}
def get(self, name: str):
"""
Fetch an embedding function class by name
Parameters
----------
name : str
The name of the embedding function to fetch
Either the alias or the class name if no alias was provided
during registration
"""
return self._functions[name]
def parse_functions(
self, metadata: Optional[Dict[bytes, bytes]]
) -> Dict[str, "EmbeddingFunctionConfig"]:
"""
Parse the metadata from an arrow table and
return a mapping of the vector column to the
embedding function and source column
Parameters
----------
metadata : Optional[Dict[bytes, bytes]]
The metadata from an arrow table. Note that
the keys and values are bytes (pyarrow api)
Returns
-------
functions : dict
A mapping of vector column name to embedding function.
An empty dict is returned if input is None or does not
contain b"embedding_functions".
"""
if metadata is None:
return {}
# Look at both bytes and string keys, since we might use either
serialized = metadata.get(
b"embedding_functions", metadata.get("embedding_functions")
)
if serialized is None:
return {}
raw_list = json.loads(serialized.decode("utf-8"))
return {
obj["vector_column"]: EmbeddingFunctionConfig(
vector_column=obj["vector_column"],
source_column=obj["source_column"],
function=self.get(obj["name"])(**obj["model"]),
)
for obj in raw_list
}
def function_to_metadata(self, conf: "EmbeddingFunctionConfig"):
"""
Convert the given embedding function and source / vector column configs
into a config dictionary that can be serialized into arrow metadata
"""
func = conf.function
name = getattr(
func, "__embedding_function_registry_alias__", func.__class__.__name__
)
json_data = func.safe_model_dump()
return {
"name": name,
"model": json_data,
"source_column": conf.source_column,
"vector_column": conf.vector_column,
}
def get_table_metadata(self, func_list):
"""
Convert a list of embedding functions and source / vector configs
into a config dictionary that can be serialized into arrow metadata
"""
if func_list is None or len(func_list) == 0:
return None
json_data = [self.function_to_metadata(func) for func in func_list]
# Note that metadata dictionary values must be bytes
# so we need to json dump then utf8 encode
metadata = json.dumps(json_data, indent=2).encode("utf-8")
return {"embedding_functions": metadata}
# Global instance
__REGISTRY__ = EmbeddingFunctionRegistry()
# @EmbeddingFunctionRegistry.get_instance().register(name) doesn't work in 3.8
def register(name):
return __REGISTRY__.get_instance().register(name)
def get_registry() -> EmbeddingFunctionRegistry:
"""
Utility function to get the global instance of the registry
Returns
-------
EmbeddingFunctionRegistry
The global registry instance
Examples
--------
from lancedb.embeddings import get_registry
registry = get_registry()
openai = registry.get("openai").create()
"""
return __REGISTRY__.get_instance()