forked from lancedb/lancedb
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbase.py
More file actions
240 lines (205 loc) · 8.17 KB
/
base.py
File metadata and controls
240 lines (205 loc) · 8.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from packaging.version import Version
from typing import Union, List, TYPE_CHECKING
import numpy as np
import pyarrow as pa
if TYPE_CHECKING:
from ..table import LanceVectorQueryBuilder
ARROW_VERSION = Version(pa.__version__)
class Reranker(ABC):
def __init__(self, return_score: str = "relevance"):
"""
Interface for a reranker. A reranker is used to rerank the results from a
vector and FTS search. This is useful for combining the results from both
search methods.
Parameters
----------
return_score : str, default "relevance"
opntions are "relevance" or "all"
The type of score to return. If "relevance", will return only the relevance
score. If "all", will return all scores from the vector and FTS search along
with the relevance score.
"""
if return_score not in ["relevance", "all"]:
raise ValueError("score must be either 'relevance' or 'all'")
self.score = return_score
# Set the merge args based on the arrow version here to avoid checking it at
# each query
self._concat_tables_args = {"promote_options": "default"}
if ARROW_VERSION.major <= 13:
self._concat_tables_args = {"promote": True}
def rerank_vector(
self,
query: str,
vector_results: pa.Table,
):
"""
Rerank function receives the result from the vector search.
This isn't mandatory to implement
Parameters
----------
query : str
The input query
vector_results : pa.Table
The results from the vector search
Returns
-------
pa.Table
The reranked results
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement rerank_vector"
)
def rerank_fts(
self,
query: str,
fts_results: pa.Table,
):
"""
Rerank function receives the result from the FTS search.
This isn't mandatory to implement
Parameters
----------
query : str
The input query
fts_results : pa.Table
The results from the FTS search
Returns
-------
pa.Table
The reranked results
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement rerank_fts"
)
@abstractmethod
def rerank_hybrid(
self,
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
) -> pa.Table:
"""
Rerank function receives the individual results from the vector and FTS search
results. You can choose to use any of the results to generate the final results,
allowing maximum flexibility. This is mandatory to implement
Parameters
----------
query : str
The input query
vector_results : pa.Table
The results from the vector search
fts_results : pa.Table
The results from the FTS search
Returns
-------
pa.Table
The reranked results
"""
pass
def merge_results(self, vector_results: pa.Table, fts_results: pa.Table):
"""
Merge the results from the vector and FTS search. This is a vanilla merging
function that just concatenates the results and removes the duplicates.
NOTE: This doesn't take score into account. It'll keep the instance that was
encountered first. This is designed for rerankers that don't use the score.
In case you want to use the score, or support `return_scores="all"` you'll
have to implement your own merging function.
Parameters
----------
vector_results : pa.Table
The results from the vector search
fts_results : pa.Table
The results from the FTS search
"""
combined = pa.concat_tables(
[vector_results, fts_results], **self._concat_tables_args
)
# deduplicate
combined = self._deduplicate(combined)
return combined
def rerank_multivector(
self,
vector_results: Union[List[pa.Table], List["LanceVectorQueryBuilder"]],
query: Union[str, None], # Some rerankers might not need the query
deduplicate: bool = False,
):
"""
This is a rerank function that receives the results from multiple
vector searches. For example, this can be used to combine the
results of two vector searches with different embeddings.
Parameters
----------
vector_results : List[pa.Table] or List[LanceVectorQueryBuilder]
The results from the vector search. Either accepts the query builder
if the results haven't been executed yet or the results in arrow format.
query : str or None,
The input query. Some rerankers might not need the query to rerank.
In that case, it can be set to None explicitly. This is inteded to
be handled by the reranker implementations.
deduplicate : bool, optional
Whether to deduplicate the results based on the `_rowid` column,
by default False. Requires `_rowid` to be present in the results.
Returns
-------
pa.Table
The reranked results
"""
vector_results = (
[vector_results] if not isinstance(vector_results, list) else vector_results
)
# Make sure all elements are of the same type
if not all(isinstance(v, type(vector_results[0])) for v in vector_results):
raise ValueError(
"All elements in vector_results should be of the same type"
)
# avoids circular import
if type(vector_results[0]).__name__ == "LanceVectorQueryBuilder":
vector_results = [result.to_arrow() for result in vector_results]
elif not isinstance(vector_results[0], pa.Table):
raise ValueError(
"vector_results should be a list of pa.Table or LanceVectorQueryBuilder"
)
combined = pa.concat_tables(vector_results, **self._concat_tables_args)
reranked = self.rerank_vector(query, combined)
# TODO: Allow custom deduplicators here.
# currently, this'll just keep the first instance.
if deduplicate:
if "_rowid" not in combined.column_names:
raise ValueError(
"'_rowid' is required for deduplication. \
add _rowid to search results like this: \
`search().with_row_id(True)`"
)
reranked = self._deduplicate(reranked)
return reranked
def _deduplicate(self, table: pa.Table):
"""
Deduplicate the table based on the `_rowid` column.
"""
row_id = table.column("_rowid")
# deduplicate
mask = np.full((table.shape[0]), False)
_, mask_indices = np.unique(np.array(row_id), return_index=True)
mask[mask_indices] = True
deduped_table = table.filter(mask=mask)
return deduped_table
def _keep_relevance_score(self, combined_results: pa.Table):
if self.score == "relevance":
if "_score" in combined_results.column_names:
combined_results = combined_results.drop_columns(["_score"])
if "_distance" in combined_results.column_names:
combined_results = combined_results.drop_columns(["_distance"])
return combined_results