Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 32 additions & 15 deletions docarray/helper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Dict, List, Type
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type

if TYPE_CHECKING:
from docarray import BaseDocument
Expand All @@ -8,21 +8,9 @@ def _is_access_path_valid(doc_type: Type['BaseDocument'], access_path: str) -> b
"""
Check if a given access path ("__"-separated) is a valid path for a given Document class.
"""
from docarray import BaseDocument

field, _, remaining = access_path.partition('__')
if len(remaining) == 0:
return access_path in doc_type.__fields__.keys()
else:
valid_field = field in doc_type.__fields__.keys()
if not valid_field:
return False
else:
d = doc_type._get_field_type(field)
if not issubclass(d, BaseDocument):
return False
else:
return _is_access_path_valid(d, remaining)
field_type = _get_field_type_by_access_path(doc_type, access_path)
return field_type is not None


def _all_access_paths_valid(
Expand Down Expand Up @@ -121,3 +109,32 @@ def _update_nested_dicts(
to_update[k] = v
else:
_update_nested_dicts(to_update[k], update_with[k])


def _get_field_type_by_access_path(
doc_type: Type['BaseDocument'], access_path: str
) -> Optional[Type]:
"""
Get field type by "__"-separated access path.
:param doc_type: type of document
:param access_path: "__"-separated access path
:return: field type of accessed attribute. If access path is invalid, return None.
"""
from docarray import BaseDocument, DocumentArray

field, _, remaining = access_path.partition('__')
field_valid = field in doc_type.__fields__.keys()

if field_valid:
if len(remaining) == 0:
return doc_type._get_field_type(field)
else:
d = doc_type._get_field_type(field)
if issubclass(d, DocumentArray):
return _get_field_type_by_access_path(d.document_type, remaining)
elif issubclass(d, BaseDocument):
return _get_field_type_by_access_path(d, remaining)
else:
return None
else:
return None
42 changes: 22 additions & 20 deletions docarray/utils/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from docarray.array.array.array import DocumentArray
from docarray.array.stacked.array_stacked import DocumentArrayStacked
from docarray.base_document import BaseDocument
from docarray.helper import _get_field_type_by_access_path
from docarray.typing import AnyTensor
from docarray.typing.tensor.abstract_tensor import AbstractTensor

Expand Down Expand Up @@ -192,8 +193,8 @@ class MyDocument(BaseDocument):
comp_backend = embedding_type.get_comp_backend()

# extract embeddings from query and index
index_embeddings = _extraxt_embeddings(index, embedding_field, embedding_type)
query_embeddings = _extraxt_embeddings(query, embedding_field, embedding_type)
index_embeddings = _extract_embeddings(index, embedding_field, embedding_type)
query_embeddings = _extract_embeddings(query, embedding_field, embedding_type)

# compute distances and return top results
metric_fn = getattr(comp_backend.Metrics, metric)
Expand Down Expand Up @@ -225,7 +226,7 @@ def _extract_embedding_single(
:return: the embeddings
"""
if isinstance(data, BaseDocument):
emb = getattr(data, embedding_field)
emb = next(AnyDocumentArray._traverse(data, embedding_field))
else: # treat data as tensor
emb = data
if len(emb.shape) == 1:
Expand All @@ -235,7 +236,7 @@ def _extract_embedding_single(
return emb


def _extraxt_embeddings(
def _extract_embeddings(
data: Union[AnyDocumentArray, BaseDocument, AnyTensor],
embedding_field: str,
embedding_type: Type,
Expand All @@ -247,40 +248,41 @@ def _extraxt_embeddings(
:param embedding_type: type of the embedding: torch.Tensor, numpy.ndarray etc.
:return: the embeddings
"""

emb: AnyTensor
if isinstance(data, DocumentArray):
emb = getattr(data, embedding_field)
emb = embedding_type._docarray_stack(emb)
elif isinstance(data, DocumentArrayStacked):
emb = getattr(data, embedding_field)
elif isinstance(data, BaseDocument):
emb = getattr(data, embedding_field)
emb_list = list(AnyDocumentArray._traverse(data, embedding_field))
emb = embedding_type._docarray_stack(emb_list)
elif isinstance(data, (DocumentArrayStacked, BaseDocument)):
emb = next(AnyDocumentArray._traverse(data, embedding_field))
else: # treat data as tensor
emb = data
emb = cast(AnyTensor, data)

if len(emb.shape) == 1:
# all currently supported frameworks provide `.reshape()`. Onc this is not true
# anymore, we need to add a `.reshape()` method to the computational backend
emb = emb.reshape(1, -1)
emb = emb.get_comp_backend().reshape(array=emb, shape=(1, -1))
return emb


def _da_attr_type(da: AnyDocumentArray, attr: str) -> Type[AnyTensor]:
def _da_attr_type(da: AnyDocumentArray, access_path: str) -> Type[AnyTensor]:
"""Get the type of the attribute according to the Document type
(schema) of the DocumentArray.

:param da: the DocumentArray
:param attr: the attribute name
:param access_path: the "__"-separated access path
:return: the type of the attribute
"""
field_type = da.document_type._get_field_type(attr)
field_type: Optional[Type] = _get_field_type_by_access_path(
da.document_type, access_path
)
if field_type is None:
raise ValueError(f"Access path is not valid: {access_path}")

if is_union_type(field_type):
# determine type based on the fist element
field_type = type(getattr(da[0], attr))
field_type = type(next(AnyDocumentArray._traverse(da[0], access_path)))

if not issubclass(field_type, AbstractTensor):
raise ValueError(
f'attribute {attr} is not a tensor-like type, '
f'attribute {access_path} is not a tensor-like type, '
f'but {field_type.__class__.__name__}'
)

Expand Down
10 changes: 8 additions & 2 deletions tests/units/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from docarray import BaseDocument
from docarray import BaseDocument, DocumentArray
from docarray.documents import Image
from docarray.helper import (
_access_path_dict_to_nested_dict,
Expand All @@ -25,8 +25,13 @@ class Middle(BaseDocument):
class Outer(BaseDocument):
img: Optional[Image]
middle: Optional[Middle]
da: DocumentArray[Inner]

doc = Outer(img=Image(), middle=Middle(img=Image(), inner=Inner(img=Image())))
doc = Outer(
img=Image(),
middle=Middle(img=Image(), inner=Inner(img=Image())),
da=DocumentArray[Inner]([Inner(img=Image(url='test.png'))]),
)
return doc


Expand All @@ -35,6 +40,7 @@ def test_is_access_path_valid(nested_doc):
assert _is_access_path_valid(nested_doc.__class__, 'middle__img')
assert _is_access_path_valid(nested_doc.__class__, 'middle__inner__img')
assert _is_access_path_valid(nested_doc.__class__, 'middle')
assert _is_access_path_valid(nested_doc.__class__, 'da__img__url')


def test_is_access_path_not_valid(nested_doc):
Expand Down
30 changes: 30 additions & 0 deletions tests/units/util/test_find.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,36 @@ class MyDoc(BaseDocument):
assert (torch.stack(sorted(scores, reverse=True)) == scores).all()


@pytest.mark.parametrize('stack', [False, True])
def test_find_nested(stack):
class InnerDoc(BaseDocument):
title: str
embedding: TorchTensor

class MyDoc(BaseDocument):
inner: InnerDoc

query = MyDoc(inner=InnerDoc(title='query', embedding=torch.rand(2)))
index = DocumentArray[MyDoc](
[
MyDoc(inner=InnerDoc(title=f'doc {i}', embedding=torch.rand(2)))
for i in range(10)
]
)
if stack:
index = index.stack()

top_k, scores = find(
index,
query,
embedding_field='inner__embedding',
limit=7,
)
assert len(top_k) == 7
assert len(scores) == 7
assert (torch.stack(sorted(scores, reverse=True)) == scores).all()


def test_find_nested_union_optional():
class MyDoc(BaseDocument):
embedding: Union[Optional[TorchTensor], Optional[NdArray]]
Expand Down