Skip to content

Commit 6de0b27

Browse files
committed
Add py.typed marker and gate pyright in CI
Ship a py.typed marker (PEP 561) and the Typing :: Typed classifier so downstream projects type check against PRAW's inline annotations; hatchling bundles the marker into the wheel automatically. Add a type dependency group with pyright, a [tool.pyright] config, and a tox type env, and add that env to the tox envlist so the shared CI lint job enforces zero pyright errors under standard mode. Enable reportUnnecessaryTypeIgnoreComment to keep ignores from going stale. Most fixes declare host-provided attributes on the various mixins, add Optional narrowing, correct return/argument annotations, and add @overload where a return type depends on argument values (e.g. DraftHelper.__call__). Notable changes: - Config: declare its dynamically-populated attributes (client_id, oauth_url, ratelimit_seconds, etc.) and widen **settings; drop the redundant None pre-init. - FullnameMixin._kind and LiveUpdate._kind are now properties so the property overrides in Comment/Submission/Message/Redditor/Subreddit are compatible. - ThingModerationMixin.thing is declared so pyright can resolve self.thing access within the mixin. - MoreComments, InlineMedia, and similar classes declare attributes set elsewhere in the object model. A handful of targeted # pyright: ignore comments remain where the root cause is prawcore's session() authorizer annotation (it accepts any BaseAuthorizer).
1 parent 74a51ba commit 6de0b27

59 files changed

Lines changed: 636 additions & 262 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

CHANGES.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ praw follows `semantic versioning <https://semver.org/>`_.
2929
``reddit_url`` endpoint, as such a file can redirect credentials to an untrusted host.
3030
The warning can be silenced by setting the ``PRAW_ALLOW_ENDPOINT_OVERRIDE``
3131
environment variable.
32+
- A ``py.typed`` marker (:PEP:`561`) so that downstream projects can type check against
33+
PRAW's inline annotations.
3234

3335
**Changed**
3436

praw/config.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,34 @@ def __str__(self) -> str:
2727
class Config:
2828
"""A class containing the configuration for a Reddit site."""
2929

30-
CONFIG = None
30+
CONFIG: configparser.ConfigParser | None = None
3131
CONFIG_NOT_SET = _NotSet() # Represents a config value that is not set.
3232
LOCK = Lock()
3333
INTERPOLATION_LEVEL = MappingProxyType({
3434
"basic": configparser.BasicInterpolation,
3535
"extended": configparser.ExtendedInterpolation,
3636
})
3737

38+
# Attributes populated by _initialize_attributes. client_id and user_agent are
39+
# validated as present by Reddit.__init__, so they are typed as required.
40+
client_id: str
41+
client_secret: str | None
42+
oauth_url: str
43+
password: str | None
44+
ratelimit_seconds: int
45+
reddit_url: str
46+
redirect_uri: str | None
47+
refresh_token: str | None
48+
timeout: int
49+
user_agent: str
50+
username: str | None
51+
3852
@staticmethod
39-
def _config_boolean(*, item: bool | str) -> bool:
53+
def _config_boolean(*, item: bool | str | _NotSet) -> bool:
4054
if isinstance(item, bool):
4155
return item
56+
if isinstance(item, _NotSet):
57+
return False
4258
return item.lower() in {"1", "yes", "true", "on"}
4359

4460
@classmethod
@@ -50,6 +66,7 @@ def _load_config(cls, *, config_interpolation: str | None = None) -> None:
5066
interpolator_class = None
5167

5268
config = configparser.ConfigParser(interpolation=interpolator_class)
69+
assert __package__ is not None
5370
with files(__package__).joinpath("praw.ini").open("r") as hdl:
5471
config.read_file(hdl)
5572

@@ -114,7 +131,7 @@ def short_url(self) -> str:
114131
:raises: :class:`.ClientException` if it is not set.
115132
116133
"""
117-
if self._short_url is self.CONFIG_NOT_SET:
134+
if isinstance(self._short_url, _NotSet):
118135
msg = "No short domain specified."
119136
raise ClientException(msg)
120137
return self._short_url
@@ -123,20 +140,17 @@ def __init__(
123140
self,
124141
site_name: str,
125142
config_interpolation: str | None = None,
126-
**settings: str,
143+
**settings: str | bool | int | None,
127144
) -> None:
128145
"""Initialize a :class:`.Config` instance."""
129146
with Config.LOCK:
130147
if Config.CONFIG is None:
131148
self._load_config(config_interpolation=config_interpolation)
132149

133150
self._settings = settings
151+
assert Config.CONFIG is not None
134152
self.custom = dict(Config.CONFIG.items(site_name), **settings)
135153

136-
self.client_id = self.client_secret = self.oauth_url = None
137-
self.reddit_url = self.refresh_token = self.redirect_uri = None
138-
self.password = self.user_agent = self.username = None
139-
140154
self._initialize_attributes()
141155

142156
def _fetch(self, key: str) -> Any:

praw/exceptions.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
from __future__ import annotations
1212

13+
from typing import cast
14+
1315

1416
class PRAWException(Exception): # noqa: N818
1517
"""The base PRAW Exception that all other exception classes extend."""
@@ -28,7 +30,7 @@ def error_message(self) -> str:
2830
error_str += f" on field {self.field!r}"
2931
return error_str
3032

31-
def __eq__(self, other: RedditErrorItem | list[str]) -> bool:
33+
def __eq__(self, other: object) -> bool:
3234
"""Check for equality."""
3335
if isinstance(other, RedditErrorItem):
3436
return (self.error_type, self.message, self.field) == (
@@ -196,6 +198,8 @@ def __init__(self, items: list[RedditErrorItem | list[str] | str]) -> None:
196198
197199
"""
198200
if isinstance(items, list) and isinstance(items[0], str):
199-
items = [items]
200-
self.items = self.parse_exception_list(items)
201+
parsed_items: list[RedditErrorItem | list[str]] = [cast("list[str]", items)]
202+
else:
203+
parsed_items = cast("list[RedditErrorItem | list[str]]", items)
204+
self.items = self.parse_exception_list(parsed_items)
201205
super().__init__(*self.items)

