forked from lancedb/lancedb
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcohere.py
More file actions
146 lines (119 loc) · 5.2 KB
/
cohere.py
File metadata and controls
146 lines (119 loc) · 5.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# Copyright (c) 2023. LanceDB Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import ClassVar, List, Union
import numpy as np
from ..util import attempt_import_or_raise
from .base import TextEmbeddingFunction
from .registry import register
from .utils import api_key_not_found_help, TEXT
@register("cohere")
class CohereEmbeddingFunction(TextEmbeddingFunction):
"""
An embedding function that uses the Cohere API
https://docs.cohere.com/docs/multilingual-language-models
Parameters
----------
name: str, default "embed-multilingual-v2.0"
The name of the model to use. List of acceptable models:
* embed-english-v3.0
* embed-multilingual-v3.0
* embed-english-light-v3.0
* embed-multilingual-light-v3.0
* embed-english-v2.0
* embed-english-light-v2.0
* embed-multilingual-v2.0
source_input_type: str, default "search_document"
The input type for the source column in the database
query_input_type: str, default "search_query"
The input type for the query column in the database
Cohere supports following input types:
| Input Type | Description |
|-------------------------|---------------------------------------|
| "`search_document`" | Used for embeddings stored in a vector|
| | database for search use-cases. |
| "`search_query`" | Used for embeddings of search queries |
| | run against a vector DB |
| "`semantic_similarity`" | Specifies the given text will be used |
| | for Semantic Textual Similarity (STS) |
| "`classification`" | Used for embeddings passed through a |
| | text classifier. |
| "`clustering`" | Used for the embeddings run through a |
| | clustering algorithm |
Examples
--------
import lancedb
from lancedb.pydantic import LanceModel, Vector
from lancedb.embeddings import EmbeddingFunctionRegistry
cohere = EmbeddingFunctionRegistry
.get_instance()
.get("cohere")
.create(name="embed-multilingual-v2.0")
class TextModel(LanceModel):
text: str = cohere.SourceField()
vector: Vector(cohere.ndims()) = cohere.VectorField()
data = [ { "text": "hello world" },
{ "text": "goodbye world" }]
db = lancedb.connect("~/.lancedb")
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(data)
"""
name: str = "embed-multilingual-v2.0"
source_input_type: str = "search_document"
query_input_type: str = "search_query"
client: ClassVar = None
def ndims(self):
# TODO: fix hardcoding
if self.name in [
"embed-english-v3.0",
"embed-multilingual-v3.0",
"embed-english-light-v2.0",
]:
return 1024
elif self.name in ["embed-english-light-v3.0", "embed-multilingual-light-v3.0"]:
return 384
elif self.name == "embed-english-v2.0":
return 4096
elif self.name == "embed-multilingual-v2.0":
return 768
else:
raise ValueError(f"Model {self.name} not supported")
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
return self.compute_source_embeddings(query, input_type=self.query_input_type)
def compute_source_embeddings(self, texts: TEXT, *args, **kwargs) -> List[np.array]:
texts = self.sanitize_input(texts)
input_type = (
kwargs.get("input_type") or self.source_input_type
) # assume source input type if not passed by `compute_query_embeddings`
return self.generate_embeddings(texts, input_type=input_type)
def generate_embeddings(
self, texts: Union[List[str], np.ndarray], *args, **kwargs
) -> List[np.array]:
"""
Get the embeddings for the given texts
Parameters
----------
texts: list[str] or np.ndarray (of str)
The texts to embed
"""
self._init_client()
rs = CohereEmbeddingFunction.client.embed(
texts=texts, model=self.name, **kwargs
)
return [emb for emb in rs.embeddings]
def _init_client(self):
cohere = attempt_import_or_raise("cohere")
if CohereEmbeddingFunction.client is None:
if os.environ.get("COHERE_API_KEY") is None:
api_key_not_found_help("cohere")
CohereEmbeddingFunction.client = cohere.Client(os.environ["COHERE_API_KEY"])