1717 cast ,
1818 overload ,
1919 Tuple ,
20- get_args ,
21- get_origin ,
2220)
2321
2422import numpy as np
2826from docarray .exceptions .exceptions import UnusableObjectError
2927from docarray .typing .abstract_type import AbstractType
3028from docarray .utils ._internal ._typing import change_cls_name , safe_issubclass
29+ from docarray .utils ._internal .pydantic import is_pydantic_v2
3130
3231if TYPE_CHECKING :
3332 from docarray .proto import DocListProto , NodeProto
3433 from docarray .typing .tensor .abstract_tensor import AbstractTensor
3534
3635if sys .version_info >= (3 , 12 ):
3736 from types import GenericAlias
37+ else :
38+ try :
39+ from typing import GenericAlias
40+ except :
41+ from typing import _GenericAlias as GenericAlias
3842
3943T = TypeVar ('T' , bound = 'AnyDocArray' )
4044T_doc = TypeVar ('T_doc' , bound = BaseDocWithoutId )
4852)
4953
5054
51- class AnyDocArray (AbstractType , Sequence [T_doc ], Generic [T_doc ]):
55+ class AnyDocArray (Sequence [T_doc ], Generic [T_doc ], AbstractType ):
5256 doc_type : Type [BaseDocWithoutId ]
5357 __typed_da__ : Dict [Type ['AnyDocArray' ], Dict [Type [BaseDocWithoutId ], Type ]] = {}
5458
@@ -57,7 +61,6 @@ def __repr__(self):
5761
5862 @classmethod
5963 def __class_getitem__ (cls , item : Union [Type [BaseDocWithoutId ], TypeVar , str ]):
60- print (f' hey here { item } ' )
6164 if not isinstance (item , type ):
6265 if sys .version_info < (3 , 12 ):
6366 return Generic .__class_getitem__ .__func__ (cls , item ) # type: ignore
@@ -76,10 +79,12 @@ def __class_getitem__(cls, item: Union[Type[BaseDocWithoutId], TypeVar, str]):
7679 if item not in cls .__typed_da__ [cls ]:
7780 # Promote to global scope so multiprocessing can pickle it
7881 global _DocArrayTyped
82+
7983 class _DocArrayTyped (cls , Generic [T_doc ]): # type: ignore
8084 doc_type : Type [BaseDocWithoutId ] = cast (Type [BaseDocWithoutId ], item )
81- # __origin__: Type['AnyDocArray'] = cls # add this
82- # __args__: Tuple[Any, ...] = (item,) # add this
85+ if is_pydantic_v2 :
86+ __origin__ : Type ['AnyDocArray' ] = cls # add this
87+ __args__ : Tuple [Any , ...] = (item ,) # add this
8388
8489 for field in _DocArrayTyped .doc_type ._docarray_fields ().keys ():
8590
@@ -109,13 +114,16 @@ def _setter(self, value):
109114 change_cls_name (
110115 _DocArrayTyped , f'{ cls .__name__ } [{ item .__name__ } ]' , globals ()
111116 )
117+ if is_pydantic_v2 :
118+ if sys .version_info < (3 , 12 ):
119+ cls .__typed_da__ [cls ][item ] = Generic .__class_getitem__ .__func__ (_DocArrayTyped , item ) # type: ignore
120+ # this do nothing that checking that item is valid type var or str
121+ # Keep the approach in #1147 to be compatible with lower versions of Python.
122+ else :
123+ cls .__typed_da__ [cls ][item ] = GenericAlias (_DocArrayTyped , item )
124+ else :
125+ cls .__typed_da__ [cls ][item ] = _DocArrayTyped
112126
113- cls .__typed_da__ [cls ][item ] = _DocArrayTyped
114-
115- print (f'return { cls .__typed_da__ [cls ][item ]} ' )
116- a = get_args (cls .__typed_da__ [cls ][item ])
117- print (f'a { a } ' )
118- print (f'get_origin { get_origin (cls .__typed_da__ [cls ][item ])} ' )
119127 return cls .__typed_da__ [cls ][item ]
120128
121129 @overload
0 commit comments