-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
91 lines (68 loc) · 3.4 KB
/
main.py
File metadata and controls
91 lines (68 loc) · 3.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import dataclasses
import typing
from collections.abc import Awaitable, Callable
import faststream
import modern_di
from faststream.asgi import AsgiFastStream
from faststream.types import DecodedMessage
from modern_di import Container, Scope, providers
T_co = typing.TypeVar("T_co", covariant=True)
P = typing.ParamSpec("P")
faststream_message_provider = providers.ContextProvider(scope=Scope.REQUEST, context_type=faststream.StreamMessage)
class _DIMiddlewareFactory:
__slots__ = ("di_container",)
def __init__(self, di_container: Container) -> None:
self.di_container = di_container
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> "_DiMiddleware[P]":
return _DiMiddleware(self.di_container, *args, **kwargs)
class _DiMiddleware(faststream.BaseMiddleware, typing.Generic[P]):
def __init__(self, di_container: Container, *args: P.args, **kwargs: P.kwargs) -> None:
self.di_container = di_container
# BaseMiddleware.__init__ expects (msg, /, *, context: ContextRepo); ParamSpec forwarding can't prove that.
super().__init__(*args, **kwargs) # ty: ignore[invalid-argument-type]
async def consume_scope(
self,
call_next: Callable[[typing.Any], Awaitable[typing.Any]],
msg: faststream.StreamMessage[typing.Any],
) -> typing.AsyncIterator[DecodedMessage]:
request_container = self.di_container.build_child_container(
scope=modern_di.Scope.REQUEST, context={faststream.StreamMessage: msg}
)
try:
with self.faststream_context.scope("request_container", request_container):
return typing.cast(
typing.AsyncIterator[DecodedMessage],
await call_next(msg),
)
finally:
await request_container.close_async()
@property
def faststream_context(self) -> faststream.ContextRepo:
return self.context
def fetch_di_container(app_: faststream.FastStream | AsgiFastStream) -> Container:
return typing.cast(Container, app_.context.get("di_container"))
def setup_di(
app: faststream.FastStream | AsgiFastStream,
container: Container,
) -> Container:
if not app.broker:
msg = "Broker must be defined to setup DI"
raise RuntimeError(msg)
container.providers_registry.add_providers(faststream_message_provider)
app.context.set_global("di_container", container)
app.after_shutdown(container.close_async)
# _DIMiddlewareFactory.__call__ ParamSpec doesn't structurally match BrokerMiddleware[Any, Any].
app.broker.add_middleware(_DIMiddlewareFactory(container)) # ty: ignore[invalid-argument-type]
return container
@dataclasses.dataclass(slots=True, frozen=True)
class Dependency(typing.Generic[T_co]):
dependency: providers.AbstractProvider[T_co] | type[T_co]
async def __call__(self, context: faststream.ContextRepo) -> T_co:
request_container: Container = context.get("request_container")
if isinstance(self.dependency, providers.AbstractProvider):
return request_container.resolve_provider(self.dependency)
return request_container.resolve(dependency_type=self.dependency)
def FromDI( # noqa: N802
dependency: providers.AbstractProvider[T_co] | type[T_co], *, use_cache: bool = True, cast: bool = False
) -> T_co:
return typing.cast(T_co, faststream.Depends(dependency=Dependency(dependency), use_cache=use_cache, cast=cast))