Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 179 additions & 0 deletions tests/test_dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,3 +851,182 @@ def test_qu_packet_parser():
parsed = DNSIncoming(qu_packet)
assert parsed.questions[0].unicast is True
assert ",QU," in str(parsed.questions[0])


def test_dns_record_hashablity_does_not_consider_ttl():
"""Test DNSRecord are hashable."""

# Verify the TTL is not considered in the hash
record1 = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, const._DNS_OTHER_TTL, b'same')
record2 = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, const._DNS_HOST_TTL, b'same')

record_set = set([record1, record2])
assert len(record_set) == 1

record_set.add(record1)
assert len(record_set) == 1

record3_dupe = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, const._DNS_HOST_TTL, b'same')

record_set.add(record3_dupe)
assert len(record_set) == 1


def test_dns_address_record_hashablity():
"""Test DNSAddress are hashable."""
address1 = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 1, b'a')
address2 = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 1, b'b')
address3 = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 1, b'c')
address4 = r.DNSAddress('irrelevant', const._TYPE_AAAA, const._CLASS_IN, 1, b'c')

record_set = set([address1, address2, address3, address4])
assert len(record_set) == 4

record_set.add(address1)
assert len(record_set) == 4

address3_dupe = r.DNSAddress('irrelevant', const._TYPE_A, const._CLASS_IN, 1, b'c')

record_set.add(address3_dupe)
assert len(record_set) == 4

# Verify we can remove records
additional_set = set([address1, address2])
record_set -= additional_set
assert record_set == set([address3, address4])


def test_dns_hinfo_record_hashablity():
"""Test DNSHinfo are hashable."""
hinfo1 = r.DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu1', 'os')
hinfo2 = r.DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu2', 'os')

record_set = set([hinfo1, hinfo2])
assert len(record_set) == 2

record_set.add(hinfo1)
assert len(record_set) == 2

hinfo2_dupe = r.DNSHinfo('irrelevant', const._TYPE_HINFO, 0, 0, 'cpu2', 'os')

record_set.add(hinfo2_dupe)
assert len(record_set) == 2


def test_dns_pointer_record_hashablity():
"""Test DNSPointer are hashable."""
ptr1 = r.DNSPointer('irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, '123')
ptr2 = r.DNSPointer('irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, '456')

record_set = set([ptr1, ptr2])
assert len(record_set) == 2

record_set.add(ptr1)
assert len(record_set) == 2

ptr2_dupe = r.DNSPointer('irrelevant', const._TYPE_PTR, const._CLASS_IN, const._DNS_OTHER_TTL, '456')

record_set.add(ptr2_dupe)
assert len(record_set) == 2


def test_dns_text_record_hashablity():
"""Test DNSText are hashable."""
text1 = r.DNSText('irrelevant', 0, 0, 0, b'12345678901')
text2 = r.DNSText('irrelevant', 1, 0, 0, b'12345678901')
text3 = r.DNSText('irrelevant', 0, 1, 0, b'12345678901')
text4 = r.DNSText('irrelevant', 0, 0, 1, b'12345678901')
text5 = r.DNSText('irrelevant', 0, 0, 0, b'ABCDEFGHIJK')

record_set = set([text1, text2, text3, text4, text5])
assert len(record_set) == 5

record_set.add(text1)
assert len(record_set) == 5

text1_dupe = r.DNSText('irrelevant', 0, 0, 0, b'12345678901')

record_set.add(text1_dupe)
assert len(record_set) == 5


def test_dns_text_record_hashablity():
"""Test DNSText are hashable."""
text1 = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'12345678901')
text2 = r.DNSText('irrelevant', 1, 0, const._DNS_OTHER_TTL, b'12345678901')
text3 = r.DNSText('irrelevant', 0, 1, const._DNS_OTHER_TTL, b'12345678901')
text4 = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'ABCDEFGHIJK')

record_set = set([text1, text2, text3, text4])

assert len(record_set) == 4

record_set.add(text1)
assert len(record_set) == 4

text1_dupe = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'12345678901')

record_set.add(text1_dupe)
assert len(record_set) == 4


def test_dns_text_record_hashablity():
"""Test DNSText are hashable."""
text1 = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'12345678901')
text2 = r.DNSText('irrelevant', 1, 0, const._DNS_OTHER_TTL, b'12345678901')
text3 = r.DNSText('irrelevant', 0, 1, const._DNS_OTHER_TTL, b'12345678901')
text4 = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'ABCDEFGHIJK')

record_set = set([text1, text2, text3, text4])

assert len(record_set) == 4

record_set.add(text1)
assert len(record_set) == 4

text1_dupe = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'12345678901')

record_set.add(text1_dupe)
assert len(record_set) == 4


def test_dns_text_record_hashablity():
"""Test DNSText are hashable."""
text1 = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'12345678901')
text2 = r.DNSText('irrelevant', 1, 0, const._DNS_OTHER_TTL, b'12345678901')
text3 = r.DNSText('irrelevant', 0, 1, const._DNS_OTHER_TTL, b'12345678901')
text4 = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'ABCDEFGHIJK')

record_set = set([text1, text2, text3, text4])

assert len(record_set) == 4

record_set.add(text1)
assert len(record_set) == 4

text1_dupe = r.DNSText('irrelevant', 0, 0, const._DNS_OTHER_TTL, b'12345678901')

record_set.add(text1_dupe)
assert len(record_set) == 4


