Skip to content

Commit 4ed9036

Browse files
authored
Fix deadlock when event loop is shutdown during service registration (#869)
1 parent 22ff6b5 commit 4ed9036

5 files changed

Lines changed: 67 additions & 9 deletions

File tree

tests/test_core.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import socket
1313
import sys
1414
import time
15+
import threading
1516
import unittest
1617
import unittest.mock
1718
from typing import cast
@@ -715,3 +716,34 @@ def test_guard_against_duplicate_packets():
715716
assert listener.suppress_duplicate_packet(b"other packet", current_time_millis() + 1000) is False
716717
assert listener.suppress_duplicate_packet(b"first packet", current_time_millis()) is False
717718
zc.close()
719+
720+
721+
def test_shutdown_while_register_in_process():
722+
"""Test we can shutdown while registering a service in another thread."""
723+
724+
# instantiate a zeroconf instance
725+
zc = Zeroconf(interfaces=['127.0.0.1'])
726+
727+
# start a browser
728+
type_ = "_homeassistant._tcp.local."
729+
name = "MyTestHome"
730+
info_service = r.ServiceInfo(
731+
type_,
732+
'%s.%s' % (name, type_),
733+
80,
734+
0,
735+
0,
736+
{'path': '/~paulsm/'},
737+
"ash-90.local.",
738+
addresses=[socket.inet_aton("10.0.1.2")],
739+
)
740+
741+
def _background_register():
742+
zc.register_service(info_service)
743+
744+
bgthread = threading.Thread(target=_background_register, daemon=True)
745+
bgthread.start()
746+
time.sleep(0.3)
747+
748+
zc.close()
749+
bgthread.join()

tests/utils/test_aio.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def test_shutdown_loop() -> None:
6464
"""Test shutting down an event loop."""
6565
loop = None
6666
loop_thread_ready = threading.Event()
67+
runcoro_thread_ready = threading.Event()
6768

6869
def _run_loop() -> None:
6970
nonlocal loop
@@ -76,10 +77,23 @@ def _run_loop() -> None:
7677
loop_thread.start()
7778
loop_thread_ready.wait()
7879

80+
async def _still_running():
81+
await asyncio.sleep(5)
82+
83+
def _run_coro() -> None:
84+
runcoro_thread_ready.set()
85+
asyncio.run_coroutine_threadsafe(_still_running(), loop).result(1)
86+
87+
runcoro_thread = threading.Thread(target=_run_coro, daemon=True)
88+
runcoro_thread.start()
89+
runcoro_thread_ready.wait()
90+
91+
time.sleep(0.1)
7992
aioutils.shutdown_loop(loop)
8093
for _ in range(5):
8194
if not loop.is_running():
8295
break
8396
time.sleep(0.05)
8497

8598
assert loop.is_running() is False
99+
runcoro_thread.join()

zeroconf/_core.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
"""
2222

2323
import asyncio
24+
import concurrent.futures
2425
import contextlib
2526
import itertools
2627
import random
@@ -71,6 +72,7 @@
7172
)
7273

7374
_TC_DELAY_RANDOM_INTERVAL = (400, 500)
75+
_CLOSE_TIMEOUT = 3
7476

7577

7678
class AsyncEngine:
@@ -170,7 +172,7 @@ def close(self) -> None:
170172
return
171173
if not self.loop.is_running():
172174
return
173-
asyncio.run_coroutine_threadsafe(self._async_close(), self.loop).result()
175+
asyncio.run_coroutine_threadsafe(self._async_close(), self.loop).result(_CLOSE_TIMEOUT)
174176

175177

176178
class AsyncListener(asyncio.Protocol, QuietLogger):
@@ -416,7 +418,10 @@ def listeners(self) -> List[RecordUpdateListener]:
416418
def wait(self, timeout: float) -> None:
417419
"""Calling task waits for a given number of milliseconds or until notified."""
418420
assert self.loop is not None
419-
asyncio.run_coroutine_threadsafe(self.async_wait(timeout), self.loop).result()
421+
with contextlib.suppress(concurrent.futures.TimeoutError):
422+
asyncio.run_coroutine_threadsafe(self.async_wait(timeout), self.loop).result(
423+
millis_to_seconds(timeout)
424+
)
420425

421426
async def async_wait(self, timeout: float) -> None:
422427
"""Calling task waits for a given number of milliseconds or until notified."""

zeroconf/_services/info.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
_is_v6_address,
3838
)
3939
from .._utils.struct import int2byte
40-
from .._utils.time import current_time_millis
40+
from .._utils.time import current_time_millis, millis_to_seconds
4141
from ..const import (
4242
_CLASS_IN,
4343
_CLASS_UNIQUE,
@@ -427,7 +427,7 @@ def request(
427427
raise RuntimeError("Use AsyncServiceInfo.async_request from the event loop")
428428
return asyncio.run_coroutine_threadsafe(
429429
self.async_request(zc, timeout, question_type), zc.loop
430-
).result()
430+
).result(millis_to_seconds(timeout) + 1)
431431

432432
async def async_request(
433433
self, zc: 'Zeroconf', timeout: float, question_type: Optional[DNSQuestionType] = None

zeroconf/_utils/aio.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
import queue
2626
from typing import Any, List, Optional, Set, cast
2727

28+
_TASK_AWAIT_TIMEOUT = 1
29+
_GET_ALL_TASKS_TIMEOUT = 1
30+
_WAIT_FOR_LOOP_TASKS_TIMEOUT = 2 # Must be larger than _TASK_AWAIT_TIMEOUT
31+
2832

2933
def get_best_available_queue() -> queue.Queue:
3034
"""Create the best available queue type."""
@@ -73,16 +77,19 @@ async def _async_get_all_tasks(loop: asyncio.AbstractEventLoop) -> List[asyncio.
7377

7478
async def _wait_for_loop_tasks(wait_tasks: Set[asyncio.Task]) -> None:
7579
"""Wait for the event loop thread we started to shutdown."""
76-
await asyncio.wait(wait_tasks, timeout=1)
80+
await asyncio.wait(wait_tasks, timeout=_TASK_AWAIT_TIMEOUT)
7781

7882

7983
def shutdown_loop(loop: asyncio.AbstractEventLoop) -> None:
8084
"""Wait for pending tasks and stop an event loop."""
81-
pending_tasks = set(asyncio.run_coroutine_threadsafe(_async_get_all_tasks(loop), loop).result())
82-
done_tasks = set(task for task in pending_tasks if not task.done())
83-
pending_tasks -= done_tasks
85+
pending_tasks = set(
86+
asyncio.run_coroutine_threadsafe(_async_get_all_tasks(loop), loop).result(_GET_ALL_TASKS_TIMEOUT)
87+
)
88+
pending_tasks -= set(task for task in pending_tasks if task.done())
8489
if pending_tasks:
85-
asyncio.run_coroutine_threadsafe(_wait_for_loop_tasks(pending_tasks), loop).result()
90+
asyncio.run_coroutine_threadsafe(_wait_for_loop_tasks(pending_tasks), loop).result(
91+
_WAIT_FOR_LOOP_TASKS_TIMEOUT
92+
)
8693
loop.call_soon_threadsafe(loop.stop)
8794

8895

0 commit comments

Comments
 (0)