Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@
import platform
import socket
import time
from collections.abc import Iterable
from functools import cache
from unittest import mock

import ifaddr

from zeroconf import DNSIncoming, DNSQuestion, DNSRecord, Zeroconf
from zeroconf import DNSIncoming, DNSOutgoing, DNSQuestion, DNSRecord, Zeroconf, const
from zeroconf._history import QuestionHistory

_MONOTONIC_RESOLUTION = time.get_clock_info("monotonic").resolution
Expand Down Expand Up @@ -70,6 +71,14 @@ def suppresses(self, question: DNSQuestion, now: float, known_answers: set[DNSRe
return False


def mock_incoming_msg(records: Iterable[DNSRecord]) -> DNSIncoming:
"""Build a `DNSIncoming` response message from a list of `DNSRecord`s."""
generated = DNSOutgoing(const._FLAGS_QR_RESPONSE)
for record in records:
generated.add_answer_at_time(record, 0)
return DNSIncoming(generated.packets()[0])


def _inject_responses(zc: Zeroconf, msgs: list[DNSIncoming]) -> None:
"""Inject a DNSIncoming response."""
assert zc.loop is not None
Expand Down
36 changes: 34 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@
from __future__ import annotations

import threading
from collections.abc import Generator
from collections.abc import AsyncGenerator, Generator
from unittest.mock import patch

import pytest
import pytest_asyncio

from zeroconf import _core, const
from zeroconf import Zeroconf, _core, const
from zeroconf._handlers import query_handler
from zeroconf._services import browser as service_browser
from zeroconf._services import info as service_info
from zeroconf.asyncio import AsyncZeroconf


@pytest.fixture(autouse=True)
Expand All @@ -23,6 +25,36 @@ def verify_threads_ended():
assert not threads


@pytest.fixture
def zc_loopback() -> Generator[Zeroconf]:
"""Yield a loopback `Zeroconf` and close it on teardown.

Replaces the inline `zc = Zeroconf(interfaces=["127.0.0.1"])` +
explicit `zc.close()` pattern duplicated across the suite. Calling
`zc.close()` inside a test is still safe — `close()` is idempotent.
"""
zc = Zeroconf(interfaces=["127.0.0.1"])
try:
yield zc
finally:
zc.close()


@pytest_asyncio.fixture
async def aiozc_loopback() -> AsyncGenerator[AsyncZeroconf]:
"""Yield a loopback `AsyncZeroconf` and close it on teardown.

Replaces the inline `aiozc = AsyncZeroconf(interfaces=["127.0.0.1"])`
+ explicit `await aiozc.async_close()` pattern duplicated across the
suite. Calling `async_close()` inside a test is still safe.
"""
aiozc = AsyncZeroconf(interfaces=["127.0.0.1"])
try:
yield aiozc
finally:
await aiozc.async_close()


@pytest.fixture
def run_isolated():
"""Change the mDNS port to run the test in isolation."""
Expand Down
9 changes: 1 addition & 8 deletions tests/services/test_browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import socket
import time
import unittest
from collections.abc import Iterable
from threading import Event
from typing import cast
from unittest.mock import patch
Expand Down Expand Up @@ -36,6 +35,7 @@
_inject_response,
_wait_for_start,
has_working_ipv6,
mock_incoming_msg,
time_changed_millis,
)

Expand All @@ -54,13 +54,6 @@ def teardown_module():
log.setLevel(original_logging_level)


def mock_incoming_msg(records: Iterable[r.DNSRecord]) -> r.DNSIncoming:
generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
for record in records:
generated.add_answer_at_time(record, 0)
return r.DNSIncoming(generated.packets()[0])


def test_service_browser_cancel_multiple_times():
"""Test we can cancel a ServiceBrowser multiple times before close."""

Expand Down
25 changes: 1 addition & 24 deletions tests/services/test_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import socket
import threading
import unittest
from collections.abc import Iterable
from ipaddress import ip_address
from threading import Event
from unittest.mock import patch
Expand All @@ -23,7 +22,7 @@
from zeroconf._utils.net import IPVersion
from zeroconf.asyncio import AsyncZeroconf

from .. import QUICK_REQUEST_TIMEOUT_MS, _inject_response, has_working_ipv6
from .. import QUICK_REQUEST_TIMEOUT_MS, _inject_response, has_working_ipv6, mock_incoming_msg

log = logging.getLogger("zeroconf")
original_logging_level = logging.NOTSET
Expand Down Expand Up @@ -279,14 +278,6 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()):
# patch the zeroconf send
with patch.object(zc, "async_send", send):

