Skip to content

Commit febea8d

Browse files
committed
fix: make DocList properly a Generic
1 parent 951679c commit febea8d

4 files changed

Lines changed: 35 additions & 146 deletions

File tree

aux.py

Lines changed: 0 additions & 96 deletions
This file was deleted.

docarray/array/any_array.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
cast,
1818
overload,
1919
Tuple,
20-
get_args,
21-
get_origin,
2220
)
2321

2422
import numpy as np
@@ -28,13 +26,19 @@
2826
from docarray.exceptions.exceptions import UnusableObjectError
2927
from docarray.typing.abstract_type import AbstractType
3028
from docarray.utils._internal._typing import change_cls_name, safe_issubclass
29+
from docarray.utils._internal.pydantic import is_pydantic_v2
3130

3231
if TYPE_CHECKING:
3332
from docarray.proto import DocListProto, NodeProto
3433
from docarray.typing.tensor.abstract_tensor import AbstractTensor
3534

3635
if sys.version_info >= (3, 12):
3736
from types import GenericAlias
37+
else:
38+
try:
39+
from typing import GenericAlias
40+
except:
41+
from typing import _GenericAlias as GenericAlias
3842

3943
T = TypeVar('T', bound='AnyDocArray')
4044
T_doc = TypeVar('T_doc', bound=BaseDocWithoutId)
@@ -48,7 +52,7 @@
4852
)
4953

5054

51-
class AnyDocArray(AbstractType, Sequence[T_doc], Generic[T_doc]):
55+
class AnyDocArray(Sequence[T_doc], Generic[T_doc], AbstractType):
5256
doc_type: Type[BaseDocWithoutId]
5357
__typed_da__: Dict[Type['AnyDocArray'], Dict[Type[BaseDocWithoutId], Type]] = {}
5458

@@ -57,7 +61,6 @@ def __repr__(self):
5761

5862
@classmethod
5963
def __class_getitem__(cls, item: Union[Type[BaseDocWithoutId], TypeVar, str]):
60-
print(f' hey here {item}')
6164
if not isinstance(item, type):
6265
if sys.version_info < (3, 12):
6366
return Generic.__class_getitem__.__func__(cls, item) # type: ignore
@@ -76,10 +79,12 @@ def __class_getitem__(cls, item: Union[Type[BaseDocWithoutId], TypeVar, str]):
7679
if item not in cls.__typed_da__[cls]:
7780
# Promote to global scope so multiprocessing can pickle it
7881
global _DocArrayTyped
82+
7983
class _DocArrayTyped(cls, Generic[T_doc]): # type: ignore
8084
doc_type: Type[BaseDocWithoutId] = cast(Type[BaseDocWithoutId], item)
81-
# __origin__: Type['AnyDocArray'] = cls # add this
82-
# __args__: Tuple[Any, ...] = (item,) # add this
85+
if is_pydantic_v2:
86+
__origin__: Type['AnyDocArray'] = cls # add this
87+
__args__: Tuple[Any, ...] = (item,) # add this
8388

8489
for field in _DocArrayTyped.doc_type._docarray_fields().keys():
8590

@@ -109,13 +114,16 @@ def _setter(self, value):
109114
change_cls_name(
110115
_DocArrayTyped, f'{cls.__name__}[{item.__name__}]', globals()
111116
)
117+
if is_pydantic_v2:
118+
if sys.version_info < (3, 12):
119+
cls.__typed_da__[cls][item] = Generic.__class_getitem__.__func__(_DocArrayTyped, item) # type: ignore
120+
# this do nothing that checking that item is valid type var or str
121+
# Keep the approach in #1147 to be compatible with lower versions of Python.
122+
else:
123+
cls.__typed_da__[cls][item] = GenericAlias(_DocArrayTyped, item)
124+
else:
125+
cls.__typed_da__[cls][item] = _DocArrayTyped
112126

113-
cls.__typed_da__[cls][item] = _DocArrayTyped
114-
115-
print(f'return {cls.__typed_da__[cls][item]}')
116-
a = get_args(cls.__typed_da__[cls][item])
117-
print(f'a {a}')
118-
print(f'get_origin {get_origin(cls.__typed_da__[cls][item])}')
119127
return cls.__typed_da__[cls][item]
120128

