Skip to content
4 changes: 1 addition & 3 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,4 @@ def has_working_ipv6():


def _clear_cache(zc):
for name in zc.cache.names():
for record in zc.cache.entries_with_name(name):
zc.cache.remove(record)
zc.cache.cache.clear()
28 changes: 14 additions & 14 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_order(self):
record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a')
record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b')
cache = r.DNSCache()
cache.add_records([record1, record2])
cache.async_add_records([record1, record2])
entry = r.DNSEntry('a', const._TYPE_SOA, const._CLASS_IN)
cached_record = cache.get(entry)
assert cached_record == record2
Expand All @@ -45,7 +45,7 @@ def test_adding_same_record_to_cache_different_ttls(self):
record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a')
record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 10, b'a')
cache = r.DNSCache()
cache.add_records([record1, record2])
cache.async_add_records([record1, record2])
entry = r.DNSEntry(record2)
cached_record = cache.get(entry)
assert cached_record == record2
Expand All @@ -61,26 +61,26 @@ def test_adding_same_record_to_cache_different_ttls(self):
record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a')
record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 10, b'a')
cache = r.DNSCache()
cache.add_records([record1, record2])
cache.async_add_records([record1, record2])
cached_records = cache.get_all_by_details('a', const._TYPE_A, const._CLASS_IN)
assert cached_records == [record2]

def test_cache_empty_does_not_leak_memory_by_leaving_empty_list(self):
record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a')
record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b')
cache = r.DNSCache()
cache.add_records([record1, record2])
cache.async_add_records([record1, record2])
assert 'a' in cache.cache
cache.remove_records([record1, record2])
cache.async_remove_records([record1, record2])
assert 'a' not in cache.cache

def test_cache_empty_multiple_calls(self):
record1 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'a')
record2 = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b')
cache = r.DNSCache()
cache.add_records([record1, record2])
cache.async_add_records([record1, record2])
assert 'a' in cache.cache
cache.remove_records([record1, record2])
cache.async_remove_records([record1, record2])
assert 'a' not in cache.cache


Expand All @@ -91,22 +91,22 @@ def test_get(self):
record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a')
record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b')
cache = r.DNSCache()
cache.add_records([record1, record2])
cache.async_add_records([record1, record2])
assert cache.get(record1) == record1
assert cache.get(record2) == record2

def test_get_by_details(self):
record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a')
record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b')
cache = r.DNSCache()
cache.add_records([record1, record2])
cache.async_add_records([record1, record2])
assert cache.get_by_details('a', const._TYPE_A, const._CLASS_IN) == record2

def test_get_all_by_details(self):
record1 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'a')
record2 = r.DNSAddress('a', const._TYPE_A, const._CLASS_IN, 1, b'b')
cache = r.DNSCache()
cache.add_records([record1, record2])
cache.async_add_records([record1, record2])
assert set(cache.get_all_by_details('a', const._TYPE_A, const._CLASS_IN)) == set([record1, record2])

def test_entries_with_server(self):
Expand All @@ -117,7 +117,7 @@ def test_entries_with_server(self):
'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab'
)
cache = r.DNSCache()
cache.add_records([record1, record2])
cache.async_add_records([record1, record2])
assert set(cache.entries_with_server('ab')) == set([record1, record2])

def test_entries_with_name(self):
Expand All @@ -128,7 +128,7 @@ def test_entries_with_name(self):
'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab'
)
cache = r.DNSCache()
cache.add_records([record1, record2])
cache.async_add_records([record1, record2])
assert set(cache.entries_with_name('irrelevant')) == set([record1, record2])

def test_current_entry_with_name_and_alias(self):
Expand All @@ -139,7 +139,7 @@ def test_current_entry_with_name_and_alias(self):
'irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, 'y.irrelevant'
)
cache = r.DNSCache()
cache.add_records([record1, record2])
cache.async_add_records([record1, record2])
assert cache.current_entry_with_name_and_alias('irrelevant', 'x.irrelevant') == record1

def test_entries_with_name(self):
Expand All @@ -150,5 +150,5 @@ def test_entries_with_name(self):
'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab'
)
cache = r.DNSCache()
cache.add_records([record1, record2])
cache.async_add_records([record1, record2])
assert cache.names() == ['irrelevant']
22 changes: 12 additions & 10 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import zeroconf as r
from zeroconf import _core, const, ServiceBrowser, Zeroconf, current_time_millis
from zeroconf.aio import AsyncZeroconf

from . import has_working_ipv6, _clear_cache, _inject_response

Expand All @@ -36,22 +37,23 @@ def teardown_module():
log.setLevel(original_logging_level)


