Skip to content

Commit 66bf28f

Browse files
committed
AsyncServiceBrowser must recheck for handlers to call when holding condition
- There was a short race condition window where the AsyncServiceBrowser could add to _handlers_to_call in the Engine thread, have the condition notify_all called, but since the AsyncServiceBrowser was not yet holding the condition it would not know to stop waiting and process the handlers to call.
1 parent ed53f62 commit 66bf28f

1 file changed

Lines changed: 76 additions & 21 deletions

File tree

zeroconf/asyncio.py

Lines changed: 76 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,47 @@
4747
)
4848

4949

50+
def _millis_to_seconds(ms: float) -> float:
51+
"""Convert miliseconds to seconds."""
52+
return ms / 1000
53+
54+
5055
def _get_best_available_queue() -> queue.Queue:
5156
"""Create the best available queue type."""
5257
if hasattr(queue, "SimpleQueue"):
5358
return queue.SimpleQueue() # type: ignore # pylint: disable=all
5459
return queue.Queue()
5560

5661

62+
# Switch to asyncio.wait_for once https://bugs.python.org/issue39032 is fixed
63+
async def wait_condition_or_timeout(condition: asyncio.Condition, timeout: float) -> None:
64+
"""Wait for a condition or timeout."""
65+
loop = asyncio.get_event_loop()
66+
future = loop.create_future()
67+
68+
def _handle_timeout() -> None:
69+
if not future.done():
70+
future.set_result(None)
71+
72+
timer_handle = loop.call_later(timeout, _handle_timeout)
73+
condition_wait = loop.create_task(condition.wait())
74+
75+
def _handle_wait_complete(_: asyncio.Task) -> None:
76+
if not future.done():
77+
future.set_result(None)
78+
79+
condition_wait.add_done_callback(_handle_wait_complete)
80+
81+
try:
82+
await future
83+
finally:
84+
timer_handle.cancel()
85+
if not condition_wait.done():
86+
condition_wait.cancel()
87+
with contextlib.suppress(asyncio.CancelledError):
88+
await condition_wait
89+
90+
5791
class _AsyncSender(threading.Thread):
5892
"""A thread to handle sending DNSOutgoing for asyncio."""
5993

@@ -86,19 +120,19 @@ def run(self) -> None:
86120
class AsyncNotifyListener(NotifyListener):
87121
"""A NotifyListener that async code can use to wait for events."""
88122

89-
def __init__(self) -> None:
123+
def __init__(self, aiozc: 'AsyncZeroconf') -> None:
90124
"""Create an event for async listeners to wait for."""
91-
self.event = asyncio.Event()
125+
self.aiozc = aiozc
92126
self.loop = asyncio.get_event_loop()
93127

94128
def notify_all(self) -> None:
95129
"""Schedule an async_notify_all."""
96-
self.loop.call_soon_threadsafe(self.async_notify_all)
130+
self.loop.call_soon_threadsafe(asyncio.ensure_future, self._async_notify_all())
97131

98-
def async_notify_all(self) -> None:
132+
async def _async_notify_all(self) -> None:
99133
"""Notify all async listeners."""
100-
self.event.set()
101-
self.event.clear()
134+
async with self.aiozc.condition:
135+
self.aiozc.condition.notify_all()
102136

103137

104138
class AsyncServiceListener:
@@ -139,7 +173,7 @@ async def async_request(self, aiozc: 'AsyncZeroconf', timeout: float) -> bool:
139173
next_ = now + delay
140174
delay *= 2
141175