121129
@overload

docarray/array/doc_list/doc_list.py

Lines changed: 12 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
overload,
1515
Callable,
1616
get_args,
17-
Generic
1817
)
1918

2019
from pydantic import parse_obj_as
@@ -31,7 +30,6 @@
3130
from docarray.utils._internal.pydantic import is_pydantic_v2
3231

3332
if is_pydantic_v2:
34-
from pydantic import GetCoreSchemaHandler
3533
from pydantic_core import core_schema
3634

3735
from docarray.utils._internal._typing import safe_issubclass
@@ -48,11 +46,7 @@
4846

4947

5048
class DocList(
51-
ListAdvancedIndexing[T_doc],
52-
PushPullMixin,
53-
IOMixinDocList,
54-
AnyDocArray[T_doc],
55-
Generic[T_doc]
49+
ListAdvancedIndexing[T_doc], PushPullMixin, IOMixinDocList, AnyDocArray[T_doc]
5650
):
5751
"""
5852
DocList is a container of Documents.
@@ -363,32 +357,15 @@ def __repr__(self):
363357
def __get_pydantic_core_schema__(
364358
cls, source: Any, handler: Callable[[Any], core_schema.CoreSchema]
365359
) -> core_schema.CoreSchema:
366-
def get_args_2(tp):
367-
"""Get type arguments with all substitutions performed.
368-
369-
For unions, basic simplifications used by Union constructor are performed.
370-
Examples::
371-
get_args(Dict[str, int]) == (str, int)
372-
get_args(int) == ()
373-
get_args(Union[int, Union[T, int], str][int]) == (int, str)
374-
get_args(Union[int, Tuple[T, int]][str]) == (int, Tuple[str, int])
375-
get_args(Callable[[], T][int]) == ([], int)
376-
"""
377-
from typing import _GenericAlias, get_origin
378-
import collections
379-
if isinstance(tp, _GenericAlias):
380-
res = tp.__args__
381-
if get_origin(tp) is collections.abc.Callable and res[0] is not Ellipsis:
382-
res = (list(res[:-1]), res[-1])
383-
return res
384-
else:
385-
print(f'IN ELSE')
386-
return ()
387-
388360
instance_schema = core_schema.is_instance_schema(cls)
389-
print(f'instance_schema {instance_schema} and {handler}')
390-
args = get_args_2(DocList[BaseDocWithoutId])
391-
print(f' args {args}')
392-
return core_schema.with_info_after_validator_function(
393-
function=cls.validate,
394-
schema=core_schema.list_schema(core_schema.any_schema()))
361+
362+
args = get_args(source)
363+
if args:
364+
sequence_t_schema = handler(Sequence[args[0]])
365+
else:
366+
sequence_t_schema = handler(Sequence)
367+
368+
non_instance_schema = core_schema.with_info_after_validator_function(
369+
lambda v, i: DocList(v), sequence_t_schema
370+
)
371+
return core_schema.union_schema([instance_schema, non_instance_schema])

docarray/documents/legacy/legacy_document.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# limitations under the License.
1616
from __future__ import annotations
1717

18-
from typing import Any, Dict, Optional, List, Union
18+
from typing import Any, Dict, Optional
1919

2020
from docarray import BaseDoc, DocList
2121
from docarray.typing import AnyEmbedding, AnyTensor
@@ -50,8 +50,8 @@ class LegacyDocument(BaseDoc):
5050
"""
5151

5252
tensor: Optional[AnyTensor] = None
53-
chunks: Optional[Union[DocList[LegacyDocument], List[LegacyDocument]]] = None
54-
matches: Optional[Union[DocList[LegacyDocument], List[LegacyDocument]]] = None
53+
chunks: Optional[DocList[LegacyDocument]] = None
54+
matches: Optional[DocList[LegacyDocument]] = None
5555
blob: Optional[bytes] = None
5656
text: Optional[str] = None
5757
url: Optional[str] = None

0 commit comments

Comments
 (0)