class TestReaper(unittest.TestCase):
@unittest.mock.patch.object(_core, "_CACHE_CLEANUP_INTERVAL", 10)
def test_reaper(self):
zeroconf = _core.Zeroconf(interfaces=['127.0.0.1'])
# This test uses asyncio because it needs to access the cache directly
# which is not threadsafe
@pytest.mark.asyncio
async def test_reaper():
with unittest.mock.patch.object(_core, "_CACHE_CLEANUP_INTERVAL", 10):
assert _core._CACHE_CLEANUP_INTERVAL == 10
aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
zeroconf = aiozc.zeroconf
cache = zeroconf.cache
original_entries = list(itertools.chain(*[cache.entries_with_name(name) for name in cache.names()]))
record_with_10s_ttl = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 10, b'a')
record_with_1s_ttl = r.DNSAddress('a', const._TYPE_SOA, const._CLASS_IN, 1, b'b')
zeroconf.cache.add(record_with_10s_ttl)
zeroconf.cache.add(record_with_1s_ttl)
zeroconf.cache.async_add_records([record_with_10s_ttl, record_with_1s_ttl])
entries_with_cache = list(itertools.chain(*[cache.entries_with_name(name) for name in cache.names()]))
time.sleep(1)
zeroconf.notify_all()
time.sleep(0.1)
await asyncio.sleep(1.2)
entries = list(itertools.chain(*[cache.entries_with_name(name) for name in cache.names()]))
zeroconf.close()
await aiozc.async_close()
assert entries != original_entries
assert entries_with_cache != original_entries
assert record_with_10s_ttl in entries
Expand Down
25 changes: 15 additions & 10 deletions tests/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import zeroconf as r
from zeroconf import ServiceInfo, Zeroconf, current_time_millis
from zeroconf import const
from zeroconf.aio import AsyncZeroconf

from . import _clear_cache, _inject_response

Expand Down Expand Up @@ -703,10 +704,14 @@ def test_known_answer_supression_service_type_enumeration_query():
zc.close()


def test_qu_response_only_sends_additionals_if_sends_answer():
# This test uses asyncio because it needs to access the cache directly
# which is not threadsafe
@pytest.mark.asyncio
async def test_qu_response_only_sends_additionals_if_sends_answer():
"""Test that a QU response does not send additionals unless it sends the answer as well."""
# instantiate a zeroconf instance
zc = Zeroconf(interfaces=['127.0.0.1'])
aiozc = AsyncZeroconf(interfaces=['127.0.0.1'])
zc = aiozc.zeroconf

type_ = "_addtest1._tcp.local."
name = "knownname"
Expand All @@ -731,13 +736,13 @@ def test_qu_response_only_sends_additionals_if_sends_answer():
ptr_record = info.dns_pointer()

# Add the PTR record to the cache
zc.cache.add(ptr_record)
zc.cache.async_add_records([ptr_record])

# Add the A record to the cache with 50% ttl remaining
a_record = info.dns_addresses()[0]
a_record.set_created_ttl(current_time_millis() - (a_record.ttl * 1000 / 2), a_record.ttl)
assert not a_record.is_recent(current_time_millis())
zc.cache.add(a_record)
zc.cache.async_add_records([a_record])

# With QU should respond to only unicast when the answer has been recently multicast
# even if the additional has not been recently multicast
Expand All @@ -755,10 +760,10 @@ def test_qu_response_only_sends_additionals_if_sends_answer():
assert unicast_out.answers[0][0] == ptr_record

# Remove the 50% A record and add a 100% A record
zc.cache.remove(a_record)
zc.cache.async_remove_records([a_record])
a_record = info.dns_addresses()[0]
assert a_record.is_recent(current_time_millis())
zc.cache.add(a_record)
zc.cache.async_add_records([a_record])
# With QU should respond to only unicast when the answer has been recently multicast
# even if the additional has not been recently multicast
query = r.DNSOutgoing(const._FLAGS_QR_QUERY)
Expand All @@ -775,10 +780,10 @@ def test_qu_response_only_sends_additionals_if_sends_answer():
assert unicast_out.answers[0][0] == ptr_record

