Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions docarray/documents/audio.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Type, TypeVar, Union

import numpy as np

from pydantic import Field

from docarray.base_doc import BaseDoc
Expand All @@ -10,6 +9,10 @@
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.typing.tensor.audio.audio_tensor import AudioTensor
from docarray.utils._internal.misc import import_library
from docarray.utils._internal.pydantic import is_pydantic_v2

if is_pydantic_v2:
from pydantic import model_validator

if TYPE_CHECKING:
import tensorflow as tf # type: ignore
Expand Down Expand Up @@ -121,17 +124,30 @@ class MultiModalDoc(BaseDoc):
)

@classmethod
def validate(
cls: Type[T],
value: Union[str, AbstractTensor, Any],
) -> T:
def _validate(cls, value) -> Dict[str, Any]:
if isinstance(value, str):
value = cls(url=value)
value = dict(url=value)
elif isinstance(value, (AbstractTensor, np.ndarray)) or (
torch is not None
and isinstance(value, torch.Tensor)
or (tf is not None and isinstance(value, tf.Tensor))
):
value = cls(tensor=value)
value = dict(tensor=value)

return value

if is_pydantic_v2:

@model_validator(mode='before')
@classmethod
def validate_model_before(cls, value):
return cls._validate(value)

else:

return super().validate(value)
@classmethod
def validate(
cls: Type[T],
value: Union[str, AbstractTensor, Any],
) -> T:
return super().validate(cls._validate(value))
35 changes: 25 additions & 10 deletions docarray/documents/image.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Type, TypeVar, Union

import numpy as np

from pydantic import Field

from docarray.base_doc import BaseDoc
from docarray.typing import AnyEmbedding, ImageBytes, ImageUrl
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.typing.tensor.image.image_tensor import ImageTensor
from docarray.utils._internal.misc import import_library
from docarray.utils._internal.pydantic import is_pydantic_v2

if is_pydantic_v2:
from pydantic import model_validator

if TYPE_CHECKING:
import tensorflow as tf # type: ignore
Expand Down Expand Up @@ -115,19 +117,32 @@ class MultiModalDoc(BaseDoc):
)

@classmethod
def validate(
cls: Type[T],
value: Union[str, AbstractTensor, Any],
) -> T:
def _validate(cls, value) -> Dict[str, Any]:
if isinstance(value, str):
value = cls(url=value)
value = dict(url=value)
elif (
isinstance(value, (AbstractTensor, np.ndarray))
or (torch is not None and isinstance(value, torch.Tensor))
or (tf is not None and isinstance(value, tf.Tensor))
):
value = cls(tensor=value)
value = dict(tensor=value)
elif isinstance(value, bytes):
value = cls(byte=value)
value = dict(byte=value)

return value

if is_pydantic_v2:

@model_validator(mode='before')
@classmethod
def validate_model_before(cls, value):
return cls._validate(value)

else:

return super().validate(value)
@classmethod
def validate(
cls: Type[T],
value: Union[str, AbstractTensor, Any],
) -> T:
return super().validate(cls._validate(value))
30 changes: 22 additions & 8 deletions docarray/documents/mesh/mesh_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from docarray.documents.mesh.vertices_and_faces import VerticesAndFaces
from docarray.typing.tensor.embedding import AnyEmbedding
from docarray.typing.url.url_3d.mesh_url import Mesh3DUrl
from docarray.utils._internal.pydantic import is_pydantic_v2

if is_pydantic_v2:
from pydantic import model_validator

T = TypeVar('T', bound='Mesh3D')

Expand Down Expand Up @@ -125,11 +128,22 @@ class MultiModalDoc(BaseDoc):
default=None,
)

@classmethod
def validate(
cls: Type[T],
value: Union[str, Any],
) -> T:
if isinstance(value, str):
value = cls(url=value)
return super().validate(value)
if is_pydantic_v2:

@model_validator(mode='before')
@classmethod
def validate_model_before(cls, value):
if isinstance(value, str):
return {'url': value}
return value

else:

@classmethod
def validate(
cls: Type[T],
value: Union[str, Any],
) -> T:
if isinstance(value, str):
value = cls(url=value)
return super().validate(value)
32 changes: 24 additions & 8 deletions docarray/documents/point_cloud/point_cloud_3d.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union

import numpy as np

from pydantic import Field

from docarray.base_doc import BaseDoc
from docarray.documents.point_cloud.points_and_colors import PointsAndColors
from docarray.typing import AnyEmbedding, PointCloud3DUrl
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.utils._internal.misc import import_library
from docarray.utils._internal.pydantic import is_pydantic_v2

if is_pydantic_v2:
from pydantic import model_validator

if TYPE_CHECKING:
import tensorflow as tf # type: ignore
Expand Down Expand Up @@ -130,17 +133,30 @@ class MultiModalDoc(BaseDoc):
)

@classmethod
def validate(
cls: Type[T],
value: Union[str, AbstractTensor, Any],
) -> T:
def _validate(self, value: Union[str, AbstractTensor, Any]) -> Any:
if isinstance(value, str):
value = cls(url=value)
value = {'url': value}
elif isinstance(value, (AbstractTensor, np.ndarray)) or (
torch is not None
and isinstance(value, torch.Tensor)
or (tf is not None and isinstance(value, tf.Tensor))
):
value = cls(tensors=PointsAndColors(points=value))
value = {'tensors': PointsAndColors(points=value)}