def test_dns_service_record_hashablity():
"""Test DNSService are hashable."""
srv1 = r.DNSService('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'a')
srv2 = r.DNSService('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 1, 80, 'a')
srv3 = r.DNSService('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 81, 'a')
srv4 = r.DNSService('irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'ab')

record_set = set([srv1, srv2, srv3, srv4])

assert len(record_set) == 4

record_set.add(srv1)
assert len(record_set) == 4

srv1_dupe = r.DNSService(
'irrelevant', const._TYPE_SRV, const._CLASS_IN, const._DNS_HOST_TTL, 0, 0, 80, 'a'
)

record_set.add(srv1_dupe)
assert len(record_set) == 4
39 changes: 30 additions & 9 deletions zeroconf/_dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ def __init__(self, name: str, type_: int, class_: int) -> None:
self.class_ = class_ & _CLASS_MASK
self.unique = (class_ & _CLASS_UNIQUE) != 0

def _entry_tuple(self) -> Tuple[str, int, int]:
"""Entry Tuple for DNSEntry."""
return (self.key, self.type, self.class_)

def __eq__(self, other: Any) -> bool:
"""Equality test on key (lowercase name), type, and class"""
return (
Expand Down Expand Up @@ -105,9 +109,6 @@ class DNSQuestion(DNSEntry):

"""A DNS question entry"""

def __init__(self, name: str, type_: int, class_: int) -> None:
DNSEntry.__init__(self, name, type_, class_)

def answered_by(self, rec: 'DNSRecord') -> bool:
"""Returns true if the question is answered by the record"""
return (
Expand Down Expand Up @@ -141,7 +142,7 @@ class DNSRecord(DNSEntry):

# TODO: Switch to just int ttl
def __init__(self, name: str, type_: int, class_: int, ttl: Union[float, int]) -> None:
DNSEntry.__init__(self, name, type_, class_)
super().__init__(name, type_, class_)
self.ttl = ttl
self.created = current_time_millis()
self._expiration_time = self.get_expiration_time(_EXPIRE_FULL_TIME_PERCENT)
Expand Down Expand Up @@ -205,7 +206,7 @@ class DNSAddress(DNSRecord):
"""A DNS address record"""

def __init__(self, name: str, type_: int, class_: int, ttl: int, address: bytes) -> None:
DNSRecord.__init__(self, name, type_, class_, ttl)
super().__init__(name, type_, class_, ttl)
self.address = address

def write(self, out: 'DNSOutgoing') -> None:
Expand All @@ -218,6 +219,10 @@ def __eq__(self, other: Any) -> bool:
isinstance(other, DNSAddress) and DNSEntry.__eq__(self, other) and self.address == other.address
)

def __hash__(self) -> int:
"""Hash to compare like DNSAddresses."""
return hash((*self._entry_tuple(), self.address))

def __repr__(self) -> str:
"""String representation"""
try:
Expand All @@ -235,7 +240,7 @@ class DNSHinfo(DNSRecord):
"""A DNS host information record"""

def __init__(self, name: str, type_: int, class_: int, ttl: int, cpu: str, os: str) -> None:
DNSRecord.__init__(self, name, type_, class_, ttl)
super().__init__(name, type_, class_, ttl)
self.cpu = cpu
self.os = os

Expand All @@ -253,6 +258,10 @@ def __eq__(self, other: Any) -> bool:
and self.os == other.os
)

def __hash__(self) -> int:
"""Hash to compare like DNSHinfo."""
return hash((*self._entry_tuple(), self.cpu, self.os))

def __repr__(self) -> str:
"""String representation"""
return self.to_string(self.cpu + " " + self.os)
Expand All @@ -263,7 +272,7 @@ class DNSPointer(DNSRecord):
"""A DNS pointer record"""

def __init__(self, name: str, type_: int, class_: int, ttl: int, alias: str) -> None:
DNSRecord.__init__(self, name, type_, class_, ttl)
super().__init__(name, type_, class_, ttl)
self.alias = alias

def write(self, out: 'DNSOutgoing') -> None:
Expand All @@ -274,6 +283,10 @@ def __eq__(self, other: Any) -> bool:
"""Tests equality on alias"""
return isinstance(other, DNSPointer) and self.alias == other.alias and DNSEntry.__eq__(self, other)

def __hash__(self) -> int:
"""Hash to compare like DNSPointer."""
return hash((*self._entry_tuple(), self.alias))

def __repr__(self) -> str:
"""String representation"""
return self.to_string(self.alias)
Expand All @@ -285,13 +298,17 @@ class DNSText(DNSRecord):

def __init__(self, name: str, type_: int, class_: int, ttl: int, text: bytes) -> None:
assert isinstance(text, (bytes, type(None)))
DNSRecord.__init__(self, name, type_, class_, ttl)
super().__init__(name, type_, class_, ttl)
self.text = text

def write(self, out: 'DNSOutgoing') -> None:
"""Used in constructing an outgoing packet"""
out.write_string(self.text)

def __hash__(self) -> int:
"""Hash to compare like DNSText."""
return hash((*self._entry_tuple(), self.text))

def __eq__(self, other: Any) -> bool:
"""Tests equality on text"""
return isinstance(other, DNSText) and self.text == other.text and DNSEntry.__eq__(self, other)
Expand All @@ -318,7 +335,7 @@ def __init__(
port: int,
server: str,
) -> None:
DNSRecord.__init__(self, name, type_, class_, ttl)
super().__init__(name, type_, class_, ttl)
self.priority = priority
self.weight = weight
self.port = port
Expand All @@ -342,6 +359,10 @@ def __eq__(self, other: Any) -> bool:
and DNSEntry.__eq__(self, other)
)

def __hash__(self) -> int:
"""Hash to compare like DNSService."""
return hash((*self._entry_tuple(), self.priority, self.weight, self.port, self.server))

def __repr__(self) -> str:
"""String representation"""
return self.to_string("%s:%s" % (self.server, self.port))
Expand Down