Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,3 +604,28 @@ async def test_open_close_twice_from_async() -> None:
zc = Zeroconf(interfaces=['127.0.0.1'])
zc.close()
zc.close()
await asyncio.sleep(0)


@pytest.mark.asyncio
async def test_multiple_sync_instances_stared_from_async_close():
"""Test we can shutdown multiple sync instances from async."""

# instantiate a zeroconf instance
zc = Zeroconf(interfaces=['127.0.0.1'])
zc2 = Zeroconf(interfaces=['127.0.0.1'])

assert zc.loop == zc2.loop

zc.close()
assert zc.loop.is_running()
zc2.close()
assert zc2.loop.is_running()

zc3 = Zeroconf(interfaces=['127.0.0.1'])
assert zc3.loop == zc2.loop

zc3.close()
assert zc3.loop.is_running()

await asyncio.sleep(0)
15 changes: 15 additions & 0 deletions tests/utils/test_aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,27 @@

import asyncio
import contextlib
import unittest.mock

import pytest

from zeroconf._utils import aio as aioutils


@pytest.mark.asyncio
async def test_async_get_all_tasks() -> None:
"""Test we can get all tasks in the event loop.

We make sure we handle RuntimeError here as
this is not thread safe under PyPy
"""
await aioutils._async_get_all_tasks(aioutils.get_running_loop())
if not hasattr(asyncio, 'all_tasks'):
return
with unittest.mock.patch("zeroconf._utils.aio.asyncio.all_tasks", side_effect=RuntimeError):
await aioutils._async_get_all_tasks(aioutils.get_running_loop())


@pytest.mark.asyncio
async def test_get_running_loop_from_async() -> None:
"""Test we can get the event loop."""
Expand Down
17 changes: 11 additions & 6 deletions zeroconf/_utils/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,18 @@ def _handle_wait_complete(_: asyncio.Task) -> None:
await event_wait


async def _get_all_tasks(loop: asyncio.AbstractEventLoop) -> List[asyncio.Task]:
async def _async_get_all_tasks(loop: asyncio.AbstractEventLoop) -> List[asyncio.Task]:
"""Return all tasks running."""
await asyncio.sleep(0) # flush out any call_soon_threadsafe
# Make a copy of the tasks in case they change during iteration
if hasattr(asyncio, 'all_tasks'):
return list(asyncio.all_tasks(loop)) # type: ignore # pylint: disable=no-member
return list(asyncio.Task.all_tasks(loop)) # type: ignore # pylint: disable=no-member
# If there are multiple event loops running, all_tasks is not
# safe EVEN WHEN CALLED FROM THE EVENTLOOP
# under PyPy so we have to try a few times.
for _ in range(3):
with contextlib.suppress(RuntimeError):
if hasattr(asyncio, 'all_tasks'):
return asyncio.all_tasks(loop) # type: ignore # pylint: disable=no-member
return asyncio.Task.all_tasks(loop) # type: ignore # pylint: disable=no-member
return []


async def _wait_for_loop_tasks(wait_tasks: Set[asyncio.Task]) -> None:
Expand All @@ -78,7 +83,7 @@ async def _wait_for_loop_tasks(wait_tasks: Set[asyncio.Task]) -> None:

def shutdown_loop(loop: asyncio.AbstractEventLoop) -> None:
"""Wait for pending tasks and stop an event loop."""
pending_tasks = set(asyncio.run_coroutine_threadsafe(_get_all_tasks(loop), loop).result())
pending_tasks = set(asyncio.run_coroutine_threadsafe(_async_get_all_tasks(loop), loop).result())
done_tasks = set(task for task in pending_tasks if not task.done())
pending_tasks -= done_tasks
if pending_tasks:
Expand Down