Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 76 additions & 21 deletions zeroconf/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,47 @@
)


def _millis_to_seconds(ms: float) -> float:
"""Convert miliseconds to seconds."""
return ms / 1000


def _get_best_available_queue() -> queue.Queue:
"""Create the best available queue type."""
if hasattr(queue, "SimpleQueue"):
return queue.SimpleQueue() # type: ignore # pylint: disable=all
return queue.Queue()


# Switch to asyncio.wait_for once https://bugs.python.org/issue39032 is fixed
async def wait_condition_or_timeout(condition: asyncio.Condition, timeout: float) -> None:
"""Wait for a condition or timeout."""
loop = asyncio.get_event_loop()
future = loop.create_future()

def _handle_timeout() -> None:
if not future.done():
future.set_result(None)

timer_handle = loop.call_later(timeout, _handle_timeout)
condition_wait = loop.create_task(condition.wait())

def _handle_wait_complete(_: asyncio.Task) -> None:
if not future.done():
future.set_result(None)

condition_wait.add_done_callback(_handle_wait_complete)

try:
await future
finally:
timer_handle.cancel()
if not condition_wait.done():
condition_wait.cancel()
with contextlib.suppress(asyncio.CancelledError):
await condition_wait


class _AsyncSender(threading.Thread):
"""A thread to handle sending DNSOutgoing for asyncio."""

Expand Down Expand Up @@ -86,19 +120,19 @@ def run(self) -> None:
class AsyncNotifyListener(NotifyListener):
"""A NotifyListener that async code can use to wait for events."""

def __init__(self) -> None:
def __init__(self, aiozc: 'AsyncZeroconf') -> None:
"""Create an event for async listeners to wait for."""
self.event = asyncio.Event()
self.aiozc = aiozc
self.loop = asyncio.get_event_loop()

def notify_all(self) -> None:
"""Schedule an async_notify_all."""
self.loop.call_soon_threadsafe(self.async_notify_all)
self.loop.call_soon_threadsafe(asyncio.ensure_future, self._async_notify_all())

def async_notify_all(self) -> None:
async def _async_notify_all(self) -> None:
"""Notify all async listeners."""
self.event.set()
self.event.clear()
async with self.aiozc.condition:
self.aiozc.condition.notify_all()


class AsyncServiceListener:
Expand Down Expand Up @@ -139,7 +173,7 @@ async def async_request(self, aiozc: 'AsyncZeroconf', timeout: float) -> bool:
next_ = now + delay
delay *= 2

await aiozc.async_wait((min(next_, last) - now) / 1000)
await aiozc.async_wait((min(next_, last) - now))
now = current_time_millis()
finally:
aiozc.zeroconf.remove_listener(self)
Expand Down Expand Up @@ -168,21 +202,41 @@ def __init__(
super().__init__(aiozc.zeroconf, type_, handlers, listener, addr, port, delay) # type: ignore
self._browser_task = asyncio.ensure_future(self.async_run())

def cancel(self) -> None:
async def async_cancel(self) -> None:
"""Cancel the browser."""
super().cancel()
self.cancel()
self._browser_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._browser_task

async def _wait_for_next_event(self) -> None:
"""Wait for the next handler or time to send queries."""
# If there are handlers to call
# we want to process them right away
if self._handlers_to_call:
return

# Wait for the type has the smallest next time
next_time = min(self._next_time.values())
now = current_time_millis()

if next_time <= now:
return

timeout = _millis_to_seconds(next_time - now)
async with self.aiozc.condition:
# We must check again while holding the condition
# in case the other thread has added to _handlers_to_call
# between when we checked above when we were not
# holding the condition
if not self._handlers_to_call:
await wait_condition_or_timeout(self.aiozc.condition, timeout)

async def async_run(self) -> None:
"""Run the browser task."""
self.run()
while True:
if not self._handlers_to_call:
# Wait for the type has the smallest next time
next_time = min(self._next_time.values())
now = current_time_millis()
if next_time > now:
await self.aiozc.async_wait(next_time - now)
await self._wait_for_next_event()

out = self.generate_ready_queries()
if out:
Expand Down Expand Up @@ -239,16 +293,17 @@ def __init__(
apple_p2p=apple_p2p,
)
self.loop = asyncio.get_event_loop()
self.async_notify = AsyncNotifyListener()
self.async_notify = AsyncNotifyListener(self)
self.zeroconf.add_notify_listener(self.async_notify)
self.async_browsers: Dict[AsyncServiceListener, AsyncServiceBrowser] = {}
self.sender = _AsyncSender(self.zeroconf)
self.condition = asyncio.Condition()

async def _async_broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int]) -> None:
"""Send a broadcasts to announce a service at intervals."""
for i in range(3):
if i != 0:
await asyncio.sleep(interval / 1000)
await asyncio.sleep(_millis_to_seconds(interval))
self.sender.send(self.zeroconf.generate_service_broadcast(info, ttl))

async def async_register_service(
Expand Down Expand Up @@ -278,7 +333,7 @@ async def async_check_service(self, info: ServiceInfo, cooperating_responders: b
self._raise_on_name_conflict(info)
for i in range(3):
if i != 0:
await asyncio.sleep(_CHECK_TIME / 1000)
await asyncio.sleep(_millis_to_seconds(_CHECK_TIME))
self.sender.send(self.zeroconf.generate_service_query(info))
self._raise_on_name_conflict(info)

Expand Down Expand Up @@ -332,8 +387,8 @@ async def async_get_service_info(

async def async_wait(self, timeout: float) -> None:
"""Calling task waits for a given number of milliseconds or until notified."""
with contextlib.suppress(asyncio.TimeoutError):
await asyncio.wait_for(self.async_notify.event.wait(), timeout / 1000)
async with self.condition:
await wait_condition_or_timeout(self.condition, _millis_to_seconds(timeout))

async def async_add_service_listener(self, type_: str, listener: AsyncServiceListener) -> None:
"""Adds a listener for a particular service type. This object
Expand All @@ -345,7 +400,7 @@ async def async_add_service_listener(self, type_: str, listener: AsyncServiceLis
async def async_remove_service_listener(self, listener: AsyncServiceListener) -> None:
"""Removes a listener from the set that is currently listening."""
if listener in self.async_browsers:
self.async_browsers[listener].cancel()
await self.async_browsers[listener].async_cancel()
del self.async_browsers[listener]

async def async_remove_all_service_listeners(self) -> None:
Expand Down