Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 44 additions & 12 deletions zeroconf/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,35 @@ def _get_best_available_queue() -> queue.Queue:
return queue.Queue()


# Switch to asyncio.wait_for once https://bugs.python.org/issue39032 is fixed
async def wait_condition_or_timeout(condition: asyncio.Condition, timeout: float) -> None:
Comment thread
bdraco marked this conversation as resolved.
"""Wait for a condition or timeout."""
loop = asyncio.get_event_loop()
future = loop.create_future()

def _handle_timeout() -> None:
if not future.done():
future.set_result(None)

timer_handle = loop.call_later(timeout, _handle_timeout)
condition_wait = loop.create_task(condition.wait())

def _handle_wait_complete(_: asyncio.Task) -> None:
if not future.done():
future.set_result(None)

condition_wait.add_done_callback(_handle_wait_complete)

try:
await future
finally:
timer_handle.cancel()
if not condition_wait.done():
condition_wait.cancel()
with contextlib.suppress(asyncio.CancelledError):
await condition_wait


class _AsyncSender(threading.Thread):
"""A thread to handle sending DNSOutgoing for asyncio."""

Expand Down Expand Up @@ -87,19 +116,19 @@ def run(self) -> None:
class AsyncNotifyListener(NotifyListener):
"""A NotifyListener that async code can use to wait for events."""

def __init__(self) -> None:
def __init__(self, aiozc: 'AsyncZeroconf') -> None:
"""Create an event for async listeners to wait for."""
self.event = asyncio.Event()
self.aiozc = aiozc
self.loop = asyncio.get_event_loop()

def notify_all(self) -> None:
"""Schedule an async_notify_all."""
self.loop.call_soon_threadsafe(self.async_notify_all)
self.loop.call_soon_threadsafe(asyncio.ensure_future, self._async_notify_all())

def async_notify_all(self) -> None:
async def _async_notify_all(self) -> None:
"""Notify all async listeners."""
self.event.set()
self.event.clear()
async with self.aiozc.condition:
self.aiozc.condition.notify_all()


class AsyncServiceListener:
Expand Down Expand Up @@ -169,10 +198,12 @@ def __init__(
super().__init__(aiozc.zeroconf, type_, handlers, listener, addr, port, delay) # type: ignore
self._browser_task = asyncio.ensure_future(self.async_run())

def cancel(self) -> None:
async def async_cancel(self) -> None:
"""Cancel the browser."""
super().cancel()
self.cancel()
self._browser_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._browser_task

async def async_run(self) -> None:
"""Run the browser task."""
Expand Down Expand Up @@ -240,10 +271,11 @@ def __init__(
apple_p2p=apple_p2p,
)
self.loop = asyncio.get_event_loop()
self.async_notify = AsyncNotifyListener()
self.async_notify = AsyncNotifyListener(self)
self.zeroconf.add_notify_listener(self.async_notify)
self.async_browsers: Dict[AsyncServiceListener, AsyncServiceBrowser] = {}
self.sender = _AsyncSender(self.zeroconf)
self.condition = asyncio.Condition()

async def _async_broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int]) -> None:
"""Send a broadcasts to announce a service at intervals."""
Expand Down Expand Up @@ -333,8 +365,8 @@ async def async_get_service_info(

async def async_wait(self, timeout: float) -> None:
"""Calling task waits for a given number of milliseconds or until notified."""
with contextlib.suppress(asyncio.TimeoutError):
await asyncio.wait_for(self.async_notify.event.wait(), millis_to_seconds(timeout))
async with self.condition:
await wait_condition_or_timeout(self.condition, millis_to_seconds(timeout))

async def async_add_service_listener(self, type_: str, listener: AsyncServiceListener) -> None:
"""Adds a listener for a particular service type. This object
Expand All @@ -346,7 +378,7 @@ async def async_add_service_listener(self, type_: str, listener: AsyncServiceLis
async def async_remove_service_listener(self, listener: AsyncServiceListener) -> None:
"""Removes a listener from the set that is currently listening."""
if listener in self.async_browsers:
self.async_browsers[listener].cancel()
await self.async_browsers[listener].async_cancel()
del self.async_browsers[listener]

async def async_remove_all_service_listeners(self) -> None:
Expand Down