Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/zeroconf/_listener.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ cdef class AsyncListener:
cdef public object sock_description
cdef public cython.dict _deferred
cdef public cython.dict _timers
cdef public cython.dict _deferred_deadlines

@cython.locals(now=double, debug=cython.bint)
cpdef datagram_received(self, cython.bytes bytes, cython.tuple addrs)
Expand All @@ -38,7 +39,10 @@ cdef class AsyncListener:

cdef _cancel_any_timers_for_addr(self, object addr)

@cython.locals(incoming=DNSIncoming, deferred=list)
@cython.locals(deadline=object, fire_at=double)
cdef double _compute_deferred_fire_at(self, object addr, double now, double delay)

@cython.locals(incoming=DNSIncoming, deferred=list, now=double, delay=double, fire_at=double)
cpdef handle_query_or_defer(
self,
DNSIncoming msg,
Expand Down
35 changes: 33 additions & 2 deletions src/zeroconf/_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class AsyncListener:

__slots__ = (
"_deferred",
"_deferred_deadlines",
"_query_handler",
"_record_manager",
"_registry",
Expand All @@ -82,6 +83,7 @@ def __init__(self, zc: Zeroconf) -> None:
self.sock_description: str | None = None
self._deferred: dict[str, list[DNSIncoming]] = {}
self._timers: dict[str, asyncio.TimerHandle] = {}
self._deferred_deadlines: dict[str, float] = {}
super().__init__()

def datagram_received(self, data: _bytes, addrs: tuple[str, int] | tuple[str, int, int, int]) -> None:
Expand Down Expand Up @@ -203,12 +205,19 @@ def handle_query_or_defer(
if incoming.data == msg.data:
return
deferred.append(msg)
delay = millis_to_seconds(random.randint(*_TC_DELAY_RANDOM_INTERVAL)) # noqa: S311
loop = self.zc.loop
assert loop is not None
now = loop.time()
delay = millis_to_seconds(random.randint(*_TC_DELAY_RANDOM_INTERVAL)) # noqa: S311
fire_at = self._compute_deferred_fire_at(addr, now, delay)
if fire_at < 0.0:
# Sentinel: a new reset would push the flush past the
# per-addr reassembly deadline, so leave the existing
# TimerHandle in place rather than re-arming it.
return
Comment thread
bdraco marked this conversation as resolved.
self._cancel_any_timers_for_addr(addr)
self._timers[addr] = loop.call_at(
loop.time() + delay,
fire_at,
self._respond_query,
None,
addr,
Expand All @@ -217,6 +226,27 @@ def handle_query_or_defer(
v6_flow_scope,
)

def _compute_deferred_fire_at(self, addr: _str, now: _float, delay: _float) -> _float:
"""Return the bounded call_at time for a TC-deferred flush, or -1.0 to keep the existing timer."""
# RFC 6762 §18.5 frames the random delay as a fixed reassembly budget
# starting at first arrival, not a sliding heartbeat.
deadline = self._deferred_deadlines.get(addr)
if deadline is None:
deadline = now + millis_to_seconds(_TC_DELAY_RANDOM_INTERVAL[1])
self._deferred_deadlines[addr] = deadline
fire_at = now + delay
if fire_at >= deadline:
if addr in self._timers:
# Existing timer already fires at or before the deadline;
# signal the caller to leave it alone rather than reset it.
return -1.0
# First packet for this addr already proposes a fire-time at
# or past the deadline — clamp to the deadline so the flush
# still happens within the reassembly budget.
return deadline
# Within budget: schedule at the proposed fire-time.
return fire_at

def _cancel_any_timers_for_addr(self, addr: _str) -> None:
"""Cancel any future truncated packet timers for the address."""
if addr in self._timers:
Expand All @@ -232,6 +262,7 @@ def _respond_query(
) -> None:
"""Respond to a query and reassemble any truncated deferred packets."""
self._cancel_any_timers_for_addr(addr)
self._deferred_deadlines.pop(addr, None)
packets = self._deferred.pop(addr, [])
if msg:
packets.append(msg)
Expand Down
113 changes: 82 additions & 31 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import zeroconf as r
from zeroconf import NotRunningException, Zeroconf, const, current_time_millis
from zeroconf._listener import AsyncListener, _WrappedTransport
from zeroconf._listener import _TC_DELAY_RANDOM_INTERVAL, AsyncListener, _WrappedTransport
from zeroconf._protocol.incoming import DNSIncoming
from zeroconf.asyncio import AsyncZeroconf

Expand Down Expand Up @@ -699,36 +699,41 @@ def test_tc_bit_defers_last_response_missing():
assert len(packets) == 4
expected_deferred = []

next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
assert protocol._deferred[source_ip] == expected_deferred
timer1 = protocol._timers[source_ip]

next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
assert protocol._deferred[source_ip] == expected_deferred
timer2 = protocol._timers[source_ip]
assert timer1.cancelled()
assert timer2 != timer1

# Send the same packet again to similar multi interfaces
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
assert protocol._deferred[source_ip] == expected_deferred
assert source_ip in protocol._timers
timer3 = protocol._timers[source_ip]
assert not timer3.cancelled()
assert timer3 == timer2

next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
assert protocol._deferred[source_ip] == expected_deferred
assert source_ip in protocol._timers
timer4 = protocol._timers[source_ip]
assert timer3.cancelled()
assert timer4 != timer3
# Pin per-packet delay to the minimum so each successive fire_at lands
# before the deadline established by the first packet — keeps the
# timer-replacement assertions deterministic under bounded TC-deferral.
min_delay_ms = _TC_DELAY_RANDOM_INTERVAL[0]
with patch("zeroconf._listener.random.randint", return_value=min_delay_ms):
next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
assert protocol._deferred[source_ip] == expected_deferred
timer1 = protocol._timers[source_ip]

next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
assert protocol._deferred[source_ip] == expected_deferred
timer2 = protocol._timers[source_ip]
assert timer1.cancelled()
assert timer2 != timer1

# Send the same packet again to similar multi interfaces
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
assert protocol._deferred[source_ip] == expected_deferred
assert source_ip in protocol._timers
timer3 = protocol._timers[source_ip]
assert not timer3.cancelled()
assert timer3 == timer2

next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(zc, protocol, next_packet, source_ip, const._MDNS_PORT, Mock(), ())
assert protocol._deferred[source_ip] == expected_deferred
assert source_ip in protocol._timers
timer4 = protocol._timers[source_ip]
assert timer3.cancelled()
assert timer4 != timer3

for _ in range(8):
time.sleep(0.1)
Expand All @@ -743,6 +748,52 @@ def test_tc_bit_defers_last_response_missing():
zc.close()


def test_tc_bit_defer_window_is_bounded():
"""TC-deferral assembly window must not slide past first_arrival + max delay."""
zc = Zeroconf(interfaces=["127.0.0.1"])
_wait_for_start(zc)
type_ = "_boundeddefer._tcp.local."
registration_name = f"knownname.{type_}"

info = r.ServiceInfo(
type_,
registration_name,
80,
0,
0,
{"path": "/~paulsm/"},
"ash-2.local.",
addresses=[socket.inet_aton("10.0.1.2")],
)
zc.registry.async_add(info)

protocol = zc.engine.protocols[0]
now_ms = r.current_time_millis()
_clear_cache(zc)
source_ip = "203.0.113.99"

generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
generated.add_question(r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN))
for _ in range(300):
generated.add_answer_at_time(info.dns_pointer(), now_ms)
packets = generated.packets()
assert len(packets) >= 3

# Pin the per-packet delay at its maximum so any subsequent reset would
# land past the deadline established by the first packet.
max_delay_ms = _TC_DELAY_RANDOM_INTERVAL[1]
with patch("zeroconf._listener.random.randint", return_value=max_delay_ms):
threadsafe_query(zc, protocol, r.DNSIncoming(packets[0]), source_ip, const._MDNS_PORT, Mock(), ())
first_when = protocol._timers[source_ip].when()
Comment thread
bdraco marked this conversation as resolved.

for raw in packets[1:-1]:
threadsafe_query(zc, protocol, r.DNSIncoming(raw), source_ip, const._MDNS_PORT, Mock(), ())
assert protocol._timers[source_ip].when() <= first_when

zc.registry.async_remove(info)
zc.close()


@pytest.mark.asyncio
async def test_open_close_twice_from_async() -> None:
"""Test we can close twice from a coroutine when using Zeroconf.
Expand Down
Loading