Skip to content

Commit ebe0ede

Browse files
authored
feat: support proto 3 (#1078)
* feat: allow to source for pb Signed-off-by: samsja <sami.jaghouar@hotmail.fr> * feat: add proto 3 files Signed-off-by: samsja <sami.jaghouar@hotmail.fr> * chore: add ci run for proto 3 Signed-off-by: samsja <sami.jaghouar@hotmail.fr> * chore: remove upper limit on proto Signed-off-by: samsja <sami.jaghouar@hotmail.fr> --------- Signed-off-by: samsja <sami.jaghouar@hotmail.fr>
1 parent 37b6d3e commit ebe0ede

File tree

13 files changed

+236
-19
lines changed

13 files changed

+236
-19
lines changed

.github/workflows/ci.yml

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,33 @@ jobs:
133133
# fail_ci_if_error: false
134134
# token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos
135135

136+
docarray-test-proto3:
137+
needs: [lint-ruff, check-black, import-test]
138+
runs-on: ubuntu-latest
139+
strategy:
140+
fail-fast: false
141+
matrix:
142+
python-version: [3.7]
143+
steps:
144+
- uses: actions/checkout@v2.5.0
145+
- name: Set up Python ${{ matrix.python-version }}
146+
uses: actions/setup-python@v4
147+
with:
148+
python-version: ${{ matrix.python-version }}
149+
- name: Prepare environment
150+
run: |
151+
python -m pip install --upgrade pip
152+
python -m pip install poetry
153+
poetry install --all-extras
154+
pip install protobuf==3.19.0 # we check that we support 3.19
155+
156+
- name: Test
157+
id: test
158+
run: |
159+
poetry run pytest -k 'proto' tests
160+
timeout-minutes: 30
161+
162+
136163
# just for blocking the merge until all parallel core-test are successful
137164
success-all-test:
138165
needs: [docarray-test, check-mypy]

docarray/proto/__init__.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,23 @@
1-
from docarray.proto.pb2.docarray_pb2 import (
2-
DocumentArrayProto,
3-
DocumentArrayStackedProto,
4-
DocumentProto,
5-
NdArrayProto,
6-
NodeProto,
7-
UnionArrayProto,
8-
)
1+
from google.protobuf import __version__ as __pb__version__
2+
3+
if __pb__version__.startswith('4'):
4+
from docarray.proto.pb.docarray_pb2 import (
5+
DocumentArrayProto,
6+
DocumentArrayStackedProto,
7+
DocumentProto,
8+
NdArrayProto,
9+
NodeProto,
10+
UnionArrayProto,
11+
)
12+
else:
13+
from docarray.proto.pb2.docarray_pb2 import (
14+
DocumentArrayProto,
15+
DocumentArrayStackedProto,
16+
DocumentProto,
17+
NdArrayProto,
18+
NodeProto,
19+
UnionArrayProto,
20+
)
921

1022
__all__ = [
1123
'DocumentArrayProto',

docarray/proto/pb/docarray_pb2.py

Lines changed: 48 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docarray/proto/pb2/docarray_pb2.py

Lines changed: 113 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ license='Apache 2.0'
99
python = "^3.7"
1010
pydantic = "^1.10.2"
1111
numpy = "^1.17.3"
12-
protobuf = { version = "^4.21.9", optional = true }
12+
protobuf = { version = ">=3.19.0", optional = true }
1313
torch = { version = "^1.0.0", optional = true }
1414
orjson = "^3.8.2"
1515
pillow = {version = "^9.3.0", optional = true }
@@ -71,15 +71,15 @@ exclude = 'docarray/proto/pb2/*'
7171

7272

7373
[tool.isort]
74-
skip_glob= ['docarray/proto/pb2/*']
74+
skip_glob= ['docarray/proto/pb2/*', 'docarray/proto/pb/*']
7575

7676
[tool.ruff]
77-
exclude = ['docarray/proto/pb2/*', 'docs/*']
77+
exclude = ['docarray/proto/pb2/*', 'docarray/proto/pb/*', 'docs/*']
7878

7979
[tool.pytest.ini_options]
8080
markers = [
8181
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
8282
"internet: marks tests as requiring internet (deselect with '-m \"not internet\"')",
8383
"asyncio: marks that run async tests",
84-
84+
"proto: mark tests that run with proto"
8585
]

tests/integrations/document/test_proto.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import pytest
23
import torch
34

45
from docarray import BaseDocument, DocumentArray
@@ -19,6 +20,7 @@
1920
from docarray.typing.tensor import NdArrayEmbedding
2021

2122

23+
@pytest.mark.proto
2224
def test_multi_modal_doc_proto():
2325
class MyMultiModalDoc(BaseDocument):
2426
image: Image
@@ -31,6 +33,7 @@ class MyMultiModalDoc(BaseDocument):
3133
MyMultiModalDoc.from_protobuf(doc.to_protobuf())
3234

3335

36+
@pytest.mark.proto
3437
def test_all_types():
3538
class NestedDoc(BaseDocument):
3639
tensor: NdArray

tests/units/array/test_array_proto.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import numpy as np
2+
import pytest
23

34
from docarray import BaseDocument, DocumentArray
45
from docarray.array.array_stacked import DocumentArrayStacked
56
from docarray.documents import Image, Text
67
from docarray.typing import NdArray
78

89

10+
@pytest.mark.proto
911
def test_simple_proto():
1012
class CustomDoc(BaseDocument):
1113
text: str
@@ -22,6 +24,7 @@ class CustomDoc(BaseDocument):
2224
assert (doc1.tensor == doc2.tensor).all()
2325

2426

27+
@pytest.mark.proto
2528
def test_nested_proto():
2629
class CustomDocument(BaseDocument):
2730
text: Text
@@ -39,6 +42,7 @@ class CustomDocument(BaseDocument):
3942
DocumentArray[CustomDocument].from_protobuf(da.to_protobuf())
4043

4144

45+
@pytest.mark.proto
4246
def test_nested_proto_any_doc():
4347
class CustomDocument(BaseDocument):
4448
text: Text
@@ -56,6 +60,7 @@ class CustomDocument(BaseDocument):
5660
DocumentArray.from_protobuf(da.to_protobuf())
5761

5862

63+
@pytest.mark.proto
5964
def test_stacked_proto():
6065
class CustomDocument(BaseDocument):
6166
image: NdArray

tests/units/array/test_array_stacked.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,12 @@ def test_iterator(batch):
3737

3838

3939
def test_stack_setter(batch):
40-
4140
batch.tensor = torch.ones(10, 3, 224, 224)
4241

4342
assert (batch.tensor == torch.ones(10, 3, 224, 224)).all()
4443

4544

4645
def test_stack_optional(batch):
47-
4846
assert (batch._columns['tensor'] == torch.zeros(10, 3, 224, 224)).all()
4947
assert (batch.tensor == torch.zeros(10, 3, 224, 224)).all()
5048

@@ -67,7 +65,6 @@ class Image(BaseDocument):
6765

6866

6967
def test_stack(batch):
70-
7168
assert (batch._columns['tensor'] == torch.zeros(10, 3, 224, 224)).all()
7269
assert (batch.tensor == torch.zeros(10, 3, 224, 224)).all()
7370
assert batch._columns['tensor'].data_ptr() == batch.tensor.data_ptr()
@@ -138,11 +135,12 @@ class MMdoc(BaseDocument):
138135
assert (doc.img.tensor == torch.zeros(3, 224, 224)).all()
139136

140137

138+
@pytest.mark.proto
141139
def test_proto_stacked_mode_torch(batch):
142-
143140
batch.from_protobuf(batch.to_protobuf())
144141

145142

143+
@pytest.mark.proto
146144
def test_proto_stacked_mode_numpy():
147145
class MyDoc(BaseDocument):
148146
tensor: NdArray[3, 224, 224]

0 commit comments

Comments
 (0)