Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,7 @@ async def test_async_unregister_all_services() -> None:
assert results[1] is not None

await aiozc.async_unregister_all_services()
_clear_cache(aiozc.zeroconf)

tasks = []
tasks.append(aiozc.async_get_service_info(type_, registration_name))
Expand Down
27 changes: 27 additions & 0 deletions tests/test_dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,33 @@ def test_dns_record_is_recent(self):
assert record.is_recent(now + (8 * 1000)) is False


def test_dns_question_hashablity():
"""Test DNSQuestions are hashable."""

record1 = r.DNSQuestion('irrelevant', const._TYPE_A, const._CLASS_IN)
record2 = r.DNSQuestion('irrelevant', const._TYPE_A, const._CLASS_IN)

record_set = {record1, record2}
assert len(record_set) == 1

record_set.add(record1)
assert len(record_set) == 1

record3_dupe = r.DNSQuestion('irrelevant', const._TYPE_A, const._CLASS_IN)
assert record2 == record3_dupe
assert record2.__hash__() == record3_dupe.__hash__()

record_set.add(record3_dupe)
assert len(record_set) == 1

record4_dupe = r.DNSQuestion('notsame', const._TYPE_A, const._CLASS_IN)
assert record2 != record4_dupe
assert record2.__hash__() != record4_dupe.__hash__()

record_set.add(record4_dupe)
assert len(record_set) == 2


def test_dns_record_hashablity_does_not_consider_ttl():
"""Test DNSRecord are hashable."""

Expand Down
48 changes: 30 additions & 18 deletions zeroconf/_dns.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import enum
import socket
from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING, Tuple, Union, cast
from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING, Union, cast

from ._exceptions import AbstractMethodException
from ._utils.net import _is_v6_address
Expand Down Expand Up @@ -81,10 +81,6 @@ def __init__(self, name: str, type_: int, class_: int) -> None:
self.class_ = class_ & _CLASS_MASK
self.unique = (class_ & _CLASS_UNIQUE) != 0

def _entry_tuple(self) -> Tuple[str, int, int]:
"""Entry Tuple for DNSEntry."""
return (self.key, self.type, self.class_)

def __eq__(self, other: Any) -> bool:
"""Equality test on key (lowercase name), type, and class"""
return dns_entry_matches(other, self.key, self.type, self.class_) and isinstance(other, DNSEntry)
Expand Down Expand Up @@ -115,12 +111,22 @@ class DNSQuestion(DNSEntry):

"""A DNS question entry"""

__slots__ = ('_hash',)

def __init__(self, name: str, type_: int, class_: int) -> None:
super().__init__(name, type_, class_)
self._hash = hash((self.key, type_, class_))

def answered_by(self, rec: 'DNSRecord') -> bool:
"""Returns true if the question is answered by the record"""
return self.class_ == rec.class_ and self.type in (rec.type, _TYPE_ANY) and self.name == rec.name

def __hash__(self) -> int:
return hash((self.name, self.class_, self.type))
return self._hash

def __eq__(self, other: Any) -> bool:
"""Tests equality on dns question."""
return isinstance(other, DNSQuestion) and DNSEntry.__eq__(self, other)

@property
def max_size(self) -> int:
Expand Down Expand Up @@ -225,7 +231,7 @@ class DNSAddress(DNSRecord):

"""A DNS address record"""

__slots__ = ('address', 'scope_id')
__slots__ = ('_hash', 'address', 'scope_id')

