Skip to content

Commit 2335cf8

Browse files
feat(api): Use protocols and overload to type hint mixins
Currently mixins like ListMixin are type hinted to return base RESTObject instead of a specific class like `MergeRequest`. The GetMixin and GetWithoutIdMixin solve this problem by defining a new `get` method for every defined class. However, this creates a lot of duplicated code. `typing.Protocol` can be used to type hint that the mixed in method will return a class matching attribute `_obj_cls`. The type checker will lookup the mixed in class attribute and adjust the return type accordinly. Delete `tests/unit/meta/test_ensure_type_hints.py` file as the `get` method is no required to be defined for every class. Signed-off-by: Igor Ponomarev <igor.ponomarev@collabora.com>
1 parent f4f7d7a commit 2335cf8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+270
-1049
lines changed

gitlab/base.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,6 @@ class RESTManager:
347347
_create_attrs: g_types.RequiredOptional = g_types.RequiredOptional()
348348
_update_attrs: g_types.RequiredOptional = g_types.RequiredOptional()
349349
_path: Optional[str] = None
350-
_obj_cls: Optional[Type[RESTObject]] = None
351350
_from_parent_attrs: Dict[str, Any] = {}
352351
_types: Dict[str, Type[g_types.GitlabAttribute]] = {}
353352

gitlab/client.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -397,10 +397,11 @@ def auth(self) -> None:
397397
The `user` attribute will hold a `gitlab.objects.CurrentUser` object on
398398
success.
399399
"""
400-
self.user = self._objects.CurrentUserManager(self).get()
400+
user = self._objects.CurrentUserManager(self).get()
401+
self.user = user
401402

402-
if hasattr(self.user, "web_url") and hasattr(self.user, "username"):
403-
self._check_url(self.user.web_url, path=self.user.username)
403+
if hasattr(user, "web_url") and hasattr(user, "username"):
404+
self._check_url(user.web_url, path=user.username)
404405

405406
def version(self) -> Tuple[str, str]:
406407
"""Returns the version and revision of the gitlab server.

gitlab/mixins.py

Lines changed: 106 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import enum
24
from types import ModuleType
35
from typing import (
@@ -6,10 +8,14 @@
68
Dict,
79
Iterator,
810
List,
11+
Literal,
912
Optional,
13+
overload,
14+
Protocol,
1015
Tuple,
1116
Type,
1217
TYPE_CHECKING,
18+
TypeVar,
1319
Union,
1420
)
1521

@@ -52,6 +58,12 @@
5258
_RestManagerBase = object
5359
_RestObjectBase = object
5460

61+
TObjCls = TypeVar("TObjCls", bound=base.RESTObject)
62+
63+
64+
class ObjClsProtocol(Protocol[TObjCls]):
65+
_obj_cls: Type[TObjCls]
66+
5567