praw/models/auth.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def limits(self) -> dict[str, str | int | None]:
2525
requests.
2626
2727
"""
28+
assert self._reddit._core is not None
2829
data = self._reddit._core._rate_limiter
2930
return {
3031
"remaining": data.remaining,
@@ -41,6 +42,8 @@ def authorize(self, code: str) -> str | None:
4142
The session's active authorization will be updated upon success.
4243
4344
"""
45+
assert self._reddit._read_only_core is not None
46+
assert self._reddit._read_only_core._authorizer is not None
4447
authenticator = self._reddit._read_only_core._authorizer._authenticator
4548
authorizer = Authorizer(authenticator)
4649
authorizer.authorize(code)
@@ -64,11 +67,14 @@ def implicit(self, *, access_token: str, expires_in: int, scope: str) -> None:
6467
non-installed application type.
6568
6669
"""
70+
assert self._reddit._read_only_core is not None
71+
assert self._reddit._read_only_core._authorizer is not None
6772
authenticator = self._reddit._read_only_core._authorizer._authenticator
6873
if not isinstance(authenticator, UntrustedAuthenticator):
6974
raise InvalidImplicitAuth
7075
implicit_session = session(
71-
authorizer=ImplicitAuthorizer(authenticator, access_token, expires_in, scope),
76+
authorizer=ImplicitAuthorizer(authenticator, access_token, expires_in, scope), # pyright: ignore[reportArgumentType] # prawcore's session() is annotated Authorizer but Session accepts any BaseAuthorizer
77+
7278
window_size=self._reddit.config.window_size,
7379
)
7480
self._reddit._core = self._reddit._authorized_core = implicit_session
@@ -79,9 +85,12 @@ def scopes(self) -> set[str]:
7985
For read-only authorizations this should return ``{"*"}``.
8086
8187
"""
88+
assert self._reddit._core is not None
8289
authorizer = self._reddit._core._authorizer
90+
assert authorizer is not None
8391
if not authorizer.is_valid():
84-
authorizer.refresh()
92+
authorizer.refresh() # pyright: ignore[reportAttributeAccessIssue] # refresh is defined on Authorizer subclasses; the active core authorizer is always refreshable here
93+
assert authorizer.scopes is not None
8594
return authorizer.scopes
8695

