|
21 | 21 | """ |
22 | 22 |
|
23 | 23 | import asyncio |
24 | | -import itertools |
25 | 24 | import logging |
26 | | -import random |
27 | | -import socket |
28 | 25 | import sys |
29 | 26 | import threading |
30 | | -from types import TracebackType # noqa # used in type hints |
31 | | -from typing import Any, Awaitable, Dict, List, Optional, Tuple, Type, Union, cast |
| 27 | +from types import TracebackType |
| 28 | +from typing import Awaitable, Dict, List, Optional, Tuple, Type, Union |
32 | 29 |
|
33 | 30 | from ._cache import DNSCache |
34 | 31 | from ._dns import DNSQuestion, DNSQuestionType |
| 32 | +from ._engine import AsyncEngine, _WrappedTransport |
35 | 33 | from ._exceptions import NonUniqueNameException, NotRunningException |
36 | 34 | from ._handlers import ( |
37 | 35 | MulticastOutgoingQueue, |
|
48 | 46 | from ._services.browser import ServiceBrowser |
49 | 47 | from ._services.info import ServiceInfo, instance_name_from_service_info |
50 | 48 | from ._services.registry import ServiceRegistry |
51 | | -from ._updates import RecordUpdate, RecordUpdateListener |
| 49 | +from ._updates import RecordUpdateListener |
52 | 50 | from ._utils.asyncio import ( |
53 | 51 | await_awaitable, |
54 | 52 | get_running_loop, |
|
67 | 65 | ) |
68 | 66 | from ._utils.time import current_time_millis, millis_to_seconds |
69 | 67 | from .const import ( |
70 | | - _CACHE_CLEANUP_INTERVAL, |
71 | 68 | _CHECK_TIME, |
72 | 69 | _CLASS_IN, |
73 | 70 | _CLASS_UNIQUE, |
74 | | - _DUPLICATE_PACKET_SUPPRESSION_INTERVAL, |
75 | 71 | _FLAGS_AA, |
76 | 72 | _FLAGS_QR_QUERY, |
77 | 73 | _FLAGS_QR_RESPONSE, |
|
86 | 82 | _UNREGISTER_TIME, |
87 | 83 | ) |
88 | 84 |
|
89 | | -_TC_DELAY_RANDOM_INTERVAL = (400, 500) |
90 | 85 | # The maximum amont of time to delay a multicast |
91 | 86 | # response in order to aggregate answers |
92 | 87 | _AGGREGATION_DELAY = 500 # ms |
|
102 | 97 | # 3000ms |
103 | 98 | _PROTECTED_AGGREGATION_DELAY = 200 # ms |
104 | 99 |
|
105 | | -_CLOSE_TIMEOUT = 3000 # ms |
106 | 100 | _REGISTER_BROADCASTS = 3 |
107 | 101 |
|
108 | 102 |
|
109 | | -class _WrappedTransport: |
110 | | - """A wrapper for transports.""" |
111 | | - |
112 | | - __slots__ = ( |
113 | | - 'transport', |
114 | | - 'is_ipv6', |
115 | | - 'sock', |
116 | | - 'fileno', |
117 | | - 'sock_name', |
118 | | - ) |
119 | | - |
120 | | - def __init__( |
121 | | - self, |
122 | | - transport: asyncio.DatagramTransport, |
123 | | - is_ipv6: bool, |
124 | | - sock: socket.socket, |
125 | | - fileno: int, |
126 | | - sock_name: Any, |
127 | | - ) -> None: |
128 | | - """Initialize the wrapped transport. |
129 | | -
|
130 | | - These attributes are used when sending packets. |
131 | | - """ |
132 | | - self.transport = transport |
133 | | - self.is_ipv6 = is_ipv6 |
134 | | - self.sock = sock |
135 | | - self.fileno = fileno |
136 | | - self.sock_name = sock_name |
137 | | - |
138 | | - |
139 | | -def _make_wrapped_transport(transport: asyncio.DatagramTransport) -> _WrappedTransport: |
140 | | - """Make a wrapped transport.""" |
141 | | - sock: socket.socket = transport.get_extra_info('socket') |
142 | | - return _WrappedTransport( |
143 | | - transport=transport, |
144 | | - is_ipv6=sock.family == socket.AF_INET6, |
145 | | - sock=sock, |
146 | | - fileno=sock.fileno(), |
147 | | - sock_name=sock.getsockname(), |
148 | | - ) |
149 | | - |
150 | | - |
151 | | -class AsyncEngine: |
152 | | - """An engine wraps sockets in the event loop.""" |
153 | | - |
154 | | - __slots__ = ( |
155 | | - 'loop', |
156 | | - 'zc', |
157 | | - 'protocols', |
158 | | - 'readers', |
159 | | - 'senders', |
160 | | - 'running_event', |
161 | | - '_listen_socket', |
162 | | - '_respond_sockets', |
163 | | - '_cleanup_timer', |
164 | | - ) |
165 | | - |
166 | | - def __init__( |
167 | | - self, |
168 | | - zeroconf: 'Zeroconf', |
169 | | - listen_socket: Optional[socket.socket], |
170 | | - respond_sockets: List[socket.socket], |
171 | | - ) -> None: |
172 | | - self.loop: Optional[asyncio.AbstractEventLoop] = None |
173 | | - self.zc = zeroconf |
174 | | - self.protocols: List[AsyncListener] = [] |
175 | | - self.readers: List[_WrappedTransport] = [] |
176 | | - self.senders: List[_WrappedTransport] = [] |
177 | | - self.running_event: Optional[asyncio.Event] = None |
178 | | - self._listen_socket = listen_socket |
179 | | - self._respond_sockets = respond_sockets |
180 | | - self._cleanup_timer: Optional[asyncio.TimerHandle] = None |
181 | | - |
182 | | - def setup(self, loop: asyncio.AbstractEventLoop, loop_thread_ready: Optional[threading.Event]) -> None: |
183 | | - """Set up the instance.""" |
184 | | - self.loop = loop |
185 | | - self.running_event = asyncio.Event() |
186 | | - self.loop.create_task(self._async_setup(loop_thread_ready)) |
187 | | - |
188 | | - async def _async_setup(self, loop_thread_ready: Optional[threading.Event]) -> None: |
189 | | - """Set up the instance.""" |
190 | | - assert self.loop is not None |
191 | | - self._cleanup_timer = self.loop.call_later(_CACHE_CLEANUP_INTERVAL, self._async_cache_cleanup) |
192 | | - await self._async_create_endpoints() |
193 | | - assert self.running_event is not None |
194 | | - self.running_event.set() |
195 | | - if loop_thread_ready: |
196 | | - loop_thread_ready.set() |
197 | | - |
198 | | - async def _async_create_endpoints(self) -> None: |
199 | | - """Create endpoints to send and receive.""" |
200 | | - assert self.loop is not None |
201 | | - loop = self.loop |
202 | | - reader_sockets = [] |
203 | | - sender_sockets = [] |
204 | | - if self._listen_socket: |
205 | | - reader_sockets.append(self._listen_socket) |
206 | | - for s in self._respond_sockets: |
207 | | - if s not in reader_sockets: |
208 | | - reader_sockets.append(s) |
209 | | - sender_sockets.append(s) |
210 | | - |
211 | | - for s in reader_sockets: |
212 | | - transport, protocol = await loop.create_datagram_endpoint(lambda: AsyncListener(self.zc), sock=s) |
213 | | - self.protocols.append(cast(AsyncListener, protocol)) |
214 | | - self.readers.append(_make_wrapped_transport(cast(asyncio.DatagramTransport, transport))) |
215 | | - if s in sender_sockets: |
216 | | - self.senders.append(_make_wrapped_transport(cast(asyncio.DatagramTransport, transport))) |
217 | | - |
218 | | - def _async_cache_cleanup(self) -> None: |
219 | | - """Periodic cache cleanup.""" |
220 | | - now = current_time_millis() |
221 | | - self.zc.question_history.async_expire(now) |
222 | | - self.zc.record_manager.async_updates( |
223 | | - now, [RecordUpdate(record, record) for record in self.zc.cache.async_expire(now)] |
224 | | - ) |
225 | | - self.zc.record_manager.async_updates_complete(False) |
226 | | - assert self.loop is not None |
227 | | - self._cleanup_timer = self.loop.call_later(_CACHE_CLEANUP_INTERVAL, self._async_cache_cleanup) |
228 | | - |
229 | | - async def _async_close(self) -> None: |
230 | | - """Cancel and wait for the cleanup task to finish.""" |
231 | | - self._async_shutdown() |
232 | | - await asyncio.sleep(0) # flush out any call soons |
233 | | - assert self._cleanup_timer is not None |
234 | | - self._cleanup_timer.cancel() |
235 | | - |
236 | | - def _async_shutdown(self) -> None: |
237 | | - """Shutdown transports and sockets.""" |
238 | | - assert self.running_event is not None |
239 | | - self.running_event.clear() |
240 | | - for wrapped_transport in itertools.chain(self.senders, self.readers): |
241 | | - wrapped_transport.transport.close() |
242 | | - |
243 | | - def close(self) -> None: |
244 | | - """Close from sync context. |
245 | | -
|
246 | | - While it is not expected during normal operation, |
247 | | - this function may raise EventLoopBlocked if the underlying |
248 | | - call to `_async_close` cannot be completed. |
249 | | - """ |
250 | | - assert self.loop is not None |
251 | | - # Guard against Zeroconf.close() being called from the eventloop |
252 | | - if get_running_loop() == self.loop: |
253 | | - self._async_shutdown() |
254 | | - return |
255 | | - if not self.loop.is_running(): |
256 | | - return |
257 | | - run_coro_with_timeout(self._async_close(), self.loop, _CLOSE_TIMEOUT) |
258 | | - |
259 | | - |
260 | | -class AsyncListener(asyncio.Protocol, QuietLogger): |
261 | | - |
262 | | - """A Listener is used by this module to listen on the multicast |
263 | | - group to which DNS messages are sent, allowing the implementation |
264 | | - to cache information as it arrives. |
265 | | -
|
266 | | - It requires registration with an Engine object in order to have |
267 | | - the read() method called when a socket is available for reading.""" |
268 | | - |
269 | | - __slots__ = ('zc', 'data', 'last_time', 'transport', 'sock_description', '_deferred', '_timers') |
270 | | - |
271 | | - def __init__(self, zc: 'Zeroconf') -> None: |
272 | | - self.zc = zc |
273 | | - self.data: Optional[bytes] = None |
274 | | - self.last_time: float = 0 |
275 | | - self.last_message: Optional[DNSIncoming] = None |
276 | | - self.transport: Optional[_WrappedTransport] = None |
277 | | - self.sock_description: Optional[str] = None |
278 | | - self._deferred: Dict[str, List[DNSIncoming]] = {} |
279 | | - self._timers: Dict[str, asyncio.TimerHandle] = {} |
280 | | - super().__init__() |
281 | | - |
282 | | - def datagram_received( |
283 | | - self, data: bytes, addrs: Union[Tuple[str, int], Tuple[str, int, int, int]] |
284 | | - ) -> None: |
285 | | - assert self.transport is not None |
286 | | - data_len = len(data) |
287 | | - debug = log.isEnabledFor(logging.DEBUG) |
288 | | - |
289 | | - if data_len > _MAX_MSG_ABSOLUTE: |
290 | | - # Guard against oversized packets to ensure bad implementations cannot overwhelm |
291 | | - # the system. |
292 | | - if debug: |
293 | | - log.debug( |
294 | | - "Discarding incoming packet with length %s, which is larger " |
295 | | - "than the absolute maximum size of %s", |
296 | | - data_len, |
297 | | - _MAX_MSG_ABSOLUTE, |
298 | | - ) |
299 | | - return |
300 | | - |
301 | | - now = current_time_millis() |
302 | | - if ( |
303 | | - self.data == data |
304 | | - and (now - _DUPLICATE_PACKET_SUPPRESSION_INTERVAL) < self.last_time |
305 | | - and self.last_message is not None |
306 | | - and not self.last_message.has_qu_question() |
307 | | - ): |
308 | | - # Guard against duplicate packets |
309 | | - if debug: |
310 | | - log.debug( |
311 | | - 'Ignoring duplicate message with no unicast questions received from %s [socket %s] (%d bytes) as [%r]', |
312 | | - addrs, |
313 | | - self.sock_description, |
314 | | - data_len, |
315 | | - data, |
316 | | - ) |
317 | | - return |
318 | | - |
319 | | - v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = () |
320 | | - if len(addrs) == 2: |
321 | | - # https://github.com/python/mypy/issues/1178 |
322 | | - addr, port = addrs # type: ignore |
323 | | - scope = None |
324 | | - else: |
325 | | - # https://github.com/python/mypy/issues/1178 |
326 | | - addr, port, flow, scope = addrs # type: ignore |
327 | | - if debug: |
328 | | - log.debug('IPv6 scope_id %d associated to the receiving interface', scope) |
329 | | - v6_flow_scope = (flow, scope) |
330 | | - |
331 | | - msg = DNSIncoming(data, (addr, port), scope, now) |
332 | | - self.data = data |
333 | | - self.last_time = now |
334 | | - self.last_message = msg |
335 | | - if msg.valid: |
336 | | - if debug: |
337 | | - log.debug( |
338 | | - 'Received from %r:%r [socket %s]: %r (%d bytes) as [%r]', |
339 | | - addr, |
340 | | - port, |
341 | | - self.sock_description, |
342 | | - msg, |
343 | | - data_len, |
344 | | - data, |
345 | | - ) |
346 | | - else: |
347 | | - if debug: |
348 | | - log.debug( |
349 | | - 'Received from %r:%r [socket %s]: (%d bytes) [%r]', |
350 | | - addr, |
351 | | - port, |
352 | | - self.sock_description, |
353 | | - data_len, |
354 | | - data, |
355 | | - ) |
356 | | - return |
357 | | - |
358 | | - if not msg.is_query(): |
359 | | - self.zc.handle_response(msg) |
360 | | - return |
361 | | - |
362 | | - self.handle_query_or_defer(msg, addr, port, self.transport, v6_flow_scope) |
363 | | - |
364 | | - def handle_query_or_defer( |
365 | | - self, |
366 | | - msg: DNSIncoming, |
367 | | - addr: str, |
368 | | - port: int, |
369 | | - transport: _WrappedTransport, |
370 | | - v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), |
371 | | - ) -> None: |
372 | | - """Deal with incoming query packets. Provides a response if |
373 | | - possible.""" |
374 | | - if not msg.truncated: |
375 | | - self._respond_query(msg, addr, port, transport, v6_flow_scope) |
376 | | - return |
377 | | - |
378 | | - deferred = self._deferred.setdefault(addr, []) |
379 | | - # If we get the same packet we ignore it |
380 | | - for incoming in reversed(deferred): |
381 | | - if incoming.data == msg.data: |
382 | | - return |
383 | | - deferred.append(msg) |
384 | | - delay = millis_to_seconds(random.randint(*_TC_DELAY_RANDOM_INTERVAL)) |
385 | | - assert self.zc.loop is not None |
386 | | - self._cancel_any_timers_for_addr(addr) |
387 | | - self._timers[addr] = self.zc.loop.call_later( |
388 | | - delay, self._respond_query, None, addr, port, transport, v6_flow_scope |
389 | | - ) |
390 | | - |
391 | | - def _cancel_any_timers_for_addr(self, addr: str) -> None: |
392 | | - """Cancel any future truncated packet timers for the address.""" |
393 | | - if addr in self._timers: |
394 | | - self._timers.pop(addr).cancel() |
395 | | - |
396 | | - def _respond_query( |
397 | | - self, |
398 | | - msg: Optional[DNSIncoming], |
399 | | - addr: str, |
400 | | - port: int, |
401 | | - transport: _WrappedTransport, |
402 | | - v6_flow_scope: Union[Tuple[()], Tuple[int, int]] = (), |
403 | | - ) -> None: |
404 | | - """Respond to a query and reassemble any truncated deferred packets.""" |
405 | | - self._cancel_any_timers_for_addr(addr) |
406 | | - packets = self._deferred.pop(addr, []) |
407 | | - if msg: |
408 | | - packets.append(msg) |
409 | | - |
410 | | - self.zc.handle_assembled_query(packets, addr, port, transport, v6_flow_scope) |
411 | | - |
412 | | - def error_received(self, exc: Exception) -> None: |
413 | | - """Likely socket closed or IPv6.""" |
414 | | - # We preformat the message string with the socket as we want |
415 | | - # log_exception_once to log a warrning message once PER EACH |
416 | | - # different socket in case there are problems with multiple |
417 | | - # sockets |
418 | | - msg_str = f"Error with socket {self.sock_description}): %s" |
419 | | - self.log_exception_once(exc, msg_str, exc) |
420 | | - |
421 | | - def connection_made(self, transport: asyncio.BaseTransport) -> None: |
422 | | - wrapped_transport = _make_wrapped_transport(cast(asyncio.DatagramTransport, transport)) |
423 | | - self.transport = wrapped_transport |
424 | | - self.sock_description = f"{wrapped_transport.fileno} ({wrapped_transport.sock_name})" |
425 | | - |
426 | | - def connection_lost(self, exc: Optional[Exception]) -> None: |
427 | | - """Handle connection lost.""" |
428 | | - |
429 | | - |
430 | 103 | def async_send_with_transport( |
431 | 104 | log_debug: bool, |
432 | 105 | transport: _WrappedTransport, |
|
0 commit comments