Skip to content

Commit 12560a7

Browse files
authored
chore: split AsyncEngine into _engine.py (#1218)
1 parent 6d83f99 commit 12560a7

7 files changed

Lines changed: 655 additions & 578 deletions

File tree

src/zeroconf/_core.py

Lines changed: 4 additions & 331 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,15 @@
2121
"""
2222

2323
import asyncio
24-
import itertools
2524
import logging
26-
import random
27-
import socket
2825
import sys
2926
import threading
30-
from types import TracebackType # noqa # used in type hints
31-
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Type, Union, cast
27+
from types import TracebackType
28+
from typing import Awaitable, Dict, List, Optional, Tuple, Type, Union
3229

3330
from ._cache import DNSCache
3431
from ._dns import DNSQuestion, DNSQuestionType
32+
from ._engine import AsyncEngine, _WrappedTransport
3533
from ._exceptions import NonUniqueNameException, NotRunningException
3634
from ._handlers import (
3735
MulticastOutgoingQueue,
@@ -48,7 +46,7 @@
4846
from ._services.browser import ServiceBrowser
4947
from ._services.info import ServiceInfo, instance_name_from_service_info
5048
from ._services.registry import ServiceRegistry
51-
from ._updates import RecordUpdate, RecordUpdateListener
49+
from ._updates import RecordUpdateListener
5250
from ._utils.asyncio import (
5351
await_awaitable,
5452
get_running_loop,
@@ -67,11 +65,9 @@
6765
)
6866
from ._utils.time import current_time_millis, millis_to_seconds
6967
from .const import (
70-
_CACHE_CLEANUP_INTERVAL,
7168
_CHECK_TIME,
7269
_CLASS_IN,
7370
_CLASS_UNIQUE,
74-
_DUPLICATE_PACKET_SUPPRESSION_INTERVAL,
7571
_FLAGS_AA,
7672
_FLAGS_QR_QUERY,
7773
_FLAGS_QR_RESPONSE,
@@ -86,7 +82,6 @@
8682
_UNREGISTER_TIME,
8783
)
8884

89-
_TC_DELAY_RANDOM_INTERVAL = (400, 500)
9085
# The maximum amont of time to delay a multicast
9186
# response in order to aggregate answers
9287
_AGGREGATION_DELAY = 500 # ms
@@ -102,331 +97,9 @@
10297
# 3000ms
10398
_PROTECTED_AGGREGATION_DELAY = 200 # ms
10499

105-
_CLOSE_TIMEOUT = 3000 # ms
106100
_REGISTER_BROADCASTS = 3
107101

108102

109-
class _WrappedTransport:
110-
"""A wrapper for transports."""
111-
112-
__slots__ = (
113-
'transport',
114-
'is_ipv6',
115-
'sock',
116-
'fileno',
117-
'sock_name',
118-
)
119-
120-
def __init__(
121-
self,
122-
transport: asyncio.DatagramTransport,
123-
is_ipv6: bool,
124-
sock: socket.socket,
125-
fileno: int,
126-
sock_name: Any,
127-
) -> None:
128-
"""Initialize the wrapped transport.
129-
130-
These attributes are used when sending packets.
131-
"""
132-
self.transport = transport
133-
self.is_ipv6 = is_ipv6
134-
self.sock = sock
135-
self.fileno = fileno
136-
self.sock_name = sock_name
137-
138-
139-
def _make_wrapped_transport(transport: asyncio.DatagramTransport) -> _WrappedTransport:
140-
"""Make a wrapped transport."""
141-
sock: socket.socket = transport.get_extra_info('socket')
142-
return _WrappedTransport(
143-
transport=transport,
144-
is_ipv6=sock.family == socket.AF_INET6,
145-
sock=sock,
146-
fileno=sock.fileno(),
147-
sock_name=sock.getsockname(),
148-
)
149-
150-
151-
class AsyncEngine:
152-
"""An engine wraps sockets in the event loop."""
153-
154-
__slots__ = (
155-
'loop',
156-
'zc',
157-
'protocols',
158-
'readers',
159-
'senders',
160-
'running_event',
161-
'_listen_socket',
162-
'_respond_sockets',
163-
'_cleanup_timer',
164-
)
165-
166-
def __init__(
167-
self,
168-
zeroconf: 'Zeroconf',
169-
listen_socket: Optional[socket.socket],
170-
respond_sockets: List[socket.socket],
171-
) -> None:
172-
self.loop: Optional[asyncio.AbstractEventLoop] = None
173-
self.zc = zeroconf
174-
self.protocols: List[AsyncListener] = []
175-
self.readers: List[_WrappedTransport] = []
176-
self.senders: List[_WrappedTransport] = []
177-
self.running_event: Optional[asyncio.Event] = None
178-
self._listen_socket = listen_socket
179-
self._respond_sockets = respond_sockets
180-
self._cleanup_timer: Optional[asyncio.TimerHandle] = None
181-
182-
def setup(self, loop: asyncio.AbstractEventLoop, loop_thread_ready: Optional[threading.Event]) -> None:
183-
"""Set up the instance."""
184-
self.loop = loop
185-
self.running_event = asyncio.Event()
186-
self.loop.create_task(self._async_setup(loop_thread_ready))
187-
188-
async def _async_setup(self, loop_thread_ready: Optional[threading.Event]) -> None:
189-
"""Set up the instance."""
190-
assert self.loop is not None
191-
self._cleanup_timer = self.loop.call_later(_CACHE_CLEANUP_INTERVAL, self._async_cache_cleanup)
192-
await self._async_create_endpoints()
193-
assert self.running_event is not None
194-
self.running_event.set()
195-
if loop_thread_ready:
196-
loop_thread_ready.set()
197-
198-
async def _async_create_endpoints(self) -> None:
199-
"""Create endpoints to send and receive."""
200-
assert self.loop is not None
201-
loop = self.loop
202-
reader_sockets = []
203-
sender_sockets = []
204-
if self._listen_socket:
205-
reader_sockets.append(self._listen_socket)
206-
for s in self._respond_sockets:
207-
if s not in reader_sockets:
208-
reader_sockets.append(s)
209-
sender_sockets.append(s)
210-
211-
for s in reader_sockets:
212-
transport, protocol = await loop.create_datagram_endpoint(lambda: AsyncListener(self.zc), sock=s)
213-
self.protocols.append(cast(AsyncListener, protocol))
214-
self.readers.append(_make_wrapped_transport(cast(asyncio.DatagramTransport, transport)))
215-
if s in sender_sockets:
216-
self.senders.append(_make_wrapped_transport(cast(asyncio.DatagramTransport, transport)))
217-
218-
def _async_cache_cleanup(self) -> None:
219-
"""Periodic cache cleanup."""
220-
now = current_time_millis()
221-
self.zc.question_history.async_expire(now)
222-
self.zc.record_manager.async_updates(
223-
now, [RecordUpdate(record, record) for record in self.zc.cache.async_expire(now)]
224-
)
225-
self.zc.record_manager.async_updates_complete(False)
226-
assert self.loop is not None
227-
self._cleanup_timer = self.loop.call_later(_CACHE_CLEANUP_INTERVAL, self._async_cache_cleanup)
228-
229-
async def _async_close(self) -> None:
230-
"""Cancel and wait for the cleanup task to finish."""
231-
self._async_shutdown()
232-
await asyncio.sleep(0) # flush out any call soons
233-
assert self._cleanup_timer is not None
234-
self._cleanup_timer.cancel()
235-
236-
def _async_shutdown(self) -> None:
237-
"""Shutdown transports and sockets."""
238-
assert self.running_event is not None
239-
self.running_event.clear()
240-
for wrapped_transport in itertools.chain(self.senders, self.readers):
241-
wrapped_transport.transport.close()
242-
243-
def close(self) -> None:
244-
"""Close from sync context.
245-
246-
While it is not expected during normal operation,
247-
this function may raise EventLoopBlocked if the underlying
248-
call to `_async_close` cannot be completed.
249-
"""
250-
assert self.loop is not None
251-
# Guard against Zeroconf.close() being called from the eventloop
252-
if get_running_loop() == self.loop:
253-
self._async_shutdown()
254-
return
255-
if not self.loop.is_running():
256-
return
257-
run_coro_with_timeout(self._async_close(), self.loop, _CLOSE_TIMEOUT)
258-
259-
260-
class AsyncListener(asyncio.Protocol, QuietLogger):
261-
262-
"""A Listener is used by this module to listen on the multicast
263-
group to which DNS messages are sent, allowing the implementation
264-
to cache information as it arrives.
265-
266-
It requires registration with an Engine object in order to have
267-
the read() method called when a socket is available for reading."""
268-
269-
__slots__ = ('zc', 'data', 'last_time', 'transport', 'sock_description', '_deferred', '_timers')
270-
271-
def __init__(self, zc: 'Zeroconf') -> None:
272-
self.zc = zc
273-
self.data: Optional[bytes] = None
274-
self.last_time: float = 0
275-
self.last_message: Optional[DNSIncoming] = None
276-
self.transport: Optional[_WrappedTransport] = None
277-
self.sock_description: Optional[str] = None
278-
self._deferred: Dict[str, List[DNSIncoming]] = {}
279-
self._timers: Dict[str, asyncio.TimerHandle] = {}
280-
super().__init__()
281-
282-
def datagram_received(
283-
self, data: bytes, addrs: Union[Tuple[str, int], Tuple[str, int, int, int]]
284-
) -> None:
285-
assert self.transport is not None
286-
data_len = len(data)
287-
debug = log.isEnabledFor(logging.DEBUG)
288-
289-
if data_len > _MAX_MSG_ABSOLUTE:
290-
# Guard against oversized packets to ensure bad implementations cannot overwhelm
291-
# the system.
292-
if debug:
293-
log.debug(
294-
"Discarding incoming packet with length %s, which is larger "
295-
"than the absolute maximum size of %s",
296-
data_len,
297-
_MAX_MSG_ABSOLUTE,
298-
)
299-
return
300-
301-
now = current_time_millis()
302-
if (
303-
self.data == data
304-
and (now - _DUPLICATE_PACKET_SUPPRESSION_INTERVAL) < self.last_time
305-
and self.last_message is not None
306-
and not self.last_message.has_qu_question()
307-
):
308-
# Guard against duplicate packets
309-
if debug:
310-
log.debug(
311-
'Ignoring duplicate message with no unicast questions received from %s [socket %s] (%d bytes) as [%r]',
312-
addrs,
313-
self.sock_description,
314-
data_len,
315-
data,
316-
)
317-
return
318-
319-
v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = ()
320-
if len(addrs) == 2:
321-
# https://github.com/python/mypy/issues/1178
322-
addr, port = addrs # type: ignore
323-
scope = None
324-
else:
325-
# https://github.com/python/mypy/issues/1178
326-
addr, port, flow, scope = addrs # type: ignore
327-
if debug:
328-
log.debug('IPv6 scope_id %d associated to the receiving interface', scope)
329-
v6_flow_scope = (flow, scope)
330-
331-
msg = DNSIncoming(data, (addr, port), scope, now)
332-
self.data = data
333-
self.last_time = now
334-
self.last_message = msg
335-
if msg.valid:
336-
if debug:
337-
log.debug(
338-
'Received from %r:%r [socket %s]: %r (%d bytes) as [%r]',
339-
addr,
340-
port,
341-
self.sock_description,
342-
msg,
343-
data_len,
344-
data,
345-
)
346-
else:
347-
if debug:
348-
log.debug(
349-
'Received from %r:%r [socket %s]: (%d bytes) [%r]',
350-
addr,
351-
port,
352-
self.sock_description,
353-
data_len,
354-
data,
355-
)
356-
return
357-
358-
if not msg.is_query():
359-
self.zc.handle_response(msg)
360-
return
361-
362-
self.handle_query_or_defer(msg, addr, port, self.transport, v6_flow_scope)
363-
364-
def handle_query_or_defer(
365-
self,
366-
msg: DNSIncoming,
367-
addr: str,
368-
port: int,
369-
transport: _WrappedTransport,
370-
v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
371-
) -> None:
372-
"""Deal with incoming query packets. Provides a response if
373-
possible."""
374-
if not msg.truncated:
375-
self._respond_query(msg, addr, port, transport, v6_flow_scope)
376-
return
377-
378-
deferred = self._deferred.setdefault(addr, [])
379-
# If we get the same packet we ignore it
380-
for incoming in reversed(deferred):
381-
if incoming.data == msg.data:
382-
return
383-
deferred.append(msg)
384-
delay = millis_to_seconds(random.randint(*_TC_DELAY_RANDOM_INTERVAL))
385-
assert self.zc.loop is not None
386-
self._cancel_any_timers_for_addr(addr)
387-
self._timers[addr] = self.zc.loop.call_later(
388-
delay, self._respond_query, None, addr, port, transport, v6_flow_scope
389-
)
390-
391-
def _cancel_any_timers_for_addr(self, addr: str) -> None:
392-
"""Cancel any future truncated packet timers for the address."""
393-
if addr in self._timers:
394-
self._timers.pop(addr).cancel()
395-
396-
def _respond_query(
397-
self,
398-
msg: Optional[DNSIncoming],
399-
addr: str,
400-
port: int,
401-
transport: _WrappedTransport,
402-
v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (),
403-
) -> None:
404-
"""Respond to a query and reassemble any truncated deferred packets."""
405-
self._cancel_any_timers_for_addr(addr)
406-
packets = self._deferred.pop(addr, [])
407-
if msg:
408-
packets.append(msg)
409-
410-
self.zc.handle_assembled_query(packets, addr, port, transport, v6_flow_scope)
411-
412-
def error_received(self, exc: Exception) -> None:
413-
"""Likely socket closed or IPv6."""
414-
# We preformat the message string with the socket as we want
415-
# log_exception_once to log a warrning message once PER EACH
416-
# different socket in case there are problems with multiple
417-
# sockets
418-
msg_str = f"Error with socket {self.sock_description}): %s"
419-
self.log_exception_once(exc, msg_str, exc)
420-
421-
def connection_made(self, transport: asyncio.BaseTransport) -> None:
422-
wrapped_transport = _make_wrapped_transport(cast(asyncio.DatagramTransport, transport))
423-
self.transport = wrapped_transport
424-
self.sock_description = f"{wrapped_transport.fileno} ({wrapped_transport.sock_name})"
425-
426-
def connection_lost(self, exc: Optional[Exception]) -> None:
427-
"""Handle connection lost."""
428-
429-
430103
def async_send_with_transport(
431104
log_debug: bool,
432105
transport: _WrappedTransport,

0 commit comments

Comments
 (0)