8796
def url(
@@ -109,6 +118,8 @@ def url(
109118
whom the URL was generated for.
110119
111120
"""
121+
assert self._reddit._read_only_core is not None
122+
assert self._reddit._read_only_core._authorizer is not None
112123
authenticator = self._reddit._read_only_core._authorizer._authenticator
113124
if authenticator.redirect_uri is self._reddit.config.CONFIG_NOT_SET:
114125
msg = "redirect_uri must be provided"

praw/models/comment_forest.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
from heapq import heappop, heappush
6-
from typing import TYPE_CHECKING
6+
from typing import TYPE_CHECKING, cast
77

88
from praw.exceptions import DuplicateReplaceException
99
from praw.models.reddit.more import MoreComments
@@ -19,7 +19,7 @@ class CommentForest:
1919
2020
"""
2121

22-
def __getitem__(self, index: int) -> models.Comment:
22+
def __getitem__(self, index: int) -> models.Comment | models.MoreComments:
2323
"""Return the comment at position ``index`` in the list.
2424
2525
This method is to be used like an array access, such as:
@@ -43,7 +43,7 @@ def __len__(self) -> int:
4343
"""Return the number of top-level comments in the forest."""
4444
return len(self._comments)
4545

46-
def _insert_comment(self, comment: models.Comment) -> None:
46+
def _insert_comment(self, comment: models.Comment | models.MoreComments) -> None:
4747
if comment.name in self._submission._comments_by_id:
4848
raise DuplicateReplaceException
4949
comment.submission = self._submission
@@ -65,24 +65,26 @@ def list(
6565
was not called first.
6666
6767
"""
68-
comments = []
69-
queue = list(self)
68+
comments: list[models.Comment | models.MoreComments] = []
69+
queue = list(self._comments)
7070
while queue:
7171
comment = queue.pop(0)
7272
comments.append(comment)
7373
if not isinstance(comment, MoreComments):
74-
queue.extend(comment.replies)
74+
queue.extend(comment.replies._comments)
7575
return comments
7676

7777
@staticmethod
7878
def _gather_more_comments(
79-
tree: list[models.MoreComments],
79+
tree: list[models.Comment | models.MoreComments],
8080
*,
81-
parent_tree: list[models.MoreComments] | None = None,
81+
parent_tree: list[models.Comment | models.MoreComments] | None = None,
8282
) -> list[MoreComments]:
8383
"""Return a list of :class:`.MoreComments` objects obtained from tree."""
84-
more_comments = []
85-
queue = [(None, x) for x in tree]
84+
more_comments: list[MoreComments] = []
85+
queue: list[tuple[models.Comment | None, models.Comment | models.MoreComments]] = [
86+
(None, x) for x in tree
87+
]
8688
while queue:
8789
parent, comment = queue.pop(0)
8890
if isinstance(comment, MoreComments):
@@ -99,7 +101,7 @@ def _gather_more_comments(
99101
def __init__(
100102
self,
101103
submission: models.Submission,
102-
comments: list[models.Comment] | None = None,
104+
comments: list[models.Comment | models.MoreComments] | None = None,
103105
) -> None:
104106
"""Initialize a :class:`.CommentForest` instance.
105107
@@ -109,10 +111,12 @@ def __init__(
109111
``None``).
110112
111113
"""
112-
self._comments = comments
114+
self._comments: list[models.Comment | models.MoreComments] = (
115+
comments if comments is not None else []
116+
)
113117
self._submission = submission
114118

115-
def _update(self, comments: list[models.Comment]) -> None:
119+
def _update(self, comments: list[models.Comment | models.MoreComments]) -> None:
116120
self._comments = comments
117121
for comment in comments:
118122
comment.submission = self._submission
@@ -184,7 +188,10 @@ def replace_more(self, *, limit: int | None = 32, threshold: int = 0) -> list[mo
184188
item._remove_from.remove(item)
185189
continue
186190

187-
new_comments = item.comments(update=False)
191+
new_comments = cast(
192+
"list[models.Comment | models.MoreComments]",
193+
item.comments(update=False),
194+
)
188195
if remaining is not None:
189196
remaining -= 1
190197

praw/models/front.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from typing import TYPE_CHECKING
5+
from typing import TYPE_CHECKING, Any
66
from urllib.parse import urljoin
77

88
from praw.models.listing.generator import ListingGenerator
@@ -23,7 +23,7 @@ def __init__(self, reddit: praw.Reddit) -> None:
2323
super().__init__(reddit, _data=None)
2424
self._path = "/"
2525

26-
def best(self, **generator_kwargs: str | int) -> Iterator[models.Submission]:
26+
def best(self, **generator_kwargs: Any) -> Iterator[models.Submission]:
2727
"""Return a :class:`.ListingGenerator` for best items.
2828
2929
Additional keyword arguments are passed in the initialization of

praw/models/helpers.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
from json import dumps
6-
from typing import TYPE_CHECKING, Any
6+
from typing import TYPE_CHECKING, Any, overload
77

88
from praw.const import API_PATH
99
from praw.models.base import PRAWBase
@@ -27,6 +27,12 @@ class DraftHelper(PRAWBase):
2727
2828
"""
2929

30+
@overload
31+
def __call__(self, draft_id: None = None) -> list[models.Draft]: ...
32+
33+
@overload
34+
def __call__(self, draft_id: str) -> models.Draft: ...
35+
3036
def __call__(self, draft_id: str | None = None) -> list[models.Draft] | models.Draft:
3137
"""Return a list of :class:`.Draft` instances.
3238
@@ -217,7 +223,7 @@ def generator() -> Iterator[models.LiveThread]:
217223
for position in range(0, len(ids), 100):
218224
ids_chunk = ids[position : position + 100]
219225
url = API_PATH["live_info"].format(ids=",".join(ids_chunk))
220-
params = {"limit": 100} # 25 is used if not specified
226+
params: dict[str, str | int] = {"limit": 100} # 25 is used if not specified
221227
yield from self._reddit.get(url, params=params)
222228

223229
return generator()
@@ -259,7 +265,7 @@ def create(
259265
display_name: str,
260266
icon_name: str | None = None,
261267
key_color: str | None = None,
262-
subreddits: str | models.Subreddit,
268+
subreddits: list[str | models.Subreddit],
263269
visibility: str = "private",
264270
weighting_scheme: str = "classic",
265271
) -> models.Multireddit:
@@ -317,7 +323,7 @@ def create(
317323
subreddit_type: str = "public",
318324
title: str | None = None,
319325
wikimode: str = "disabled",
320-
**other_settings: str | None,
326+
**other_settings: Any,
321327
) -> models.Subreddit:
322328
"""Create a new :class:`.Subreddit`.
323329

0 commit comments

Comments
 (0)