Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions src/zeroconf/_cache.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,13 @@ cdef class DNSCache:
)
cpdef list async_all_by_details(self, str name, unsigned int type_, unsigned int class_)

cpdef cython.dict async_entries_with_name(self, str name)
cpdef list async_entries_with_name(self, str name)

cpdef cython.dict async_entries_with_server(self, str name)
cpdef list async_entries_with_server(self, str name)

@cython.locals(
cached_entry=DNSRecord,
records=dict
)
cpdef DNSRecord get_by_details(self, str name, unsigned int type_, unsigned int class_)

Expand All @@ -79,7 +80,15 @@ cdef class DNSCache:
)
cpdef void async_mark_unique_records_older_than_1s_to_expire(self, cython.set unique_types, object answers, double now)

cpdef entries_with_name(self, str name)
@cython.locals(
entries=dict
)
cpdef list entries_with_name(self, str name)

@cython.locals(
entries=dict
)
cpdef list entries_with_server(self, str server)

@cython.locals(
record=DNSRecord,
Expand Down
24 changes: 14 additions & 10 deletions src/zeroconf/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,26 +149,26 @@ def async_all_by_details(self, name: _str, type_: _int, class_: _int) -> List[DN
matches: List[DNSRecord] = []
if records is None:
return matches
for record in records:
for record in records.values():
if type_ == record.type and class_ == record.class_:
matches.append(record)
return matches

def async_entries_with_name(self, name: str) -> Dict[DNSRecord, DNSRecord]:
def async_entries_with_name(self, name: str) -> List[DNSRecord]:
"""Returns a dict of entries whose key matches the name.

This function is not threadsafe and must be called from
the event loop.
"""
return self.cache.get(name.lower()) or {}
return self.entries_with_name(name)

def async_entries_with_server(self, name: str) -> Dict[DNSRecord, DNSRecord]:
def async_entries_with_server(self, name: str) -> List[DNSRecord]:
"""Returns a dict of entries whose key matches the server.

This function is not threadsafe and must be called from
the event loop.
"""
return self.service_cache.get(name.lower()) or {}
return self.entries_with_server(name)

# The below functions are threadsafe and do not need to be run in the
# event loop, however they all make copies so they significantly
Expand All @@ -179,7 +179,7 @@ def get(self, entry: DNSEntry) -> Optional[DNSRecord]:
matching entry."""
if isinstance(entry, _UNIQUE_RECORD_TYPES):
return self.cache.get(entry.key, {}).get(entry)
for cached_entry in reversed(list(self.cache.get(entry.key, []))):
for cached_entry in reversed(list(self.cache.get(entry.key, {}).values())):
if entry.__eq__(cached_entry):
return cached_entry
return None
Expand All @@ -200,7 +200,7 @@ def get_by_details(self, name: str, type_: _int, class_: _int) -> Optional[DNSRe
records = self.cache.get(key)
if records is None:
return None
for cached_entry in reversed(list(records)):
for cached_entry in reversed(list(records.values())):
if type_ == cached_entry.type and class_ == cached_entry.class_:
return cached_entry
return None
Expand All @@ -211,15 +211,19 @@ def get_all_by_details(self, name: str, type_: _int, class_: _int) -> List[DNSRe
records = self.cache.get(key)
if records is None:
return []
return [entry for entry in list(records) if type_ == entry.type and class_ == entry.class_]
return [entry for entry in list(records.values()) if type_ == entry.type and class_ == entry.class_]

def entries_with_server(self, server: str) -> List[DNSRecord]:
"""Returns a list of entries whose server matches the name."""
return list(self.service_cache.get(server.lower(), []))
if entries := self.service_cache.get(server.lower()):
return list(entries.values())
return []

def entries_with_name(self, name: str) -> List[DNSRecord]:
"""Returns a list of entries whose key matches the name."""
return list(self.cache.get(name.lower(), []))
if entries := self.cache.get(name.lower()):
return list(entries.values())
return []

def current_entry_with_name_and_alias(self, name: str, alias: str) -> Optional[DNSRecord]:
now = current_time_millis()
Expand Down
79 changes: 79 additions & 0 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,3 +279,82 @@ def test_name(self):
cache = r.DNSCache()
cache.async_add_records([record1, record2])
assert cache.names() == ["irrelevant"]


def test_async_entries_with_name_returns_newest_record():
cache = r.DNSCache()
record1 = r.DNSAddress("a", const._TYPE_A, const._CLASS_IN, 1, b"a", created=1.0)
record2 = r.DNSAddress("a", const._TYPE_A, const._CLASS_IN, 1, b"a", created=2.0)
cache.async_add_records([record1])
cache.async_add_records([record2])
assert next(iter(cache.async_entries_with_name("a"))) is record2


def test_async_entries_with_server_returns_newest_record():
cache = r.DNSCache()
record1 = r.DNSService("a", const._TYPE_SRV, const._CLASS_IN, 1, 1, 1, 1, "a", created=1.0)
record2 = r.DNSService("a", const._TYPE_SRV, const._CLASS_IN, 1, 1, 1, 1, "a", created=2.0)
cache.async_add_records([record1])
cache.async_add_records([record2])
assert next(iter(cache.async_entries_with_server("a"))) is record2


def test_async_get_returns_newest_record():
cache = r.DNSCache()
record1 = r.DNSAddress("a", const._TYPE_A, const._CLASS_IN, 1, b"a", created=1.0)
record2 = r.DNSAddress("a", const._TYPE_A, const._CLASS_IN, 1, b"a", created=2.0)
cache.async_add_records([record1])
cache.async_add_records([record2])
assert cache.get(record2) is record2


def test_async_get_returns_newest_nsec_record():
cache = r.DNSCache()
record1 = r.DNSNsec("a", const._TYPE_NSEC, const._CLASS_IN, 1, "a", [], created=1.0)
record2 = r.DNSNsec("a", const._TYPE_NSEC, const._CLASS_IN, 1, "a", [], created=2.0)
cache.async_add_records([record1])
cache.async_add_records([record2])
assert cache.get(record2) is record2


def test_get_by_details_returns_newest_record():
cache = r.DNSCache()
record1 = r.DNSAddress("a", const._TYPE_A, const._CLASS_IN, 1, b"a", created=1.0)
record2 = r.DNSAddress("a", const._TYPE_A, const._CLASS_IN, 1, b"a", created=2.0)
cache.async_add_records([record1])
cache.async_add_records([record2])
assert cache.get_by_details("a", const._TYPE_A, const._CLASS_IN) is record2


def test_get_all_by_details_returns_newest_record():
cache = r.DNSCache()
record1 = r.DNSAddress("a", const._TYPE_A, const._CLASS_IN, 1, b"a", created=1.0)
record2 = r.DNSAddress("a", const._TYPE_A, const._CLASS_IN, 1, b"a", created=2.0)
cache.async_add_records([record1])
cache.async_add_records([record2])
records = cache.get_all_by_details("a", const._TYPE_A, const._CLASS_IN)
assert len(records) == 1
assert records[0] is record2


def test_async_get_all_by_details_returns_newest_record():
cache = r.DNSCache()
record1 = r.DNSAddress("a", const._TYPE_A, const._CLASS_IN, 1, b"a", created=1.0)
record2 = r.DNSAddress("a", const._TYPE_A, const._CLASS_IN, 1, b"a", created=2.0)
cache.async_add_records([record1])
cache.async_add_records([record2])
records = cache.async_all_by_details("a", const._TYPE_A, const._CLASS_IN)
assert len(records) == 1
assert records[0] is record2


def test_async_get_unique_returns_newest_record():
cache = r.DNSCache()
record1 = r.DNSPointer("a", const._TYPE_PTR, const._CLASS_IN, 1, "a", created=1.0)
record2 = r.DNSPointer("a", const._TYPE_PTR, const._CLASS_IN, 1, "a", created=2.0)
cache.async_add_records([record1])
cache.async_add_records([record2])
record = cache.async_get_unique(record1)
assert record is record2
record = cache.async_get_unique(record2)
assert record is record2