return value

if is_pydantic_v2:

@model_validator(mode='before')
@classmethod
def validate_model_before(cls, value):
return cls._validate(value)

else:

return super().validate(value)
@classmethod
def validate(
cls: Type[T],
value: Union[str, AbstractTensor, Any],
) -> T:
return super().validate(cls._validate(value))
32 changes: 24 additions & 8 deletions docarray/documents/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from docarray.base_doc import BaseDoc
from docarray.typing import TextUrl
from docarray.typing.tensor.embedding import AnyEmbedding
from docarray.utils._internal.pydantic import is_pydantic_v2

if is_pydantic_v2:
from pydantic import model_validator

T = TypeVar('T', bound='TextDoc')

Expand Down Expand Up @@ -129,14 +133,26 @@ def __init__(self, text: Optional[str] = None, **kwargs):
kwargs['text'] = text
super().__init__(**kwargs)

@classmethod
def validate(
cls: Type[T],
value: Union[str, Any],
) -> T:
if isinstance(value, str):
value = cls(text=value)
return super().validate(value)
if is_pydantic_v2:

@model_validator(mode='before')
@classmethod
def validate_model_before(cls, values):
if isinstance(values, str):
return {'text': values}
else:
return values

else:

@classmethod
def validate(
cls: Type[T],
value: Union[str, Any],
) -> T:
if isinstance(value, str):
value = cls(text=value)
return super().validate(value)

def __eq__(self, other: Any) -> bool:
if isinstance(other, str):
Expand Down
34 changes: 25 additions & 9 deletions docarray/documents/video.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Type, TypeVar, Union

import numpy as np

from pydantic import Field

from docarray.base_doc import BaseDoc
Expand All @@ -11,6 +10,10 @@
from docarray.typing.tensor.video.video_tensor import VideoTensor
from docarray.typing.url.video_url import VideoUrl
from docarray.utils._internal.misc import import_library
from docarray.utils._internal.pydantic import is_pydantic_v2

if is_pydantic_v2:
from pydantic import model_validator

if TYPE_CHECKING:
import tensorflow as tf # type: ignore
Expand Down Expand Up @@ -131,17 +134,30 @@ class MultiModalDoc(BaseDoc):
)

@classmethod
def validate(
cls: Type[T],
value: Union[str, AbstractTensor, Any],
) -> T:
def _validate(cls, value) -> Dict[str, Any]:
if isinstance(value, str):
value = cls(url=value)
value = dict(url=value)
elif isinstance(value, (AbstractTensor, np.ndarray)) or (
torch is not None
and isinstance(value, torch.Tensor)
or (tf is not None and isinstance(value, tf.Tensor))
):
value = cls(tensor=value)
value = dict(tensor=value)

return value

if is_pydantic_v2:

@model_validator(mode='before')
@classmethod
def validate_model_before(cls, value):
return cls._validate(value)

else:

return super().validate(value)
@classmethod
def validate(
cls: Type[T],
value: Union[str, AbstractTensor, Any],
) -> T:
return super().validate(cls._validate(value))
1 change: 0 additions & 1 deletion docarray/typing/tensor/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ def _docarray_validate(
return cls._docarray_from_native(arr)
except Exception:
pass # handled below
breakpoint()
raise ValueError(f'Expected a numpy.ndarray compatible type, got {type(value)}')

@classmethod
Expand Down
6 changes: 0 additions & 6 deletions tests/integrations/predefined_document/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from docarray.typing import AudioUrl
from docarray.typing.tensor.audio import AudioNdArray, AudioTorchTensor
from docarray.utils._internal.misc import is_tf_available
from docarray.utils._internal.pydantic import is_pydantic_v2
from tests import TOYDATA_DIR

tf_available = is_tf_available()
Expand Down Expand Up @@ -184,32 +183,27 @@ class MyAudio(AudioDoc):


# Validating predefined docs against url or tensor is not yet working with pydantic v28
@pytest.mark.skipif(is_pydantic_v2, reason="Not working with pydantic v2 for now")
def test_audio_np():
audio = parse_obj_as(AudioDoc, np.zeros((10, 10, 3)))
assert (audio.tensor == np.zeros((10, 10, 3))).all()


@pytest.mark.skipif(is_pydantic_v2, reason="Not working with pydantic v2 for now")
def test_audio_torch():
audio = parse_obj_as(AudioDoc, torch.zeros(10, 10, 3))
assert (audio.tensor == torch.zeros(10, 10, 3)).all()


@pytest.mark.skipif(is_pydantic_v2, reason="Not working with pydantic v2 for now")
@pytest.mark.tensorflow
def test_audio_tensorflow():
audio = parse_obj_as(AudioDoc, tf.zeros((10, 10, 3)))
assert tnp.allclose(audio.tensor.tensor, tf.zeros((10, 10, 3)))


@pytest.mark.skipif(is_pydantic_v2, reason="Not working with pydantic v2 for now")
def test_audio_bytes():
audio = parse_obj_as(AudioDoc, torch.zeros(10, 10, 3))
audio.bytes_ = audio.tensor.to_bytes()


@pytest.mark.skipif(is_pydantic_v2, reason="Not working with pydantic v2 for now")
def test_audio_shortcut_doc():
class MyDoc(BaseDoc):
audio: AudioDoc
Expand Down
Loading