|
1 | | -from django.contrib.postgres.operations import CreateExtension |
2 | | -from django.contrib.postgres.indexes import PostgresIndex |
3 | | -from django.db.models import Field, FloatField, Func, Value |
4 | | -import numpy as np |
5 | | -from .forms import VectorFormField |
6 | | -from ..utils import from_db, to_db |
7 | | - |
8 | | -__all__ = ['VectorExtension', 'VectorField', 'IvfflatIndex', 'HnswIndex', 'L2Distance', 'MaxInnerProduct', 'CosineDistance'] |
9 | | - |
10 | | - |
11 | | -class VectorExtension(CreateExtension): |
12 | | - def __init__(self): |
13 | | - self.name = 'vector' |
14 | | - |
15 | | - |
16 | | -# https://docs.djangoproject.com/en/4.2/howto/custom-model-fields/ |
17 | | -class VectorField(Field): |
18 | | - description = 'Vector' |
19 | | - empty_strings_allowed = False |
20 | | - |
21 | | - def __init__(self, *args, dimensions=None, **kwargs): |
22 | | - self.dimensions = dimensions |
23 | | - super().__init__(*args, **kwargs) |
24 | | - |
25 | | - def deconstruct(self): |
26 | | - name, path, args, kwargs = super().deconstruct() |
27 | | - if self.dimensions is not None: |
28 | | - kwargs['dimensions'] = self.dimensions |
29 | | - return name, path, args, kwargs |
30 | | - |
31 | | - def db_type(self, connection): |
32 | | - if self.dimensions is None: |
33 | | - return 'vector' |
34 | | - return 'vector(%d)' % self.dimensions |
35 | | - |
36 | | - def from_db_value(self, value, expression, connection): |
37 | | - return from_db(value) |
38 | | - |
39 | | - def to_python(self, value): |
40 | | - if isinstance(value, list): |
41 | | - return np.array(value, dtype=np.float32) |
42 | | - return from_db(value) |
43 | | - |
44 | | - def get_prep_value(self, value): |
45 | | - return to_db(value) |
46 | | - |
47 | | - def value_to_string(self, obj): |
48 | | - return self.get_prep_value(self.value_from_object(obj)) |
49 | | - |
50 | | - def validate(self, value, model_instance): |
51 | | - if isinstance(value, np.ndarray): |
52 | | - value = value.tolist() |
53 | | - super().validate(value, model_instance) |
54 | | - |
55 | | - def run_validators(self, value): |
56 | | - if isinstance(value, np.ndarray): |
57 | | - value = value.tolist() |
58 | | - super().run_validators(value) |
59 | | - |
60 | | - def formfield(self, **kwargs): |
61 | | - return super().formfield(form_class=VectorFormField, **kwargs) |
62 | | - |
63 | | - |
64 | | -class IvfflatIndex(PostgresIndex): |
65 | | - suffix = 'ivfflat' |
66 | | - |
67 | | - def __init__(self, *expressions, lists=None, **kwargs): |
68 | | - self.lists = lists |
69 | | - super().__init__(*expressions, **kwargs) |
70 | | - |
71 | | - def deconstruct(self): |
72 | | - path, args, kwargs = super().deconstruct() |
73 | | - if self.lists is not None: |
74 | | - kwargs['lists'] = self.lists |
75 | | - return path, args, kwargs |
76 | | - |
77 | | - def get_with_params(self): |
78 | | - with_params = [] |
79 | | - if self.lists is not None: |
80 | | - with_params.append('lists = %d' % self.lists) |
81 | | - return with_params |
82 | | - |
83 | | - |
84 | | -class HnswIndex(PostgresIndex): |
85 | | - suffix = 'hnsw' |
86 | | - |
87 | | - def __init__(self, *expressions, m=None, ef_construction=None, **kwargs): |
88 | | - self.m = m |
89 | | - self.ef_construction = ef_construction |
90 | | - super().__init__(*expressions, **kwargs) |
91 | | - |
92 | | - def deconstruct(self): |
93 | | - path, args, kwargs = super().deconstruct() |
94 | | - if self.m is not None: |
95 | | - kwargs['m'] = self.m |
96 | | - if self.ef_construction is not None: |
97 | | - kwargs['ef_construction'] = self.ef_construction |
98 | | - return path, args, kwargs |
99 | | - |
100 | | - def get_with_params(self): |
101 | | - with_params = [] |
102 | | - if self.m is not None: |
103 | | - with_params.append('m = %d' % self.m) |
104 | | - if self.ef_construction is not None: |
105 | | - with_params.append('ef_construction = %d' % self.ef_construction) |
106 | | - return with_params |
107 | | - |
108 | | - |
109 | | -class DistanceBase(Func): |
110 | | - output_field = FloatField() |
111 | | - |
112 | | - def __init__(self, expression, vector, **extra): |
113 | | - if not hasattr(vector, 'resolve_expression'): |
114 | | - vector = Value(to_db(vector)) |
115 | | - super().__init__(expression, vector, **extra) |
116 | | - |
117 | | - |
118 | | -class L2Distance(DistanceBase): |
119 | | - function = '' |
120 | | - arg_joiner = ' <-> ' |
121 | | - |
122 | | - |
123 | | -class MaxInnerProduct(DistanceBase): |
124 | | - function = '' |
125 | | - arg_joiner = ' <#> ' |
126 | | - |
127 | | - |
128 | | -class CosineDistance(DistanceBase): |
129 | | - function = '' |
130 | | - arg_joiner = ' <=> ' |
| 1 | +from .extensions import VectorExtension |
| 2 | +from .functions import L2Distance, MaxInnerProduct, CosineDistance, L1Distance |
| 3 | +from .halfvec import HalfvecField |
| 4 | +from .indexes import IvfflatIndex, HnswIndex |
| 5 | +from .sparsevec import SparsevecField |
| 6 | +from .vector import VectorField |
| 7 | +from ..utils import SparseVec |
| 8 | + |
| 9 | +__all__ = ['VectorExtension', 'VectorField', 'HalfvecField', 'SparsevecField', 'IvfflatIndex', 'HnswIndex', 'L2Distance', 'MaxInnerProduct', 'CosineDistance', 'L1Distance'] |
0 commit comments