def __init__(
self,
Expand All @@ -241,6 +247,7 @@ def __init__(
super().__init__(name, type_, class_, ttl, created)
self.address = address
self.scope_id = scope_id
self._hash = hash((self.key, type_, class_, address, scope_id))

def write(self, out: 'DNSOutgoing') -> None:
"""Used in constructing an outgoing packet"""
Expand All @@ -257,7 +264,7 @@ def __eq__(self, other: Any) -> bool:

def __hash__(self) -> int:
"""Hash to compare like DNSAddresses."""
return hash((*self._entry_tuple(), self.address, self.scope_id))
return self._hash

def __repr__(self) -> str:
"""String representation"""
Expand All @@ -275,14 +282,15 @@ class DNSHinfo(DNSRecord):

"""A DNS host information record"""

__slots__ = ('cpu', 'os')
__slots__ = ('_hash', 'cpu', 'os')

def __init__(
self, name: str, type_: int, class_: int, ttl: int, cpu: str, os: str, created: Optional[float] = None
) -> None:
super().__init__(name, type_, class_, ttl, created)
self.cpu = cpu
self.os = os
self._hash = hash((self.key, type_, class_, cpu, os))

def write(self, out: 'DNSOutgoing') -> None:
"""Used in constructing an outgoing packet"""
Expand All @@ -300,7 +308,7 @@ def __eq__(self, other: Any) -> bool:

def __hash__(self) -> int:
"""Hash to compare like DNSHinfo."""
return hash((*self._entry_tuple(), self.cpu, self.os))
return self._hash

def __repr__(self) -> str:
"""String representation"""
Expand All @@ -311,13 +319,14 @@ class DNSPointer(DNSRecord):

"""A DNS pointer record"""

__slots__ = ('alias',)
__slots__ = ('_hash', 'alias')

def __init__(
self, name: str, type_: int, class_: int, ttl: int, alias: str, created: Optional[float] = None
) -> None:
super().__init__(name, type_, class_, ttl, created)
self.alias = alias
self._hash = hash((self.key, type_, class_, alias))

@property
def max_size_compressed(self) -> int:
Expand All @@ -339,7 +348,7 @@ def __eq__(self, other: Any) -> bool:

def __hash__(self) -> int:
"""Hash to compare like DNSPointer."""
return hash((*self._entry_tuple(), self.alias))
return self._hash

def __repr__(self) -> str:
"""String representation"""
Expand All @@ -350,22 +359,23 @@ class DNSText(DNSRecord):

"""A DNS text record"""

__slots__ = ('text',)
__slots__ = ('_hash', 'text')

def __init__(
self, name: str, type_: int, class_: int, ttl: int, text: bytes, created: Optional[float] = None
) -> None:
assert isinstance(text, (bytes, type(None)))
super().__init__(name, type_, class_, ttl, created)
self.text = text
self._hash = hash((self.key, type_, class_, text))

def write(self, out: 'DNSOutgoing') -> None:
"""Used in constructing an outgoing packet"""
out.write_string(self.text)

def __hash__(self) -> int:
"""Hash to compare like DNSText."""
return hash((*self._entry_tuple(), self.text))
return self._hash

def __eq__(self, other: Any) -> bool:
"""Tests equality on text"""
Expand All @@ -382,7 +392,7 @@ class DNSService(DNSRecord):

"""A DNS service record"""

__slots__ = ('priority', 'weight', 'port', 'server')
__slots__ = ('_hash', 'priority', 'weight', 'port', 'server')

def __init__(
self,
Expand All @@ -401,6 +411,7 @@ def __init__(
self.weight = weight
self.port = port
self.server = server
self._hash = hash((self.key, type_, class_, priority, weight, port, server))

def write(self, out: 'DNSOutgoing') -> None:
"""Used in constructing an outgoing packet"""
Expand All @@ -422,7 +433,7 @@ def __eq__(self, other: Any) -> bool:

def __hash__(self) -> int:
"""Hash to compare like DNSService."""
return hash((*self._entry_tuple(), self.priority, self.weight, self.port, self.server))
return self._hash

def __repr__(self) -> str:
"""String representation"""
Expand All @@ -433,7 +444,7 @@ class DNSNsec(DNSRecord):

"""A DNS NSEC record"""

__slots__ = ('next_name', 'rdtypes')
__slots__ = ('_hash', 'next_name', 'rdtypes')

def __init__(
self,
Expand All @@ -448,6 +459,7 @@ def __init__(
super().__init__(name, type_, class_, ttl, created)
self.next_name = next_name
self.rdtypes = rdtypes
self._hash = hash((self.key, type_, class_, next_name, *self.rdtypes))

def __eq__(self, other: Any) -> bool:
"""Tests equality on cpu and os"""
Expand All @@ -460,7 +472,7 @@ def __eq__(self, other: Any) -> bool:

def __hash__(self) -> int:
"""Hash to compare like DNSNSec."""
return hash((*self._entry_tuple(), self.next_name, *self.rdtypes))
return self._hash

def __repr__(self) -> str:
"""String representation"""
Expand Down