142-
await aiozc.async_wait((min(next_, last) - now) / 1000)
176+
await aiozc.async_wait((min(next_, last) - now))
143177
now = current_time_millis()
144178
finally:
145179
aiozc.zeroconf.remove_listener(self)
@@ -168,21 +202,41 @@ def __init__(
168202
super().__init__(aiozc.zeroconf, type_, handlers, listener, addr, port, delay) # type: ignore
169203
self._browser_task = asyncio.ensure_future(self.async_run())
170204

171-
def cancel(self) -> None:
205+
async def async_cancel(self) -> None:
172206
"""Cancel the browser."""
173-
super().cancel()
207+
self.cancel()
174208
self._browser_task.cancel()
209+
with contextlib.suppress(asyncio.CancelledError):
210+
await self._browser_task
211+
212+
async def _wait_for_next_event(self) -> None:
213+
"""Wait for the next handler or time to send queries."""
214+
# If there are handlers to call
215+
# we want to process them right away
216+
if self._handlers_to_call:
217+
return
218+
219+
# Wait for the type has the smallest next time
220+
next_time = min(self._next_time.values())
221+
now = current_time_millis()
222+
223+
if next_time <= now:
224+
return
225+
226+
timeout = _millis_to_seconds(next_time - now)
227+
async with self.aiozc.condition:
228+
# We must check again while holding the condition
229+
# in case the other thread has added to _handlers_to_call
230+
# between when we checked above when we were not
231+
# holding the condition
232+
if not self._handlers_to_call:
233+
await wait_condition_or_timeout(self.aiozc.condition, timeout)
175234

176235
async def async_run(self) -> None:
177236
"""Run the browser task."""
178237
self.run()
179238
while True:
180-
if not self._handlers_to_call:
181-
# Wait for the type has the smallest next time
182-
next_time = min(self._next_time.values())
183-
now = current_time_millis()
184-
if next_time > now:
185-
await self.aiozc.async_wait(next_time - now)
239+
await self._wait_for_next_event()
186240

187241
out = self.generate_ready_queries()
188242
if out:
@@ -239,16 +293,17 @@ def __init__(
239293
apple_p2p=apple_p2p,
240294
)
241295
self.loop = asyncio.get_event_loop()
242-
self.async_notify = AsyncNotifyListener()
296+
self.async_notify = AsyncNotifyListener(self)
243297
self.zeroconf.add_notify_listener(self.async_notify)
244298
self.async_browsers: Dict[AsyncServiceListener, AsyncServiceBrowser] = {}
245299
self.sender = _AsyncSender(self.zeroconf)
300+
self.condition = asyncio.Condition()
246301

247302
async def _async_broadcast_service(self, info: ServiceInfo, interval: int, ttl: Optional[int]) -> None:
248303
"""Send a broadcasts to announce a service at intervals."""
249304
for i in range(3):
250305
if i != 0:
251-
await asyncio.sleep(interval / 1000)
306+
await asyncio.sleep(_millis_to_seconds(interval))
252307
self.sender.send(self.zeroconf.generate_service_broadcast(info, ttl))
253308

254309
async def async_register_service(
@@ -278,7 +333,7 @@ async def async_check_service(self, info: ServiceInfo, cooperating_responders: b
278333
self._raise_on_name_conflict(info)
279334
for i in range(3):
280335
if i != 0:
281-
await asyncio.sleep(_CHECK_TIME / 1000)
336+
await asyncio.sleep(_millis_to_seconds(_CHECK_TIME))
282337
self.sender.send(self.zeroconf.generate_service_query(info))
283338
self._raise_on_name_conflict(info)
284339

@@ -332,8 +387,8 @@ async def async_get_service_info(
332387

333388
async def async_wait(self, timeout: float) -> None:
334389
"""Calling task waits for a given number of milliseconds or until notified."""
335-
with contextlib.suppress(asyncio.TimeoutError):
336-
await asyncio.wait_for(self.async_notify.event.wait(), timeout / 1000)
390+
async with self.condition:
391+
await wait_condition_or_timeout(self.condition, _millis_to_seconds(timeout))
337392

338393
async def async_add_service_listener(self, type_: str, listener: AsyncServiceListener) -> None:
339394
"""Adds a listener for a particular service type. This object
@@ -345,7 +400,7 @@ async def async_add_service_listener(self, type_: str, listener: AsyncServiceLis
345400
async def async_remove_service_listener(self, listener: AsyncServiceListener) -> None:
346401
"""Removes a listener from the set that is currently listening."""
347402
if listener in self.async_browsers:
348-
self.async_browsers[listener].cancel()
403+
await self.async_browsers[listener].async_cancel()
349404
del self.async_browsers[listener]
350405

351406
async def async_remove_all_service_listeners(self) -> None:

0 commit comments

Comments
 (0)