def mock_incoming_msg(records: Iterable[r.DNSRecord]) -> r.DNSIncoming:
generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)

for record in records:
generated.add_answer_at_time(record, 0)

return r.DNSIncoming(generated.packets()[0])

def get_service_info_helper(zc, type, name):
nonlocal service_info
service_info = zc.get_service_info(type, name)
Expand Down Expand Up @@ -422,14 +413,6 @@ def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()):
# patch the zeroconf send
with patch.object(zc, "async_send", send):

def mock_incoming_msg(records: Iterable[r.DNSRecord]) -> r.DNSIncoming:
generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)

for record in records:
generated.add_answer_at_time(record, 0)

return r.DNSIncoming(generated.packets()[0])

def get_service_info_helper(zc, type, name, timeout):
nonlocal service_info
service_info = zc.get_service_info(type, name, timeout)
Expand Down Expand Up @@ -552,12 +535,6 @@ def test_get_info_single(self):
),
]

def mock_incoming_msg(records: Iterable[r.DNSRecord]) -> r.DNSIncoming:
generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
for record in records:
generated.add_answer_at_time(record, 0)
return r.DNSIncoming(generated.packets()[0])

sent_queries: list[r.DNSOutgoing] = []

def send(out, addr=const._MDNS_ADDR, port=const._MDNS_PORT, v6_flow_scope=()):
Expand Down
44 changes: 18 additions & 26 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,39 +74,31 @@ async def test_reaper():


@pytest.mark.asyncio
async def test_setup_releases_socket_ownership() -> None:
async def test_setup_releases_socket_ownership(aiozc_loopback: AsyncZeroconf) -> None:
"""Engine releases its pending-socket refs once each socket has a transport."""
aiozc = AsyncZeroconf(interfaces=["127.0.0.1"])
try:
await aiozc.zeroconf.async_wait_for_start()
engine = aiozc.zeroconf.engine
assert engine._listen_socket is None
assert engine._respond_sockets == []
assert engine.readers
assert engine.senders
finally:
await aiozc.async_close()
await aiozc_loopback.zeroconf.async_wait_for_start()
engine = aiozc_loopback.zeroconf.engine
assert engine._listen_socket is None
assert engine._respond_sockets == []
assert engine.readers
assert engine.senders


@pytest.mark.asyncio
async def test_async_close_propagates_outer_cancellation() -> None:
async def test_async_close_propagates_outer_cancellation(aiozc_loopback: AsyncZeroconf) -> None:
"""Outer-task cancellation while awaiting setup propagates to the caller."""
aiozc = AsyncZeroconf(interfaces=["127.0.0.1"])
await aiozc_loopback.zeroconf.async_wait_for_start()
engine = aiozc_loopback.zeroconf.engine
loop = asyncio.get_running_loop()
original_task = engine._setup_task
fake_task = loop.create_future()
fake_task.set_exception(asyncio.CancelledError())
engine._setup_task = fake_task # type: ignore[assignment]
try:
await aiozc.zeroconf.async_wait_for_start()
engine = aiozc.zeroconf.engine
loop = asyncio.get_running_loop()
original_task = engine._setup_task
fake_task = loop.create_future()
fake_task.set_exception(asyncio.CancelledError())
engine._setup_task = fake_task # type: ignore[assignment]
try:
with pytest.raises(asyncio.CancelledError):
await engine._async_close()
finally:
engine._setup_task = original_task
with pytest.raises(asyncio.CancelledError):
await engine._async_close()
finally:
await aiozc.async_close()
engine._setup_task = original_task


@pytest.mark.asyncio
Expand Down
Loading