Skip to content

Commit aeeb8ab

Browse files
committed
Switch from using an asyncio.Event to asyncio.Condition for waiting
- Using a asyncio.Condition permits fixing a race condition in processing handlers similar to the sync version in #477
1 parent 849e9bc commit aeeb8ab

1 file changed

Lines changed: 44 additions & 12 deletions

File tree

zeroconf/asyncio.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,35 @@ def _get_best_available_queue() -> queue.Queue:
5555
return queue.Queue()
5656

5757

58+
# Switch to asyncio.wait_for once https://bugs.python.org/issue39032 is fixed
59+
async def wait_condition_or_timeout(condition: asyncio.Condition, timeout: float) -> None:
60+
"""Wait for a condition or timeout."""
61+
loop = asyncio.get_event_loop()
62+
future = loop.create_future()
63+
64+
def _handle_timeout() -> None:
65+
if not future.done():
66+
future.set_result(None)
67+
68+
timer_handle = loop.call_later(timeout, _handle_timeout)
69+
condition_wait = loop.create_task(condition.wait())
70+
71+
def _handle_wait_complete(_: asyncio.Task) -> None:
72+
if not future.done():
73+
future.set_result(None)
74+
75+
condition_wait.add_done_callback(_handle_wait_complete)
76+
77+
try:
78+
await future
79+
finally:
80+
timer_handle.cancel()
81+
if not condition_wait.done():
82+
condition_wait.cancel()
83+
with contextlib.suppress(asyncio.CancelledError):
84+
await condition_wait
85+
86+
5887
class _AsyncSender(threading.Thread):
5988
"""A thread to handle sending DNSOutgoing for asyncio."""
6089

@@ -87,19 +116,19 @@ def run(self) -> None:
87116
class AsyncNotifyListener(NotifyListener):
88117
"""A NotifyListener that async code can use to wait for events."""
89118

90-
def __init__(self) -> None:
119+
def __init__(self, aiozc: 'AsyncZeroconf') -> None:
91120
"""Create an event for async listeners to wait for."""
92-
self.event = asyncio.Event()
121+
self.aiozc = aiozc
93122
self.loop = asyncio.get_event_loop()
94123

95124
def notify_all(self) -> None:
96125
"""Schedule an async_notify_all."""
97-
self.loop.call_soon_threadsafe(self.async_notify_all)
126+
self.loop.call_soon_threadsafe(asyncio.ensure_future, self._async_notify_all())
98127

99-
def async_notify_all(self) -> None:
128+
async def _async_notify_all(self) -> None:
100129
"""Notify all async listeners."""
101-
self.event.set()
102-
self.event.clear()
130+
async with self.aiozc.condition:
131+
self.aiozc.condition.notify_all()
103132

104133

105134
class AsyncServiceListener:
@@ -169,10 +198,12 @@ def __init__(
169198
super().__init__(aiozc.zeroconf, type_, handlers, listener, addr, port, delay) # type: ignore
170199
self._browser_task = asyncio.ensure_future(self.async_run())
171200

172-
def cancel(self) -> None:
201+
async def async_cancel(self) -> None:
173202
"""Cancel the browser."""
174-
super().cancel()
203+
self.cancel()
175204
self._browser_task.cancel()
205+
with contextlib.suppress(asyncio.CancelledError):
206+
await self._browser_task
176207

177208
async def async_run(self) -> None:
178209
"""Run the browser task."""
@@ -240,10 +271,11 @@ def __init__(
240271
apple_p2p=apple_p2p,
241272
)
242273
self.loop = asyncio.get_event_loop()
243-
self.async_notify = AsyncNotifyListener()
274+
self.async_notify = AsyncNotifyListener(self)
244275
self.zeroconf.add_notify_listener(self.async_notify)
245276
self.async_browsers: Dict[AsyncServiceListener, AsyncServiceBrowser] = {}
246277
self.sender = _AsyncSender(self.zeroconf)
278+
self.condition = asyncio.Condition()
247279

248280
async def _async_broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int]) -> None:
249281
"""Send a broadcasts to announce a service at intervals."""
@@ -333,8 +365,8 @@ async def async_get_service_info(
333365

334366
async def async_wait(self, timeout: float) -> None:
335367
"""Calling task waits for a given number of milliseconds or until notified."""
336-
with contextlib.suppress(asyncio.TimeoutError):
337-
await asyncio.wait_for(self.async_notify.event.wait(), millis_to_seconds(timeout))
368+
async with self.condition:
369+
await wait_condition_or_timeout(self.condition, millis_to_seconds(timeout))
338370

339371
async def async_add_service_listener(self, type_: str, listener: AsyncServiceListener) -> None:
340372
"""Adds a listener for a particular service type. This object
@@ -346,7 +378,7 @@ async def async_add_service_listener(self, type_: str, listener: AsyncServiceLis
346378
async def async_remove_service_listener(self, listener: AsyncServiceListener) -> None:
347379
"""Removes a listener from the set that is currently listening."""
348380
if listener in self.async_browsers:
349-
self.async_browsers[listener].cancel()
381+
await self.async_browsers[listener].async_cancel()
350382
del self.async_browsers[listener]
351383

352384
async def async_remove_all_service_listeners(self) -> None:

0 commit comments

Comments
 (0)