# Remove the 100% PTR record and add a 50% PTR record
zc.cache.remove(ptr_record)
zc.cache.async_remove_records([ptr_record])
ptr_record.set_created_ttl(current_time_millis() - (ptr_record.ttl * 1000 / 2), ptr_record.ttl)
assert not ptr_record.is_recent(current_time_millis())
zc.cache.add(ptr_record)
zc.cache.async_add_records([ptr_record])
# With QU should respond to only multicast since the has less
# than 75% of its ttl remaining
query = r.DNSOutgoing(const._FLAGS_QR_QUERY)
Expand Down Expand Up @@ -811,7 +816,7 @@ def test_qu_response_only_sends_additionals_if_sends_answer():
question.unicast = True # Set the QU bit
assert question.unicast is True
query.add_question(question)
zc.cache.add(info2.dns_pointer()) # Add 100% TTL for info2 to the cache
zc.cache.async_add_records([info2.dns_pointer()]) # Add 100% TTL for info2 to the cache

unicast_out, multicast_out = zc.query_handler.async_response(
[r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT
Expand All @@ -828,4 +833,4 @@ def test_qu_response_only_sends_additionals_if_sends_answer():

# unregister
zc.registry.remove(info)
zc.close()
await aiozc.async_close()
3 changes: 1 addition & 2 deletions tests/test_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,8 +960,7 @@ async def test_multiple_a_addresses():
host = "multahost.local."
record1 = r.DNSAddress(host, const._TYPE_A, const._CLASS_IN, 1000, b'a')
record2 = r.DNSAddress(host, const._TYPE_A, const._CLASS_IN, 1000, b'b')
cache.add(record1)
cache.add(record2)
cache.async_add_records([record1, record2])

# New kwarg way
info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, host)
Expand Down
18 changes: 9 additions & 9 deletions zeroconf/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def __init__(self) -> None:
self.cache: _DNSRecordCacheType = {}
self.service_cache: _DNSRecordCacheType = {}

# Functions prefixed with are NOT threadsafe and must
# Functions prefixed with async_ are NOT threadsafe and must
# be run in the event loop.

def add(self, entry: DNSRecord) -> None:
def _async_add(self, entry: DNSRecord) -> None:
"""Adds an entry.

This function must be run in from event loop.
Expand All @@ -65,15 +65,15 @@ def add(self, entry: DNSRecord) -> None:
if isinstance(entry, DNSService):
self.service_cache.setdefault(entry.server, {})[entry] = entry

def add_records(self, entries: Iterable[DNSRecord]) -> None:
def async_add_records(self, entries: Iterable[DNSRecord]) -> None:
"""Add multiple records.

This function must be run in from event loop.
"""
for entry in entries:
self.add(entry)
self._async_add(entry)

def remove(self, entry: DNSRecord) -> None:
def _async_remove(self, entry: DNSRecord) -> None:
"""Removes an entry.

This function must be run in from event loop.
Expand All @@ -82,23 +82,23 @@ def remove(self, entry: DNSRecord) -> None:
_remove_key(self.service_cache, entry.server, entry)
_remove_key(self.cache, entry.key, entry)

def remove_records(self, entries: Iterable[DNSRecord]) -> None:
def async_remove_records(self, entries: Iterable[DNSRecord]) -> None:
"""Remove multiple records.

This function must be run in from event loop.
"""
for entry in entries:
self.remove(entry)
self._async_remove(entry)

def expire(self, now: float) -> Iterable[DNSRecord]:
def async_expire(self, now: float) -> Iterable[DNSRecord]:
"""Purge expired entries from the cache.

This function must be run in from event loop.
"""
for name in self.names():
for record in self.entries_with_name(name):
if record.is_expired(now):
self.remove(record)
self._async_remove(record)
yield record

# The below functions are threadsafe and do not need to be run in the
Expand Down
2 changes: 1 addition & 1 deletion zeroconf/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ async def _async_cache_cleanup(self) -> None:
"""Periodic cache cleanup."""
while not self.zc.done:
now = current_time_millis()
self.zc.record_manager.async_updates(now, list(self.zc.cache.expire(now)))
self.zc.record_manager.async_updates(now, list(self.zc.cache.async_expire(now)))
self.zc.record_manager.async_updates_complete()
await asyncio.sleep(millis_to_seconds(_CACHE_CLEANUP_INTERVAL))

Expand Down
4 changes: 2 additions & 2 deletions zeroconf/_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,11 +362,11 @@ def async_updates_from_response(self, msg: DNSIncoming) -> None:
# zc.get_service_info will see the cached value
# but ONLY after all the record updates have been
# processsed.
self.cache.add_records(itertools.chain(address_adds, other_adds))
self.cache.async_add_records(itertools.chain(address_adds, other_adds))
# Removes are processed last since
# ServiceInfo could generate an un-needed query
# because the data was not yet populated.
self.cache.remove_records(removes)
self.cache.async_remove_records(removes)
self.async_updates_complete()

def add_listener(
Expand Down