Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
3ea813e
fix: jax backend boilerplate setup
agaraman0 May 22, 2023
825daf5
feat: typing init
agaraman0 May 23, 2023
fd11322
feat: JaxArray refactoring
agaraman0 May 24, 2023
bc4a698
fix: _docarray_from_native function for jaxarray
agaraman0 Jun 4, 2023
cb64d4f
feat: JAX array implementation is complete
agaraman0 Jun 14, 2023
d57a656
feat: JAX array implementation is complete
agaraman0 Jun 14, 2023
49b3764
feat: JaxCompBackend tests complete till nested Retrieval class
agaraman0 Jun 16, 2023
23589d6
fix: isort format fix
agaraman0 Jun 20, 2023
8793822
feat: Jax Added as dependency
agaraman0 Jun 20, 2023
aff2ff6
feat: poetry lock added
agaraman0 Jun 20, 2023
3d6f45f
fix: jax_comp review comments resolved
agaraman0 Jun 23, 2023
3072cea
fix: squashed commits and bypassing jax test cases which fails
agaraman0 Jul 11, 2023
e0a7d89
fix: add jax pytest marker for missing testcase
agaraman0 Jul 11, 2023
3f9a399
feat: added integration test
agaraman0 Jul 14, 2023
b7b81d6
fix poetry lock update
agaraman0 Jul 14, 2023
963c1b3
fix: test_jax_integration changes
agaraman0 Jul 14, 2023
6f571e2
fix: init comments change
agaraman0 Jul 14, 2023
c03d37c
fix: inmemory changes included
agaraman0 Jul 17, 2023
b64b31f
fix: include jax tests
agaraman0 Jul 17, 2023
a8a3545
fix: include jax tests
agaraman0 Jul 17, 2023
4e2762a
fix: include jax tests round#2
agaraman0 Jul 17, 2023
67dce6d
fix: install jax and run jax tests
agaraman0 Jul 17, 2023
2b3cf04
fix: install jaxlib and check for jax workflow
agaraman0 Jul 17, 2023
0fe883e
fix: failed test cases fixes
agaraman0 Jul 17, 2023
b6d7fe4
fix: failed jax test cases fixes
agaraman0 Jul 17, 2023
06afa37
Merge branch 'main' into feat-comp-jax-backend
agaraman0 Jul 17, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 46 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ jobs:
python -m pip install poetry
poetry install --without dev
poetry run pip install tensorflow==2.11.0
poetry run pip install jax
- name: Test basic import
run: poetry run python -c 'from docarray import DocList, BaseDoc'

Expand Down Expand Up @@ -111,7 +112,7 @@ jobs:
- name: Test
id: test
run: |
poetry run pytest -m "not (tensorflow or benchmark or index)" --cov=docarray --cov-report=xml ${{ matrix.test-path }} --ignore=tests/integrations/store/test_jac.py
poetry run pytest -m "not (tensorflow or benchmark or index or jax)" --cov=docarray --cov-report=xml ${{ matrix.test-path }} --ignore=tests/integrations/store/test_jac.py
echo "flag it as docarray for codeoverage"
echo "codecov_flag=docarray" >> $GITHUB_OUTPUT
timeout-minutes: 30
Expand Down Expand Up @@ -158,7 +159,7 @@ jobs:
- name: Test
id: test
run: |
poetry run pytest -m "not (tensorflow or benchmark or index)" --cov=docarray --cov-report=xml tests/integrations/store/test_jac.py
poetry run pytest -m "not (tensorflow or benchmark or index or jax)" --cov=docarray --cov-report=xml tests/integrations/store/test_jac.py
echo "flag it as docarray for codeoverage"
echo "codecov_flag=docarray" >> $GITHUB_OUTPUT
timeout-minutes: 30
Expand Down Expand Up @@ -357,6 +358,49 @@ jobs:
flags: ${{ steps.test.outputs.codecov_flag }}
fail_ci_if_error: false

docarray-test-jax:
needs: [import-test]
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: [3.8]
steps:
- uses: actions/checkout@v2.5.0
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Prepare environment
run: |
python -m pip install --upgrade pip
python -m pip install poetry
poetry install --all-extras
poetry run pip install jaxlib
poetry run pip install jax

