Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions docarray/base_doc/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
'set[int] | set[str] | dict[int, Any] | dict[str, Any] | None'
)

from pydantic import ConfigDict


_console: Console = Console()

Expand All @@ -71,10 +73,14 @@ class BaseDocWithoutId(BaseModel, IOMixin, UpdateMixin, BaseNode):

if is_pydantic_v2:

class Config:
validate_assignment = True
_load_extra_fields_from_protobuf = False
json_encoders = {AbstractTensor: lambda x: x}
class ConfigDocArray(ConfigDict):
_load_extra_fields_from_protobuf: bool

model_config = ConfigDocArray(
validate_assignment=True,
_load_extra_fields_from_protobuf=False,
json_encoders={AbstractTensor: lambda x: x},
)

else:

Expand Down
8 changes: 6 additions & 2 deletions docarray/base_doc/mixins/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,14 @@ def from_protobuf(cls: Type[T], pb_msg: 'DocProto') -> T:
"""

fields: Dict[str, Any] = {}

load_extra_field = (
cls.model_config['_load_extra_fields_from_protobuf']
if is_pydantic_v2
else cls.Config._load_extra_fields_from_protobuf
)
for field_name in pb_msg.data:
if (
not (cls.Config._load_extra_fields_from_protobuf)
not (load_extra_field)
and field_name not in cls._docarray_fields().keys()
):
continue # optimization we don't even load the data if the key does not
Expand Down
31 changes: 23 additions & 8 deletions docs/user_guide/representing/first_step.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,22 +119,37 @@ This representation can be used to [send](../sending/first_step.md) or [store](.

## Setting a Pydantic `Config` class

Documents support setting a `Config` [like any other Pydantic `BaseModel`](https://docs.pydantic.dev/latest/usage/model_config/).
Documents support setting a custom `configuration` [like any other Pydantic `BaseModel`](https://docs.pydantic.dev/latest/api/config/).

However, if you set a config, you should inherit from the `BaseDoc` config class:
Here is an example to extend the Config of a Document dependong on which version of Pydantic you are using.

```python
from docarray import BaseDoc


class MyDoc(BaseDoc):
class Config(BaseDoc.Config):
arbitrary_types_allowed = True # just an example setting
```
=== "Pydantic v1"
```python
from docarray import BaseDoc


class MyDoc(BaseDoc):
class Config(BaseDoc.Config):
arbitrary_types_allowed = True # just an example setting
```

=== "Pydantic v2"
```python
from docarray import BaseDoc


class MyDoc(BaseDoc):
model_config = BaseDoc.ConfigDocArray.ConfigDict(
arbitrary_types_allowed=True
) # just an example setting
```

See also:

* The [next part](./array.md) of the representing section
* API reference for the [BaseDoc][docarray.base_doc.doc.BaseDoc] class
* The [Storing](../storing/first_step.md) section on how to store your data
* The [Sending](../sending/first_step.md) section on how to send your data

22 changes: 0 additions & 22 deletions tests/units/document/test_any_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,10 @@

import numpy as np
import pytest
from orjson import orjson

from docarray import DocList
from docarray.base_doc import AnyDoc, BaseDoc
from docarray.base_doc.io.json import orjson_dumps_and_decode
from docarray.typing import NdArray
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.utils._internal.pydantic import is_pydantic_v2


def test_any_doc():
Expand Down Expand Up @@ -94,21 +90,3 @@ class DocTest(BaseDoc):
assert isinstance(d.ld[0], dict)
assert d.ld[0]['text'] == 'I am inner'
assert d.ld[0]['t'] == {'a': 'b'}


@pytest.mark.skipif(is_pydantic_v2, reason="Not working with pydantic v2 for now")
def test_subclass_config():
class MyDoc(BaseDoc):
x: str

class Config(BaseDoc.Config):
arbitrary_types_allowed = True # just an example setting

assert MyDoc.Config.json_loads == orjson.loads
assert MyDoc.Config.json_dumps == orjson_dumps_and_decode
assert (
MyDoc.Config.json_encoders[AbstractTensor](3) == 3
) # dirty check that it is identity
assert MyDoc.Config.validate_assignment
assert not MyDoc.Config._load_extra_fields_from_protobuf
assert MyDoc.Config.arbitrary_types_allowed
41 changes: 41 additions & 0 deletions tests/units/document/test_base_document.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from typing import Any, List, Optional, Tuple

import numpy as np
import orjson
import pytest

from docarray import DocList, DocVec
from docarray.base_doc.doc import BaseDoc
from docarray.base_doc.io.json import orjson_dumps_and_decode
from docarray.typing import NdArray
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.utils._internal.pydantic import is_pydantic_v2


def test_base_document_init():
Expand Down Expand Up @@ -146,3 +150,40 @@ class MyDoc(BaseDoc):
field_type = MyDoc._get_field_inner_type("tuple_")

assert field_type == Any


@pytest.mark.skipif(
is_pydantic_v2, reason="syntax only working with pydantic v1 for now"
)
def test_subclass_config():
class MyDoc(BaseDoc):
x: str

class Config(BaseDoc.Config):
arbitrary_types_allowed = True # just an example setting

assert MyDoc.Config.json_loads == orjson.loads
assert MyDoc.Config.json_dumps == orjson_dumps_and_decode
assert (
MyDoc.Config.json_encoders[AbstractTensor](3) == 3
) # dirty check that it is identity
assert MyDoc.Config.validate_assignment
assert not MyDoc.Config._load_extra_fields_from_protobuf
assert MyDoc.Config.arbitrary_types_allowed


@pytest.mark.skipif(not (is_pydantic_v2), reason="syntax only working with pydantic v2")
def test_subclass_config_v2():
class MyDoc(BaseDoc):
x: str

model_config = BaseDoc.ConfigDocArray(
arbitrary_types_allowed=True
) # just an example setting

assert (
MyDoc.model_config['json_encoders'][AbstractTensor](3) == 3
) # dirty check that it is identity
assert MyDoc.model_config['validate_assignment']
assert not MyDoc.model_config['_load_extra_fields_from_protobuf']
assert MyDoc.model_config['arbitrary_types_allowed']