Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import socket
import sys
import time
import threading
import unittest
import unittest.mock
from typing import cast
Expand Down Expand Up @@ -715,3 +716,34 @@ def test_guard_against_duplicate_packets():
assert listener.suppress_duplicate_packet(b"other packet", current_time_millis() + 1000) is False
assert listener.suppress_duplicate_packet(b"first packet", current_time_millis()) is False
zc.close()


def test_shutdown_while_register_in_process():
"""Test we can shutdown while registering a service in another thread."""

# instantiate a zeroconf instance
zc = Zeroconf(interfaces=['127.0.0.1'])

# start a browser
type_ = "_homeassistant._tcp.local."
name = "MyTestHome"
info_service = r.ServiceInfo(
type_,
'%s.%s' % (name, type_),
80,
0,
0,
{'path': '/~paulsm/'},
"ash-90.local.",
addresses=[socket.inet_aton("10.0.1.2")],
)

def _background_register():
zc.register_service(info_service)

bgthread = threading.Thread(target=_background_register, daemon=True)
bgthread.start()
time.sleep(0.3)

zc.close()
bgthread.join()
14 changes: 14 additions & 0 deletions tests/utils/test_aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def test_shutdown_loop() -> None:
"""Test shutting down an event loop."""
loop = None
loop_thread_ready = threading.Event()
runcoro_thread_ready = threading.Event()

def _run_loop() -> None:
nonlocal loop
Expand All @@ -76,10 +77,23 @@ def _run_loop() -> None:
loop_thread.start()
loop_thread_ready.wait()

async def _still_running():
await asyncio.sleep(5)

def _run_coro() -> None:
runcoro_thread_ready.set()
asyncio.run_coroutine_threadsafe(_still_running(), loop).result(1)

runcoro_thread = threading.Thread(target=_run_coro, daemon=True)
runcoro_thread.start()
runcoro_thread_ready.wait()

time.sleep(0.1)
aioutils.shutdown_loop(loop)
for _ in range(5):
if not loop.is_running():
break
time.sleep(0.05)

assert loop.is_running() is False
runcoro_thread.join()
9 changes: 7 additions & 2 deletions zeroconf/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"""

import asyncio
import concurrent.futures
import contextlib
import itertools
import random
Expand Down Expand Up @@ -71,6 +72,7 @@
)

_TC_DELAY_RANDOM_INTERVAL = (400, 500)
_CLOSE_TIMEOUT = 3


class AsyncEngine:
Expand Down Expand Up @@ -170,7 +172,7 @@ def close(self) -> None:
return
if not self.loop.is_running():
return
asyncio.run_coroutine_threadsafe(self._async_close(), self.loop).result()
asyncio.run_coroutine_threadsafe(self._async_close(), self.loop).result(_CLOSE_TIMEOUT)


class AsyncListener(asyncio.Protocol, QuietLogger):
Expand Down Expand Up @@ -416,7 +418,10 @@ def listeners(self) -> List[RecordUpdateListener]:
def wait(self, timeout: float) -> None:
"""Calling task waits for a given number of milliseconds or until notified."""
assert self.loop is not None
asyncio.run_coroutine_threadsafe(self.async_wait(timeout), self.loop).result()
with contextlib.suppress(concurrent.futures.TimeoutError):
asyncio.run_coroutine_threadsafe(self.async_wait(timeout), self.loop).result(
millis_to_seconds(timeout)
)

async def async_wait(self, timeout: float) -> None:
"""Calling task waits for a given number of milliseconds or until notified."""
Expand Down
4 changes: 2 additions & 2 deletions zeroconf/_services/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
_is_v6_address,
)
from .._utils.struct import int2byte
from .._utils.time import current_time_millis
from .._utils.time import current_time_millis, millis_to_seconds
from ..const import (
_CLASS_IN,
_CLASS_UNIQUE,
Expand Down Expand Up @@ -427,7 +427,7 @@ def request(
raise RuntimeError("Use AsyncServiceInfo.async_request from the event loop")
return asyncio.run_coroutine_threadsafe(
self.async_request(zc, timeout, question_type), zc.loop
).result()
).result(millis_to_seconds(timeout) + 1)

async def async_request(
self, zc: 'Zeroconf', timeout: float, question_type: Optional[DNSQuestionType] = None
Expand Down
17 changes: 12 additions & 5 deletions zeroconf/_utils/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
import queue
from typing import Any, List, Optional, Set, cast

_TASK_AWAIT_TIMEOUT = 1
_GET_ALL_TASKS_TIMEOUT = 1
_WAIT_FOR_LOOP_TASKS_TIMEOUT = 2 # Must be larger than _TASK_AWAIT_TIMEOUT


def get_best_available_queue() -> queue.Queue:
"""Create the best available queue type."""
Expand Down Expand Up @@ -73,16 +77,19 @@ async def _async_get_all_tasks(loop: asyncio.AbstractEventLoop) -> List[asyncio.

async def _wait_for_loop_tasks(wait_tasks: Set[asyncio.Task]) -> None:
"""Wait for the event loop thread we started to shutdown."""
await asyncio.wait(wait_tasks, timeout=1)
await asyncio.wait(wait_tasks, timeout=_TASK_AWAIT_TIMEOUT)


def shutdown_loop(loop: asyncio.AbstractEventLoop) -> None:
"""Wait for pending tasks and stop an event loop."""
pending_tasks = set(asyncio.run_coroutine_threadsafe(_async_get_all_tasks(loop), loop).result())
done_tasks = set(task for task in pending_tasks if not task.done())
pending_tasks -= done_tasks
pending_tasks = set(
asyncio.run_coroutine_threadsafe(_async_get_all_tasks(loop), loop).result(_GET_ALL_TASKS_TIMEOUT)
)
pending_tasks -= set(task for task in pending_tasks if task.done())
if pending_tasks:
asyncio.run_coroutine_threadsafe(_wait_for_loop_tasks(pending_tasks), loop).result()
asyncio.run_coroutine_threadsafe(_wait_for_loop_tasks(pending_tasks), loop).result(
_WAIT_FOR_LOOP_TASKS_TIMEOUT
)
loop.call_soon_threadsafe(loop.stop)


Expand Down