- name: Test
id: test
run: |
poetry run pytest -m 'jax' --cov=docarray --cov-report=xml tests
echo "flag it as docarray for codeoverage"
echo "codecov_flag=docarray" >> $GITHUB_OUTPUT
timeout-minutes: 30
- name: Check codecov file
id: check_files
uses: andstor/file-existence-action@v1
with:
files: "coverage.xml"
- name: Upload coverage from test to Codecov
uses: codecov/codecov-action@v3.1.1
if: steps.check_files.outputs.files_exists == 'true' && ${{ matrix.python-version }} == '3.8'
with:
file: coverage.xml
name: benchmark-test-codecov
flags: ${{ steps.test.outputs.codecov_flag }}
fail_ci_if_error: false



docarray-test-benchmarks:
needs: [import-test]
Expand Down
28 changes: 26 additions & 2 deletions docarray/array/doc_vec/doc_vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@
from docarray.typing import NdArray
from docarray.typing.tensor.abstract_tensor import AbstractTensor
from docarray.utils._internal._typing import is_tensor_union, safe_issubclass
from docarray.utils._internal.misc import is_tf_available, is_torch_available
from docarray.utils._internal.misc import (
is_jax_available,
is_tf_available,
is_torch_available,
)

if TYPE_CHECKING:
import csv
Expand Down Expand Up @@ -60,6 +64,14 @@
else:
TensorFlowTensor = None # type: ignore

jnp_available = is_jax_available()
if jnp_available:
import jax.numpy as jnp # type: ignore

from docarray.typing import JaxArray # noqa: F401
else:
JaxArray = None # type: ignore

T_doc = TypeVar('T_doc', bound=BaseDoc)
T = TypeVar('T', bound='DocVec')
T_io_mixin = TypeVar('T_io_mixin', bound='IOMixinArray')
Expand Down Expand Up @@ -262,6 +274,19 @@ def _check_doc_field_not_none(field_name, doc):

stacked: tf.Tensor = tf.stack(tf_stack)
tensor_columns[field_name] = TensorFlowTensor(stacked)
elif jnp_available and issubclass(field_type, JaxArray):
if first_doc_is_none:
_verify_optional_field_of_docs(docs)
tensor_columns[field_name] = None
else:
tf_stack = []
for i, doc in enumerate(docs):
val = getattr(doc, field_name)
_check_doc_field_not_none(field_name, doc)
tf_stack.append(val.tensor)

jax_stacked: jnp.ndarray = jnp.stack(tf_stack)
tensor_columns[field_name] = JaxArray(jax_stacked)

elif safe_issubclass(field_type, AbstractTensor):
if first_doc_is_none:
Expand Down Expand Up @@ -835,7 +860,6 @@ def to_doc_list(self: T) -> DocList[T_doc]:
unstacked_doc_column[field] = doc_col.to_doc_list() if doc_col else None

for field, da_col in self._storage.docs_vec_columns.items():

unstacked_da_column[field] = (
[docs.to_doc_list() for docs in da_col] if da_col else None
)
Expand Down
15 changes: 14 additions & 1 deletion docarray/array/list_advance_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from typing_extensions import SupportsIndex

from docarray.utils._internal.misc import (
is_torch_available,
is_jax_available,
is_tf_available,
is_torch_available,
)

torch_available = is_torch_available()
Expand All @@ -24,7 +25,13 @@
tf_available = is_tf_available()
if tf_available:
import tensorflow as tf # type: ignore

from docarray.typing.tensor.tensorflow_tensor import TensorFlowTensor
jax_available = is_jax_available()
if jax_available:
import jax.numpy as jnp

from docarray.typing.tensor.jaxarray import JaxArray

T_item = TypeVar('T_item')
T = TypeVar('T', bound='ListAdvancedIndexing')
Expand Down Expand Up @@ -100,6 +107,12 @@ def _normalize_index_item(
if isinstance(item, TensorFlowTensor):
return item.tensor.numpy().tolist()

if jax_available:
if isinstance(item, jnp.ndarray):
return item.__array__().tolist()
if isinstance(item, JaxArray):
return item.tensor.__array__().tolist()

return item

def _get_from_indices(self: T, item: Iterable[int]) -> T:
Expand Down
Loading