Skip to content

Commit b7d8678

Browse files
authored
Make DNSRecords hashable (#611)
- Allows storing them in a set for de-duplication - Needed to be able to check for duplicates to solve #604
1 parent 22bd147 commit b7d8678

2 files changed

Lines changed: 209 additions & 9 deletions

File tree

tests/test_dns.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,3 +851,182 @@ def test_qu_packet_parser():
851851
parsed = DNSIncoming(qu_packet)
852852
assert parsed.questions[0].unicast is True
853853
assert ",QU," in str(parsed.questions[0])
854+
855+
856+
def test_dns_record_hashablity_does_not_consider_ttl():
857+
"""Test DNSRecord are hashable."""
858+
859+
# Verify the TTL is not considered in the hash
860+
record1 = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, const._DNS_OTHER_TTL, b'same')
861+
record2 = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, const._DNS_HOST_TTL, b'same')
862+
863+
record_set = set([record1, record2])
864+
assert len(record_set) == 1
865+
866+
record_set.add(record1)
867+
assert len(record_set) == 1
868+
869+
record3_dupe = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, const._DNS_HOST_TTL, b'same')
870+
871+
record_set.add(record3_dupe)
872+
assert len(record_set) == 1
873+
874+
875+
def test_dns_address_record_hashablity():
876+
"""Test DNSAddress are hashable."""
877+
address1 = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 1, b'a')
878+
address2 = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 1, b'b')
879+
address3 = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 1, b'c')
880+
address4 = r.DNSAddress('irrelevant', const._TYPE_AAAA, const._CLASS_IN, 1, b'c')
881+
882+
record_set = set([address1, address2, address3, address4])
883+
assert len(record_set) == 4
884+
885+
record_set.add(address1)
886+
assert len(record_set) == 4
887+
888+
address3_dupe = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 1, b'c')
889+
890+
record_set.add(address3_dupe)
891+
assert len(record_set) == 4
892+
893+
# Verify we can remove records
894+
additional_set = set([address1, address2])
895+
record_set -= additional_set
896+
assert record_set == set([address3, address4])
897+
898+
899+
def test_dns_hinfo_record_hashablity():
900+
"""Test DNSHinfo are hashable."""
901+
hinfo1 = r.DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu1', 'os')
902+
hinfo2 = r.DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu2', 'os')
903+
904+
record_set = set([hinfo1, hinfo2])
905+
assert len(record_set) == 2
906+
907+
record_set.add(hinfo1)
908+
assert len(record_set) == 2
909+
910+
hinfo2_dupe = r.DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu2', 'os')
911+
912+
record_set.add(hinfo2_dupe)
913+
assert len(record_set) == 2
914+
915+
916+
def test_dns_pointer_record_hashablity():
917+
"""Test DNSPointer are hashable."""
918+
ptr1 = r.DNSPointer('irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, '123')
919+
ptr2 = r.DNSPointer('irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, '456')
920+
921+
record_set = set([ptr1, ptr2])
922+
assert len(record_set) == 2
923+
924+
record_set.add(ptr1)
925+
assert len(record_set) == 2
926+
927+
ptr2_dupe = r.DNSPointer('irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, '456')
928+
929+
record_set.add(ptr2_dupe)
930+
assert len(record_set) == 2
931+
932+
933+
def test_dns_text_record_hashablity():
934+
"""Test DNSText are hashable."""
935+
text1 = r.DNSText('irrelevant', 0, 0, 0, b'12345678901')
936+
text2 = r.DNSText('irrelevant', 1, 0, 0, b'12345678901')
937+
text3 = r.DNSText('irrelevant', 0, 1, 0, b'12345678901')
938+
text4 = r.DNSText('irrelevant', 0, 0, 1, b'12345678901')
939+
text5 = r.DNSText('irrelevant', 0, 0, 0, b'ABCDEFGHIJK')
940+
941+
record_set = set([text1, text2, text3, text4, text5])
942+
assert len(record_set) == 5
943+
944+
record_set.add(text1)
945+
assert len(record_set) == 5
946+
947+
text1_dupe = r.DNSText('irrelevant', 0, 0, 0, b'12345678901')
948+
949+
record_set.add(text1_dupe)
950+
assert len(record_set) == 5
951+
952+
953+
def test_dns_text_record_hashablity():
954+
"""Test DNSText are hashable."""
955+
text1 = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'12345678901')
956+
text2 = r.DNSText('irrelevant', 1, 0, const._DNS_OTHER_TTL, b'12345678901')
957+
text3 = r.DNSText('irrelevant', 0, 1, const._DNS_OTHER_TTL, b'12345678901')
958+
text4 = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'ABCDEFGHIJK')
959+
960+
record_set = set([text1, text2, text3, text4])
961+
962+
assert len(record_set) == 4
963+
964+
record_set.add(text1)
965+
assert len(record_set) == 4
966+
967+
text1_dupe = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'12345678901')
968+
969+
record_set.add(text1_dupe)
970+
assert len(record_set) == 4
971+
972+
973+
def test_dns_text_record_hashablity():
974+
"""Test DNSText are hashable."""
975+
text1 = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'12345678901')
976+
text2 = r.DNSText('irrelevant', 1, 0, const._DNS_OTHER_TTL, b'12345678901')
977+
text3 = r.DNSText('irrelevant', 0, 1, const._DNS_OTHER_TTL, b'12345678901')
978+
text4 = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'ABCDEFGHIJK')
979+
980+
record_set = set([text1, text2, text3, text4])
981+
982+
assert len(record_set) == 4
983+
984+
record_set.add(text1)
985+
assert len(record_set) == 4
986+
987+
text1_dupe = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'12345678901')
988+
989+
record_set.add(text1_dupe)
990+
assert len(record_set) == 4
991+
992+
993+
def test_dns_text_record_hashablity():
994+
"""Test DNSText are hashable."""
995+
text1 = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'12345678901')
996+
text2 = r.DNSText('irrelevant', 1, 0, const._DNS_OTHER_TTL, b'12345678901')
997+
text3 = r.DNSText('irrelevant', 0, 1, const._DNS_OTHER_TTL, b'12345678901')
998+
text4 = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'ABCDEFGHIJK')
999+
1000+
record_set = set([text1, text2, text3, text4])
1001+
1002+
assert len(record_set) == 4
1003+
1004+
record_set.add(text1)
1005+
assert len(record_set) == 4
1006+
1007+
text1_dupe = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'12345678901')
1008+
1009+
record_set.add(text1_dupe)
1010+
assert len(record_set) == 4
1011+
1012+
1013+
def test_dns_service_record_hashablity():
1014+
"""Test DNSService are hashable."""
1015+
srv1 = r.DNSService('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'a')
1016+
srv2 = r.DNSService('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 1, 80, 'a')
1017+
srv3 = r.DNSService('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 81, 'a')
1018+
srv4 = r.DNSService('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab')
1019+
1020+
record_set = set([srv1, srv2, srv3, srv4])
1021+
1022+
assert len(record_set) == 4
1023+
1024+
record_set.add(srv1)
1025+
assert len(record_set) == 4
1026+
1027+
srv1_dupe = r.DNSService(
1028+
'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'a'
1029+
)
1030+
1031+
record_set.add(srv1_dupe)
1032+
assert len(record_set) == 4

zeroconf/_dns.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ def __init__(self, name: str, type_: int, class_: int) -> None:
7070
self.class_ = class_ & _CLASS_MASK
7171
self.unique = (class_ & _CLASS_UNIQUE) != 0
7272

73+
def _entry_tuple(self) -> Tuple[str, int, int]:
74+
"""Entry Tuple for DNSEntry."""
75+
return (self.key, self.type, self.class_)
76+
7377
def __eq__(self, other: Any) -> bool:
7478
"""Equality test on key (lowercase name), type, and class"""
7579
return (
@@ -105,9 +109,6 @@ class DNSQuestion(DNSEntry):
105109

106110
"""A DNS question entry"""
107111

108-
def __init__(self, name: str, type_: int, class_: int) -> None:
109-
DNSEntry.__init__(self, name, type_, class_)
110-
111112
def answered_by(self, rec: 'DNSRecord') -> bool:
112113
"""Returns true if the question is answered by the record"""
113114
return (
@@ -141,7 +142,7 @@ class DNSRecord(DNSEntry):
141142

142143
# TODO: Switch to just int ttl
143144
def __init__(self, name: str, type_: int, class_: int, ttl: Union[float, int]) -> None:
144-
DNSEntry.__init__(self, name, type_, class_)
145+
super().__init__(name, type_, class_)
145146
self.ttl = ttl
146147
self.created = current_time_millis()
147148
self._expiration_time = self.get_expiration_time(_EXPIRE_FULL_TIME_PERCENT)
@@ -205,7 +206,7 @@ class DNSAddress(DNSRecord):
205206
"""A DNS address record"""
206207

207208
def __init__(self, name: str, type_: int, class_: int, ttl: int, address: bytes) -> None:
208-
DNSRecord.__init__(self, name, type_, class_, ttl)
209+
super().__init__(name, type_, class_, ttl)
209210
self.address = address
210211

211212
def write(self, out: 'DNSOutgoing') -> None:
@@ -218,6 +219,10 @@ def __eq__(self, other: Any) -> bool:
218219
isinstance(other, DNSAddress) and DNSEntry.__eq__(self, other) and self.address == other.address
219220
)
220221

222+
def __hash__(self) -> int:
223+
"""Hash to compare like DNSAddresses."""
224+
return hash((*self._entry_tuple(), self.address))
225+
221226
def __repr__(self) -> str:
222227
"""String representation"""
223228
try:
@@ -235,7 +240,7 @@ class DNSHinfo(DNSRecord):
235240
"""A DNS host information record"""
236241

237242
def __init__(self, name: str, type_: int, class_: int, ttl: int, cpu: str, os: str) -> None:
238-
DNSRecord.__init__(self, name, type_, class_, ttl)
243+
super().__init__(name, type_, class_, ttl)
239244
self.cpu = cpu
240245
self.os = os
241246

@@ -253,6 +258,10 @@ def __eq__(self, other: Any) -> bool:
253258
and self.os == other.os
254259
)
255260

261+
def __hash__(self) -> int:
262+
"""Hash to compare like DNSHinfo."""
263+
return hash((*self._entry_tuple(), self.cpu, self.os))
264+
256265
def __repr__(self) -> str:
257266
"""String representation"""
258267
return self.to_string(self.cpu + " " + self.os)
@@ -263,7 +272,7 @@ class DNSPointer(DNSRecord):
263272
"""A DNS pointer record"""
264273

265274
def __init__(self, name: str, type_: int, class_: int, ttl: int, alias: str) -> None:
266-
DNSRecord.__init__(self, name, type_, class_, ttl)
275+
super().__init__(name, type_, class_, ttl)
267276
self.alias = alias
268277

269278
def write(self, out: 'DNSOutgoing') -> None:
@@ -274,6 +283,10 @@ def __eq__(self, other: Any) -> bool:
274283
"""Tests equality on alias"""
275284
return isinstance(other, DNSPointer) and self.alias == other.alias and DNSEntry.__eq__(self, other)
276285

286+
def __hash__(self) -> int:
287+
"""Hash to compare like DNSPointer."""
288+
return hash((*self._entry_tuple(), self.alias))
289+
277290
def __repr__(self) -> str:
278291
"""String representation"""
279292
return self.to_string(self.alias)
@@ -285,13 +298,17 @@ class DNSText(DNSRecord):
285298

286299
def __init__(self, name: str, type_: int, class_: int, ttl: int, text: bytes) -> None:
287300
assert isinstance(text, (bytes, type(None)))
288-
DNSRecord.__init__(self, name, type_, class_, ttl)
301+
super().__init__(name, type_, class_, ttl)
289302
self.text = text
290303

291304
def write(self, out: 'DNSOutgoing') -> None:
292305
"""Used in constructing an outgoing packet"""
293306
out.write_string(self.text)
294307

308+
def __hash__(self) -> int:
309+
"""Hash to compare like DNSText."""
310+
return hash((*self._entry_tuple(), self.text))
311+
295312
def __eq__(self, other: Any) -> bool:
296313
"""Tests equality on text"""
297314
return isinstance(other, DNSText) and self.text == other.text and DNSEntry.__eq__(self, other)
@@ -318,7 +335,7 @@ def __init__(
318335
port: int,
319336
server: str,
320337
) -> None:
321-
DNSRecord.__init__(self, name, type_, class_, ttl)
338+
super().__init__(name, type_, class_, ttl)
322339
self.priority = priority
323340
self.weight = weight
324341
self.port = port
@@ -342,6 +359,10 @@ def __eq__(self, other: Any) -> bool:
342359
and DNSEntry.__eq__(self, other)
343360
)
344361

362+
def __hash__(self) -> int:
363+
"""Hash to compare like DNSService."""
364+
return hash((*self._entry_tuple(), self.priority, self.weight, self.port, self.server))
365+
345366
def __repr__(self) -> str:
346367
"""String representation"""
347368
return self.to_string("%s:%s" % (self.server, self.port))

0 commit comments

Comments
 (0)