Skip to content

Commit d4c109c

Browse files
authored
Cache DNS record and question hashes (#960)
1 parent 3b482e2 commit d4c109c

3 files changed

Lines changed: 58 additions & 18 deletions

File tree

tests/test_asyncio.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,7 @@ async def test_async_unregister_all_services() -> None:
579579
assert results[1] is not None
580580

581581
await aiozc.async_unregister_all_services()
582+
_clear_cache(aiozc.zeroconf)
582583

583584
tasks = []
584585
tasks.append(aiozc.async_get_service_info(type_, registration_name))

tests/test_dns.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,33 @@ def test_dns_record_is_recent(self):
163163
assert record.is_recent(now + (8 * 1000)) is False
164164

165165

166+
def test_dns_question_hashablity():
167+
"""Test DNSQuestions are hashable."""
168+
169+
record1 = r.DNSQuestion('irrelevant', const._TYPE_A, const._CLASS_IN)
170+
record2 = r.DNSQuestion('irrelevant', const._TYPE_A, const._CLASS_IN)
171+
172+
record_set = {record1, record2}
173+
assert len(record_set) == 1
174+
175+
record_set.add(record1)
176+
assert len(record_set) == 1
177+
178+
record3_dupe = r.DNSQuestion('irrelevant', const._TYPE_A, const._CLASS_IN)
179+
assert record2 == record3_dupe
180+
assert record2.__hash__() == record3_dupe.__hash__()
181+
182+
record_set.add(record3_dupe)
183+
assert len(record_set) == 1
184+
185+
record4_dupe = r.DNSQuestion('notsame', const._TYPE_A, const._CLASS_IN)
186+
assert record2 != record4_dupe
187+
assert record2.__hash__() != record4_dupe.__hash__()
188+
189+
record_set.add(record4_dupe)
190+
assert len(record_set) == 2
191+
192+
166193
def test_dns_record_hashablity_does_not_consider_ttl():
167194
"""Test DNSRecord are hashable."""
168195

zeroconf/_dns.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
import enum
2424
import socket
25-
from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING, Tuple, Union, cast
25+
from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING, Union, cast
2626

2727
from ._exceptions import AbstractMethodException
2828
from ._utils.net import _is_v6_address
@@ -81,10 +81,6 @@ def __init__(self, name: str, type_: int, class_: int) -> None:
8181
self.class_ = class_ & _CLASS_MASK
8282
self.unique = (class_ & _CLASS_UNIQUE) != 0
8383

84-
def _entry_tuple(self) -> Tuple[str, int, int]:
85-
"""Entry Tuple for DNSEntry."""
86-
return (self.key, self.type, self.class_)
87-
8884
def __eq__(self, other: Any) -> bool:
8985
"""Equality test on key (lowercase name), type, and class"""
9086
return dns_entry_matches(other, self.key, self.type, self.class_) and isinstance(other, DNSEntry)
@@ -115,12 +111,22 @@ class DNSQuestion(DNSEntry):
115111

116112
"""A DNS question entry"""
117113

114+
__slots__ = ('_hash',)
115+
116+
def __init__(self, name: str, type_: int, class_: int) -> None:
117+
super().__init__(name, type_, class_)
118+
self._hash = hash((self.key, type_, class_))
119+
118120
def answered_by(self, rec: 'DNSRecord') -> bool:
119121
"""Returns true if the question is answered by the record"""
120122
return self.class_ == rec.class_ and self.type in (rec.type, _TYPE_ANY) and self.name == rec.name
121123

122124
def __hash__(self) -> int:
123-
return hash((self.name, self.class_, self.type))
125+
return self._hash
126+
127+
def __eq__(self, other: Any) -> bool:
128+
"""Tests equality on dns question."""
129+
return isinstance(other, DNSQuestion) and DNSEntry.__eq__(self, other)
124130

125131
@property
126132
def max_size(self) -> int:
@@ -225,7 +231,7 @@ class DNSAddress(DNSRecord):
225231

226232
"""A DNS address record"""
227233

228-
__slots__ = ('address', 'scope_id')
234+
__slots__ = ('_hash', 'address', 'scope_id')
229235

230236
def __init__(
231237
self,
@@ -241,6 +247,7 @@ def __init__(
241247
super().__init__(name, type_, class_, ttl, created)
242248
self.address = address
243249
self.scope_id = scope_id
250+
self._hash = hash((self.key, type_, class_, address, scope_id))
244251

245252
def write(self, out: 'DNSOutgoing') -> None:
246253
"""Used in constructing an outgoing packet"""
@@ -257,7 +264,7 @@ def __eq__(self, other: Any) -> bool:
257264

258265
def __hash__(self) -> int:
259266
"""Hash to compare like DNSAddresses."""
260-
return hash((*self._entry_tuple(), self.address, self.scope_id))
267+
return self._hash
261268

262269
def __repr__(self) -> str:
263270
"""String representation"""
@@ -275,14 +282,15 @@ class DNSHinfo(DNSRecord):
275282

276283
"""A DNS host information record"""
277284

278-
__slots__ = ('cpu', 'os')
285+
__slots__ = ('_hash', 'cpu', 'os')
279286

280287
def __init__(
281288
self, name: str, type_: int, class_: int, ttl: int, cpu: str, os: str, created: Optional[float] = None
282289
) -> None:
283290
super().__init__(name, type_, class_, ttl, created)
284291
self.cpu = cpu
285292
self.os = os
293+
self._hash = hash((self.key, type_, class_, cpu, os))
286294

287295
def write(self, out: 'DNSOutgoing') -> None:
288296
"""Used in constructing an outgoing packet"""
@@ -300,7 +308,7 @@ def __eq__(self, other: Any) -> bool:
300308

301309
def __hash__(self) -> int:
302310
"""Hash to compare like DNSHinfo."""
303-
return hash((*self._entry_tuple(), self.cpu, self.os))
311+
return self._hash
304312

305313
def __repr__(self) -> str:
306314
"""String representation"""
@@ -311,13 +319,14 @@ class DNSPointer(DNSRecord):
311319

312320
"""A DNS pointer record"""
313321

314-
__slots__ = ('alias',)
322+
__slots__ = ('_hash', 'alias')
315323

316324
def __init__(
317325
self, name: str, type_: int, class_: int, ttl: int, alias: str, created: Optional[float] = None
318326
) -> None:
319327
super().__init__(name, type_, class_, ttl, created)
320328
self.alias = alias
329+
self._hash = hash((self.key, type_, class_, alias))
321330

322331
@property
323332
def max_size_compressed(self) -> int:
@@ -339,7 +348,7 @@ def __eq__(self, other: Any) -> bool:
339348

340349
def __hash__(self) -> int:
341350
"""Hash to compare like DNSPointer."""
342-
return hash((*self._entry_tuple(), self.alias))
351+
return self._hash
343352

344353
def __repr__(self) -> str:
345354
"""String representation"""
@@ -350,22 +359,23 @@ class DNSText(DNSRecord):
350359

351360
"""A DNS text record"""
352361

353-
__slots__ = ('text',)
362+
__slots__ = ('_hash', 'text')
354363

355364
def __init__(
356365
self, name: str, type_: int, class_: int, ttl: int, text: bytes, created: Optional[float] = None
357366
) -> None:
358367
assert isinstance(text, (bytes, type(None)))
359368
super().__init__(name, type_, class_, ttl, created)
360369
self.text = text
370+
self._hash = hash((self.key, type_, class_, text))
361371

362372
def write(self, out: 'DNSOutgoing') -> None:
363373
"""Used in constructing an outgoing packet"""
364374
out.write_string(self.text)
365375

366376
def __hash__(self) -> int:
367377
"""Hash to compare like DNSText."""
368-
return hash((*self._entry_tuple(), self.text))
378+
return self._hash
369379

370380
def __eq__(self, other: Any) -> bool:
371381
"""Tests equality on text"""
@@ -382,7 +392,7 @@ class DNSService(DNSRecord):
382392

383393
"""A DNS service record"""
384394

385-
__slots__ = ('priority', 'weight', 'port', 'server')
395+
__slots__ = ('_hash', 'priority', 'weight', 'port', 'server')
386396

387397
def __init__(
388398
self,
@@ -401,6 +411,7 @@ def __init__(
401411
self.weight = weight
402412
self.port = port
403413
self.server = server
414+
self._hash = hash((self.key, type_, class_, priority, weight, port, server))
404415

405416
def write(self, out: 'DNSOutgoing') -> None:
406417
"""Used in constructing an outgoing packet"""
@@ -422,7 +433,7 @@ def __eq__(self, other: Any) -> bool:
422433

423434
def __hash__(self) -> int:
424435
"""Hash to compare like DNSService."""
425-
return hash((*self._entry_tuple(), self.priority, self.weight, self.port, self.server))
436+
return self._hash
426437

427438
def __repr__(self) -> str:
428439
"""String representation"""
@@ -433,7 +444,7 @@ class DNSNsec(DNSRecord):
433444

434445
"""A DNS NSEC record"""
435446

436-
__slots__ = ('next_name', 'rdtypes')
447+
__slots__ = ('_hash', 'next_name', 'rdtypes')
437448

438449
def __init__(
439450
self,
@@ -448,6 +459,7 @@ def __init__(
448459
super().__init__(name, type_, class_, ttl, created)
449460
self.next_name = next_name
450461
self.rdtypes = rdtypes
462+
self._hash = hash((self.key, type_, class_, next_name, *self.rdtypes))
451463

452464
def __eq__(self, other: Any) -> bool:
453465
"""Tests equality on cpu and os"""
@@ -460,7 +472,7 @@ def __eq__(self, other: Any) -> bool:
460472

461473
def __hash__(self) -> int:
462474
"""Hash to compare like DNSNSec."""
463-
return hash((*self._entry_tuple(), self.next_name, *self.rdtypes))
475+
return self._hash
464476

465477
def __repr__(self) -> str:
466478
"""String representation"""

0 commit comments

Comments
 (0)