Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,33 @@ jobs:
# fail_ci_if_error: false
# token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos

docarray-test-proto3:
needs: [lint-ruff, check-black, import-test]
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: [3.7]
steps:
- uses: actions/checkout@v2.5.0
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Prepare environment
run: |
python -m pip install --upgrade pip
python -m pip install poetry
poetry install --all-extras
pip install protobuf==3.19.0 # we check that we support 3.19

- name: Test
id: test
run: |
poetry run pytest -k 'proto' tests
timeout-minutes: 30


# just for blocking the merge until all parallel core-test are successful
success-all-test:
needs: [docarray-test, check-mypy]
Expand Down
28 changes: 20 additions & 8 deletions docarray/proto/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
from docarray.proto.pb2.docarray_pb2 import (
DocumentArrayProto,
DocumentArrayStackedProto,
DocumentProto,
NdArrayProto,
NodeProto,
UnionArrayProto,
)
from google.protobuf import __version__ as __pb__version__

if __pb__version__.startswith('4'):
from docarray.proto.pb.docarray_pb2 import (
DocumentArrayProto,
DocumentArrayStackedProto,
DocumentProto,
NdArrayProto,
NodeProto,
UnionArrayProto,
)
else:
from docarray.proto.pb2.docarray_pb2 import (
DocumentArrayProto,
DocumentArrayStackedProto,
DocumentProto,
NdArrayProto,
NodeProto,
UnionArrayProto,
)
Comment on lines +3 to +20
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aren't these two exactly the same? or am I blind? 😅

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh i see the difference now, nevermind

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this PR should be named "spot the difference"

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nerd version of "where is waldo?"


__all__ = [
'DocumentArrayProto',
Expand Down
48 changes: 48 additions & 0 deletions docarray/proto/pb/docarray_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

116 changes: 113 additions & 3 deletions docarray/proto/pb2/docarray_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ license='Apache 2.0'
python = "^3.7"
pydantic = "^1.10.2"
numpy = "^1.17.3"
protobuf = { version = "^4.21.9", optional = true }
protobuf = { version = ">=3.19.0", optional = true }
torch = { version = "^1.0.0", optional = true }
orjson = "^3.8.2"
pillow = {version = "^9.3.0", optional = true }
Expand Down Expand Up @@ -71,15 +71,15 @@ exclude = 'docarray/proto/pb2/*'


[tool.isort]
skip_glob= ['docarray/proto/pb2/*']
skip_glob= ['docarray/proto/pb2/*', 'docarray/proto/pb/*']

[tool.ruff]
exclude = ['docarray/proto/pb2/*', 'docs/*']
exclude = ['docarray/proto/pb2/*', 'docarray/proto/pb/*', 'docs/*']

[tool.pytest.ini_options]
markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
"internet: marks tests as requiring internet (deselect with '-m \"not internet\"')",
"asyncio: marks that run async tests",

"proto: mark tests that run with proto"
]
3 changes: 3 additions & 0 deletions tests/integrations/document/test_proto.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pytest
import torch

from docarray import BaseDocument, DocumentArray
Expand All @@ -19,6 +20,7 @@
from docarray.typing.tensor import NdArrayEmbedding


@pytest.mark.proto
def test_multi_modal_doc_proto():
class MyMultiModalDoc(BaseDocument):
image: Image
Expand All @@ -31,6 +33,7 @@ class MyMultiModalDoc(BaseDocument):
MyMultiModalDoc.from_protobuf(doc.to_protobuf())


@pytest.mark.proto
def test_all_types():
class NestedDoc(BaseDocument):
tensor: NdArray
Expand Down
5 changes: 5 additions & 0 deletions tests/units/array/test_array_proto.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import numpy as np
import pytest

from docarray import BaseDocument, DocumentArray
from docarray.array.array_stacked import DocumentArrayStacked
from docarray.documents import Image, Text
from docarray.typing import NdArray


@pytest.mark.proto
def test_simple_proto():
class CustomDoc(BaseDocument):
text: str
Expand All @@ -22,6 +24,7 @@ class CustomDoc(BaseDocument):
assert (doc1.tensor == doc2.tensor).all()


@pytest.mark.proto
def test_nested_proto():
class CustomDocument(BaseDocument):
text: Text
Expand All @@ -39,6 +42,7 @@ class CustomDocument(BaseDocument):
DocumentArray[CustomDocument].from_protobuf(da.to_protobuf())


@pytest.mark.proto
def test_nested_proto_any_doc():
class CustomDocument(BaseDocument):
text: Text
Expand All @@ -56,6 +60,7 @@ class CustomDocument(BaseDocument):
DocumentArray.from_protobuf(da.to_protobuf())


@pytest.mark.proto
def test_stacked_proto():
class CustomDocument(BaseDocument):
image: NdArray
Expand Down
6 changes: 2 additions & 4 deletions tests/units/array/test_array_stacked.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,12 @@ def test_iterator(batch):


def test_stack_setter(batch):

batch.tensor = torch.ones(10, 3, 224, 224)

assert (batch.tensor == torch.ones(10, 3, 224, 224)).all()


def test_stack_optional(batch):

assert (batch._columns['tensor'] == torch.zeros(10, 3, 224, 224)).all()
assert (batch.tensor == torch.zeros(10, 3, 224, 224)).all()

Expand All @@ -67,7 +65,6 @@ class Image(BaseDocument):


def test_stack(batch):

assert (batch._columns['tensor'] == torch.zeros(10, 3, 224, 224)).all()
assert (batch.tensor == torch.zeros(10, 3, 224, 224)).all()
assert batch._columns['tensor'].data_ptr() == batch.tensor.data_ptr()
Expand Down Expand Up @@ -138,11 +135,12 @@ class MMdoc(BaseDocument):
assert (doc.img.tensor == torch.zeros(3, 224, 224)).all()


@pytest.mark.proto
def test_proto_stacked_mode_torch(batch):

batch.from_protobuf(batch.to_protobuf())


@pytest.mark.proto
def test_proto_stacked_mode_numpy():
class MyDoc(BaseDocument):
tensor: NdArray[3, 224, 224]
Expand Down
Loading