5668
class HeadMixin(_RestManagerBase):
5769
@exc.on_http_error(exc.GitlabHeadError)
@@ -84,13 +96,29 @@ def head(
8496
class GetMixin(HeadMixin, _RestManagerBase):
8597
_computed_path: Optional[str]
8698
_from_parent_attrs: Dict[str, Any]
87-
_obj_cls: Optional[Type[base.RESTObject]]
8899
_optional_get_attrs: Tuple[str, ...] = ()
100+
_obj_cls: Type[base.RESTObject]
89101
_parent: Optional[base.RESTObject]
90102
_parent_attrs: Dict[str, Any]
91103
_path: Optional[str]
92104
gitlab: gitlab.Gitlab
93105

106+
@overload
107+
def get(
108+
self: ObjClsProtocol[TObjCls],
109+
id: Union[str, int],
110+
lazy: bool = False,
111+
**kwargs: Any,
112+
) -> TObjCls: ...
113+
114+
@overload
115+
def get(
116+
self: Any,
117+
id: Union[str, int],
118+
lazy: bool = False,
119+
**kwargs: Any,
120+
) -> base.RESTObject: ...
121+
94122
@exc.on_http_error(exc.GitlabGetError)
95123
def get(
96124
self, id: Union[str, int], lazy: bool = False, **kwargs: Any
@@ -129,13 +157,19 @@ def get(
129157
class GetWithoutIdMixin(HeadMixin, _RestManagerBase):
130158
_computed_path: Optional[str]
131159
_from_parent_attrs: Dict[str, Any]
132-
_obj_cls: Optional[Type[base.RESTObject]]
133160
_optional_get_attrs: Tuple[str, ...] = ()
161+
_obj_cls: Type[base.RESTObject]
134162
_parent: Optional[base.RESTObject]
135163
_parent_attrs: Dict[str, Any]
136164
_path: Optional[str]
137165
gitlab: gitlab.Gitlab
138166

167+
@overload
168+
def get(self: ObjClsProtocol[TObjCls], **kwargs: Any) -> TObjCls: ...
169+
170+
@overload
171+
def get(self: Any, **kwargs: Any) -> base.RESTObject: ...
172+
139173
@exc.on_http_error(exc.GitlabGetError)
140174
def get(self, **kwargs: Any) -> base.RESTObject:
141175
"""Retrieve a single object.
@@ -196,14 +230,54 @@ class ListMixin(HeadMixin, _RestManagerBase):
196230
_computed_path: Optional[str]
197231
_from_parent_attrs: Dict[str, Any]
198232
_list_filters: Tuple[str, ...] = ()
199-
_obj_cls: Optional[Type[base.RESTObject]]
233+
_obj_cls: Type[base.RESTObject]
200234
_parent: Optional[base.RESTObject]
201235
_parent_attrs: Dict[str, Any]
202236
_path: Optional[str]
203237
gitlab: gitlab.Gitlab
204238

239+
@overload
240+
def list(
241+
self: ObjClsProtocol[TObjCls],
242+
*,
243+
iterator: Literal[False] = False,
244+
**kwargs: Any,
245+
) -> List[TObjCls]: ...
246+
247+
@overload
248+
def list(
249+
self: ObjClsProtocol[TObjCls],
250+
*,
251+
iterator: Literal[True] = True,
252+
**kwargs: Any,
253+
) -> base.RESTObjectList: ...
254+
255+
@overload
256+
def list(
257+
self: Any,
258+
*,
259+
iterator: Literal[False] = False,
260+
**kwargs: Any,
261+
) -> List[base.RESTObject]: ...
262+
263+
@overload
264+
def list(
265+
self: Any,
266+
*,
267+
iterator: Literal[True] = True,
268+
**kwargs: Any,
269+
) -> base.RESTObjectList: ...
270+
271+
@overload
272+
def list(
273+
self: Any,
274+
**kwargs: Any,
275+
) -> Union[base.RESTObjectList, List[base.RESTObject]]: ...
276+
205277
@exc.on_http_error(exc.GitlabListError)
206-
def list(self, **kwargs: Any) -> Union[base.RESTObjectList, List[base.RESTObject]]:
278+
def list(
279+
self, *, iterator: bool = False, **kwargs: Any
280+
) -> Union[base.RESTObjectList, List[Any]]:
207281
"""Retrieve a list of objects.
208282
209283
Args:
@@ -221,6 +295,7 @@ def list(self, **kwargs: Any) -> Union[base.RESTObjectList, List[base.RESTObject
221295
GitlabAuthenticationError: If authentication is not correct
222296
GitlabListError: If the server cannot perform the request
223297
"""
298+
kwargs.update(iterator=iterator)
224299

225300
data, _ = utils._transform_types(
226301
data=kwargs,
@@ -253,7 +328,6 @@ def list(self, **kwargs: Any) -> Union[base.RESTObjectList, List[base.RESTObject
253328
class RetrieveMixin(ListMixin, GetMixin):
254329
_computed_path: Optional[str]
255330
_from_parent_attrs: Dict[str, Any]
256-
_obj_cls: Optional[Type[base.RESTObject]]
257331
_parent: Optional[base.RESTObject]
258332
_parent_attrs: Dict[str, Any]
259333
_path: Optional[str]
@@ -263,12 +337,24 @@ class RetrieveMixin(ListMixin, GetMixin):
263337
class CreateMixin(_RestManagerBase):
264338
_computed_path: Optional[str]
265339
_from_parent_attrs: Dict[str, Any]
266-
_obj_cls: Optional[Type[base.RESTObject]]
340+
_obj_cls: Type[base.RESTObject]
267341
_parent: Optional[base.RESTObject]
268342
_parent_attrs: Dict[str, Any]
269343
_path: Optional[str]
270344
gitlab: gitlab.Gitlab
271345

346+
@overload
347+
def create(
348+
self: ObjClsProtocol[TObjCls],
349+
data: Optional[Dict[str, Any]] = None,
350+
**kwargs: Any,
351+
) -> TObjCls: ...
352+
353+
@overload
354+
def create(
355+
self: Any, data: Optional[Dict[str, Any]] = None, **kwargs: Any
356+
) -> base.RESTObject: ...
357+
272358
@exc.on_http_error(exc.GitlabCreateError)
273359
def create(
274360
self, data: Optional[Dict[str, Any]] = None, **kwargs: Any
@@ -385,12 +471,23 @@ def update(
385471
class SetMixin(_RestManagerBase):
386472
_computed_path: Optional[str]
387473
_from_parent_attrs: Dict[str, Any]
388-
_obj_cls: Optional[Type[base.RESTObject]]
474+
_obj_cls: Type[base.RESTObject]
389475
_parent: Optional[base.RESTObject]
390476
_parent_attrs: Dict[str, Any]
391477
_path: Optional[str]
392478
gitlab: gitlab.Gitlab
393479

480+
@overload
481+
def set(
482+
self: ObjClsProtocol[TObjCls],
483+
key: str,
484+
value: str,
485+
**kwargs: Any,
486+
) -> TObjCls: ...
487+
488+
@overload
489+
def set(self: Any, key: str, value: str, **kwargs: Any) -> base.RESTObject: ...
490+
394491
@exc.on_http_error(exc.GitlabSetError)
395492
def set(self, key: str, value: str, **kwargs: Any) -> base.RESTObject:
396493
"""Create or update the object.
@@ -450,7 +547,7 @@ def delete(self, id: Optional[Union[str, int]] = None, **kwargs: Any) -> None:
450547
class CRUDMixin(GetMixin, ListMixin, CreateMixin, UpdateMixin, DeleteMixin):
451548
_computed_path: Optional[str]
452549
_from_parent_attrs: Dict[str, Any]
453-
_obj_cls: Optional[Type[base.RESTObject]]
550+
_obj_cls: Type[base.RESTObject]
454551
_parent: Optional[base.RESTObject]
455552
_parent_attrs: Dict[str, Any]
456553
_path: Optional[str]
@@ -460,7 +557,7 @@ class CRUDMixin(GetMixin, ListMixin, CreateMixin, UpdateMixin, DeleteMixin):
460557
class NoUpdateMixin(GetMixin, ListMixin, CreateMixin, DeleteMixin):
461558
_computed_path: Optional[str]
462559
_from_parent_attrs: Dict[str, Any]
463-
_obj_cls: Optional[Type[base.RESTObject]]
560+
_obj_cls: Type[base.RESTObject]
464561
_parent: Optional[base.RESTObject]
465562
_parent_attrs: Dict[str, Any]
466563
_path: Optional[str]

gitlab/v4/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ def extend_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
401401
if not isinstance(cls, type):
402402
continue
403403
if issubclass(cls, gitlab.base.RESTManager):
404-
if cls._obj_cls is not None:
404+
if hasattr(cls, "_obj_cls"):
405405
classes.add(cls._obj_cls)
406406

407407
for cls in sorted(classes, key=operator.attrgetter("__name__")):

gitlab/v4/objects/appearance.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, cast, Dict, Optional, Union
1+
from typing import Any, Dict, Optional, Union
22

33
from gitlab import exceptions as exc
44
from gitlab.base import RESTManager, RESTObject
@@ -58,6 +58,3 @@ def update(
5858
new_data = new_data or {}
5959
data = new_data.copy()
6060
return super().update(id, data, **kwargs)
61-
62-
def get(self, **kwargs: Any) -> ApplicationAppearance:
63-
return cast(ApplicationAppearance, super().get(**kwargs))

gitlab/v4/objects/audit_events.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
https://docs.gitlab.com/ee/api/audit_events.html
44
"""
55

6-
from typing import Any, cast, Union
7-
86
from gitlab.base import RESTManager, RESTObject
97
from gitlab.mixins import RetrieveMixin
108

@@ -29,9 +27,6 @@ class AuditEventManager(RetrieveMixin, RESTManager):
2927
_obj_cls = AuditEvent
3028
_list_filters = ("created_after", "created_before", "entity_type", "entity_id")
3129

32-
def get(self, id: Union[str, int], lazy: bool = False, **kwargs: Any) -> AuditEvent:
33-
return cast(AuditEvent, super().get(id=id, lazy=lazy, **kwargs))
34-
3530

3631
class GroupAuditEvent(RESTObject):
3732
_id_attr = "id"
@@ -43,11 +38,6 @@ class GroupAuditEventManager(RetrieveMixin, RESTManager):
4338
_from_parent_attrs = {"group_id": "id"}
4439
_list_filters = ("created_after", "created_before")
4540

46-
def get(
47-
self, id: Union[str, int], lazy: bool = False, **kwargs: Any
48-
) -> GroupAuditEvent:
49-
return cast(GroupAuditEvent, super().get(id=id, lazy=lazy, **kwargs))
50-
5141

5242
class ProjectAuditEvent(RESTObject):
5343
_id_attr = "id"
@@ -59,11 +49,6 @@ class ProjectAuditEventManager(RetrieveMixin, RESTManager):
5949
_from_parent_attrs = {"project_id": "id"}
6050
_list_filters = ("created_after", "created_before")
6151

62-
def get(
63-
self, id: Union[str, int], lazy: bool = False, **kwargs: Any
64-
) -> ProjectAuditEvent:
65-
return cast(ProjectAuditEvent, super().get(id=id, lazy=lazy, **kwargs))
66-
6752

6853
class ProjectAudit(ProjectAuditEvent):
6954
pass

0 commit comments

Comments
 (0)