Skip to content

Commit 09db184

Browse files
authored
feat: implement heapq for tracking cache expire times (#1465)
1 parent 6de7bb6 commit 09db184

10 files changed

Lines changed: 263 additions & 61 deletions

File tree

src/zeroconf/_cache.pxd

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,14 @@ from ._dns cimport (
1111
DNSText,
1212
)
1313

14+
cdef object heappop
15+
cdef object heappush
16+
cdef object heapify
1417

1518
cdef object _UNIQUE_RECORD_TYPES
1619
cdef unsigned int _TYPE_PTR
1720
cdef cython.uint _ONE_SECOND
21+
cdef unsigned int _MIN_SCHEDULED_RECORD_EXPIRATION
1822

1923
@cython.locals(
2024
record_cache=dict,
@@ -26,6 +30,8 @@ cdef class DNSCache:
2630

2731
cdef public cython.dict cache
2832
cdef public cython.dict service_cache
33+
cdef public list _expire_heap
34+
cdef public dict _expirations
2935

3036
cpdef bint async_add_records(self, object entries)
3137

@@ -65,7 +71,8 @@ cdef class DNSCache:
6571

6672
@cython.locals(
6773
store=cython.dict,
68-
service_record=DNSService
74+
service_record=DNSService,
75+
when=object
6976
)
7077
cdef bint _async_add(self, DNSRecord record)
7178

@@ -95,3 +102,10 @@ cdef class DNSCache:
95102
now=double
96103
)
97104
cpdef current_entry_with_name_and_alias(self, str name, str alias)
105+
106+
cpdef void _async_set_created_ttl(
107+
self,
108+
DNSRecord record,
109+
double now,
110+
cython.float ttl
111+
)

src/zeroconf/_cache.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
USA
2121
"""
2222

23+
from heapq import heapify, heappop, heappush
2324
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union, cast
2425

2526
from ._dns import (
@@ -43,6 +44,11 @@
4344
_float = float
4445
_int = int
4546

47+
# The minimum number of scheduled record expirations before we start cleaning up
48+
# the expiration heap. This is a performance optimization to avoid cleaning up the
49+
# heap too often when there are only a few scheduled expirations.
50+
_MIN_SCHEDULED_RECORD_EXPIRATION = 100
51+
4652

4753
def _remove_key(cache: _DNSRecordCacheType, key: _str, record: _DNSRecord) -> None:
4854
"""Remove a key from a DNSRecord cache
@@ -60,6 +66,8 @@ class DNSCache:
6066

6167
def __init__(self) -> None:
6268
self.cache: _DNSRecordCacheType = {}
69+
self._expire_heap: List[Tuple[float, DNSRecord]] = []
70+
self._expirations: Dict[DNSRecord, float] = {}
6371
self.service_cache: _DNSRecordCacheType = {}
6472

6573
# Functions prefixed with async_ are NOT threadsafe and must
@@ -81,6 +89,12 @@ def _async_add(self, record: _DNSRecord) -> bool:
8189
store = self.cache.setdefault(record.key, {})
8290
new = record not in store and not isinstance(record, DNSNsec)
8391
store[record] = record
92+
when = record.created + (record.ttl * 1000)
93+
if self._expirations.get(record) != when:
94+
# Avoid adding duplicates to the heap
95+
heappush(self._expire_heap, (when, record))
96+
self._expirations[record] = when
97+
8498
if isinstance(record, DNSService):
8599
service_record = record
86100
self.service_cache.setdefault(record.server_key, {})[service_record] = service_record
@@ -108,6 +122,7 @@ def _async_remove(self, record: _DNSRecord) -> None:
108122
service_record = record
109123
_remove_key(self.service_cache, service_record.server_key, service_record)
110124
_remove_key(self.cache, record.key, record)
125+
self._expirations.pop(record, None)
111126

112127
def async_remove_records(self, entries: Iterable[DNSRecord]) -> None:
113128
"""Remove multiple records.
@@ -121,8 +136,44 @@ def async_expire(self, now: _float) -> List[DNSRecord]:
121136
"""Purge expired entries from the cache.
122137
123138
This function must be run in from event loop.
139+
140+
:param now: The current time in milliseconds.
124141
"""
125-
expired = [record for records in self.cache.values() for record in records if record.is_expired(now)]
142+
if not (expire_heap_len := len(self._expire_heap)):
143+
return []
144+
145+
expired: List[DNSRecord] = []
146+
# Find any expired records and add them to the to-delete list
147+
while self._expire_heap:
148+
when, record = self._expire_heap[0]
149+
if when > now:
150+
break
151+
heappop(self._expire_heap)
152+
# Check if the record hasn't been re-added to the heap
153+
# with a different expiration time as it will be removed
154+
# later when it reaches the top of the heap and its
155+
# expiration time is met.
156+
if self._expirations.get(record) == when:
157+
expired.append(record)
158+
159+
# If the expiration heap grows larger than the number expirations
160+
# times two, we clean it up to avoid keeping expired entries in
161+
# the heap and consuming memory. We guard this with a minimum
162+
# threshold to avoid cleaning up the heap too often when there are
163+
# only a few scheduled expirations.
164+
if (
165+
expire_heap_len > _MIN_SCHEDULED_RECORD_EXPIRATION
166+
and expire_heap_len > len(self._expirations) * 2
167+
):
168+
# Remove any expired entries from the expiration heap
169+
# that do not match the expiration time in the expirations
170+
# as it means the record has been re-added to the heap
171+
# with a different expiration time.
172+
self._expire_heap = [
173+
entry for entry in self._expire_heap if self._expirations.get(entry[1]) == entry[0]
174+
]
175+
heapify(self._expire_heap)
176+
126177
self.async_remove_records(expired)
127178
return expired
128179

@@ -256,4 +307,11 @@ def async_mark_unique_records_older_than_1s_to_expire(
256307
created_double = record.created
257308
if (now - created_double > _ONE_SECOND) and record not in answers_rrset:
258309
# Expire in 1s
259-
record.set_created_ttl(now, 1)
310+
self._async_set_created_ttl(record, now, 1)
311+
312+
def _async_set_created_ttl(self, record: DNSRecord, now: _float, ttl: _float) -> None:
313+
"""Set the created time and ttl of a record."""
314+
# It would be better if we made a copy instead of mutating the record
315+
# in place, but records currently don't have a copy method.
316+
record._set_created_ttl(now, ttl)
317+
self._async_add(record)

src/zeroconf/_dns.pxd

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,7 @@ cdef class DNSRecord(DNSEntry):
6666

6767
cpdef bint is_recent(self, double now)
6868

69-
cpdef reset_ttl(self, DNSRecord other)
70-
71-
cpdef set_created_ttl(self, double now, cython.float ttl)
69+
cdef _set_created_ttl(self, double now, cython.float ttl)
7270

7371
cdef class DNSAddress(DNSRecord):
7472

src/zeroconf/_dns.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,9 @@ def __eq__(self, other: Any) -> bool: # pylint: disable=no-self-use
185185
"""Abstract method"""
186186
raise AbstractMethodException
187187

188+
def __lt__(self, other: "DNSRecord") -> bool:
189+
return self.ttl < other.ttl
190+
188191
def suppressed_by(self, msg: "DNSIncoming") -> bool:
189192
"""Returns true if any answer in a message can suffice for the
190193
information held in this record."""
@@ -222,13 +225,10 @@ def is_recent(self, now: _float) -> bool:
222225
"""Returns true if the record more than one quarter of its TTL remaining."""
223226
return self.created + (_RECENT_TIME_MS * self.ttl) > now
224227

225-
def reset_ttl(self, other) -> None: # type: ignore[no-untyped-def]
226-
"""Sets this record's TTL and created time to that of
227-
another record."""
228-
self.set_created_ttl(other.created, other.ttl)
229-
230-
def set_created_ttl(self, created: _float, ttl: Union[float, int]) -> None:
228+
def _set_created_ttl(self, created: _float, ttl: Union[float, int]) -> None:
231229
"""Set the created and ttl of a record."""
230+
# It would be better if we made a copy instead of mutating the record
231+
# in place, but records currently don't have a copy method.
232232
self.created = created
233233
self.ttl = ttl
234234

src/zeroconf/_handlers/record_manager.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
103103
record,
104104
_DNS_PTR_MIN_TTL,
105105
)
106-
record.set_created_ttl(record.created, _DNS_PTR_MIN_TTL)
106+
# Safe because the record is never in the cache yet
107+
record._set_created_ttl(record.created, _DNS_PTR_MIN_TTL)
107108

108109
if record.unique: # https://tools.ietf.org/html/rfc6762#section-10.2
109110
unique_types.add((record.name, record_type, record.class_))
@@ -113,13 +114,10 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
113114

114115
maybe_entry = cache.async_get_unique(record)
115116
if not record.is_expired(now):
116-
if maybe_entry is not None:
117-
maybe_entry.reset_ttl(record)
117+
if record_type in _ADDRESS_RECORD_TYPES:
118+
address_adds.append(record)
118119
else:
119-
if record_type in _ADDRESS_RECORD_TYPES:
120-
address_adds.append(record)
121-
else:
122-
other_adds.append(record)
120+
other_adds.append(record)
123121
rec_update = RecordUpdate.__new__(RecordUpdate)
124122
rec_update._fast_init(record, maybe_entry)
125123
updates.append(rec_update)

tests/services/test_browser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1509,9 +1509,9 @@ def update_service(self, zc, type_, name) -> None: # type: ignore[no-untyped-de
15091509
)
15101510
# Force the ttl to be 1 second
15111511
now = current_time_millis()
1512-
for cache_record in zc.cache.cache.values():
1512+
for cache_record in list(zc.cache.cache.values()):
15131513
for record in cache_record:
1514-
record.set_created_ttl(now, 1)
1514+
zc.cache._async_set_created_ttl(record, now, 1)
15151515

15161516
time.sleep(0.3)
15171517
info.port = 400

tests/services/test_info.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def test_service_info_rejects_expired_records(self):
242242
ttl,
243243
b"\x04ff=0\x04ci=3\x04sf=0\x0bsh=6fLM5A==",
244244
)
245-
expired_record.set_created_ttl(1000, 1)
245+
zc.cache._async_set_created_ttl(expired_record, 1000, 1)
246246
info.async_update_records(zc, now, [RecordUpdate(expired_record, None)])
247247
assert info.properties[b"ci"] == b"2"
248248
zc.close()

tests/test_cache.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import logging
44
import unittest
55
import unittest.mock
6+
from heapq import heapify, heappop
7+
8+
import pytest
69

710
import zeroconf as r
811
from zeroconf import const
@@ -358,3 +361,123 @@ def test_async_get_unique_returns_newest_record():
358361
assert record is record2
359362
record = cache.async_get_unique(record2)
360363
assert record is record2
364+
365+
366+
@pytest.mark.asyncio
367+
async def test_cache_heap_cleanup() -> None:
368+
"""Test that the heap gets cleaned up when there are many old expirations."""
369+
cache = r.DNSCache()
370+
# The heap should not be cleaned up when there are less than 100 expiration changes
371+
min_records_to_cleanup = 100
372+
now = r.current_time_millis()
373+
name = "heap.local."
374+
ttl_seconds = 100
375+
ttl_millis = ttl_seconds * 1000
376+
377+
for i in range(min_records_to_cleanup):
378+
record = r.DNSAddress(name, const._TYPE_A, const._CLASS_IN, ttl_seconds, b"1", created=now + i)
379+
cache.async_add_records([record])
380+
381+
assert len(cache._expire_heap) == min_records_to_cleanup
382+
assert len(cache.async_entries_with_name(name)) == 1
383+
384+
# Now that we reached the minimum number of cookies to cleanup,
385+
# add one more cookie to trigger the cleanup
386+
record = r.DNSAddress(
387+
name, const._TYPE_A, const._CLASS_IN, ttl_seconds, b"1", created=now + min_records_to_cleanup
388+
)
389+
expected_expire_time = record.created + ttl_millis
390+
cache.async_add_records([record])
391+
assert len(cache.async_entries_with_name(name)) == 1
392+
entry = next(iter(cache.async_entries_with_name(name)))
393+
assert (entry.created + ttl_millis) == expected_expire_time
394+
assert entry is record
395+
396+
# Verify that the heap has been cleaned up
397+
assert len(cache.async_entries_with_name(name)) == 1
398+
cache.async_expire(now)
399+
400+
heap_copy = cache._expire_heap.copy()
401+
heapify(heap_copy)
402+
# Ensure heap order is maintained
403+
assert cache._expire_heap == heap_copy
404+
405+
# The heap should have been cleaned up
406+
assert len(cache._expire_heap) == 1
407+
assert len(cache.async_entries_with_name(name)) == 1
408+
409+
entry = next(iter(cache.async_entries_with_name(name)))
410+
assert entry is record
411+
412+
assert (entry.created + ttl_millis) == expected_expire_time
413+
414+
cache.async_expire(expected_expire_time)
415+
assert not cache.async_entries_with_name(name), cache._expire_heap
416+
417+
418+
@pytest.mark.asyncio
419+
async def test_cache_heap_multi_name_cleanup() -> None:
420+
"""Test cleanup with multiple names."""
421+
cache = r.DNSCache()
422+
# The heap should not be cleaned up when there are less than 100 expiration changes
423+
min_records_to_cleanup = 100
424+
now = r.current_time_millis()
425+
name = "heap.local."
426+
name2 = "heap2.local."
427+
ttl_seconds = 100
428+
ttl_millis = ttl_seconds * 1000
429+
430+
for i in range(min_records_to_cleanup):
431+
record = r.DNSAddress(name, const._TYPE_A, const._CLASS_IN, ttl_seconds, b"1", created=now + i)
432+
cache.async_add_records([record])
433+
expected_expire_time = record.created + ttl_millis
434+
435+
for i in range(5):
436+
record = r.DNSAddress(
437+
name2, const._TYPE_A, const._CLASS_IN, ttl_seconds, bytes((i,)), created=now + i
438+
)
439+
cache.async_add_records([record])
440+
441+
assert len(cache._expire_heap) == min_records_to_cleanup + 5
442+
assert len(cache.async_entries_with_name(name)) == 1
443+
assert len(cache.async_entries_with_name(name2)) == 5
444+
445+
cache.async_expire(now)
446+
# The heap and expirations should have been cleaned up
447+
assert len(cache._expire_heap) == 1 + 5
448+
assert len(cache._expirations) == 1 + 5
449+
450+
cache.async_expire(expected_expire_time)
451+
assert not cache.async_entries_with_name(name), cache._expire_heap
452+
453+
454+
@pytest.mark.asyncio
455+
async def test_cache_heap_pops_order() -> None:
456+
"""Test cache heap is popped in order."""
457+
cache = r.DNSCache()
458+
# The heap should not be cleaned up when there are less than 100 expiration changes
459+
min_records_to_cleanup = 100
460+
now = r.current_time_millis()
461+
name = "heap.local."
462+
name2 = "heap2.local."
463+
ttl_seconds = 100
464+
465+
for i in range(min_records_to_cleanup):
466+
record = r.DNSAddress(name, const._TYPE_A, const._CLASS_IN, ttl_seconds, b"1", created=now + i)
467+
cache.async_add_records([record])
468+
469+
for i in range(5):
470+
record = r.DNSAddress(
471+
name2, const._TYPE_A, const._CLASS_IN, ttl_seconds, bytes((i,)), created=now + i
472+
)
473+
cache.async_add_records([record])
474+
475+
assert len(cache._expire_heap) == min_records_to_cleanup + 5
476+
assert len(cache.async_entries_with_name(name)) == 1
477+
assert len(cache.async_entries_with_name(name2)) == 5
478+
479+
start_ts = 0.0
480+
while cache._expire_heap:
481+
ts, _ = heappop(cache._expire_heap)
482+
assert ts >= start_ts
483+
start_ts = ts

0 commit comments

Comments
 (0)