Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions tests/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from zeroconf import ServiceInfo, Zeroconf, current_time_millis
from zeroconf import const

from . import _clear_cache
from . import _clear_cache, _inject_response

log = logging.getLogger('zeroconf')
original_logging_level = logging.NOTSET
Expand Down Expand Up @@ -227,7 +227,19 @@ def test_ptr_optimization():
# register
zc.register_service(info)

# query
# Verify we won't respond for 1s with the same multicast
query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA)
query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN))
unicast_out, multicast_out = zc.query_handler.response(
r.DNSIncoming(query.packets()[0]), None, const._MDNS_PORT
)
assert unicast_out is None
assert multicast_out is None

# Clear the cache to allow responding again
_clear_cache(zc)

# Verify we will now respond
query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA)
query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN))
unicast_out, multicast_out = zc.query_handler.response(
Expand Down Expand Up @@ -320,6 +332,7 @@ def test_unicast_response():
)
# register
zc.register_service(info)
_clear_cache(zc)

# query
query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA)
Expand All @@ -329,8 +342,8 @@ def test_unicast_response():
assert out.id == query.id
has_srv = has_txt = has_a = False
nbr_additionals = 0
nbr_answers = len(multicast_out.answers)
nbr_authorities = len(multicast_out.authorities)
nbr_answers = len(out.answers)
nbr_authorities = len(out.authorities)
for answer in out.additionals:
nbr_additionals += 1
if answer.type == const._TYPE_SRV:
Expand Down Expand Up @@ -360,7 +373,7 @@ def test_known_answer_supression():
zc.register_service(info)

now = current_time_millis()

_clear_cache(zc)
# Test PTR supression
generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
question = r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN)
Expand Down
2 changes: 1 addition & 1 deletion zeroconf/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ def __init__(
self._notify_listeners: List[NotifyListener] = []
self.browsers: Dict[ServiceListener, ServiceBrowser] = {}
self.registry = ServiceRegistry()
self.query_handler = QueryHandler(self.registry)
self.cache = DNSCache()
self.query_handler = QueryHandler(self.registry, self.cache)
self.record_manager = RecordManager(self)

self.condition = threading.Condition()
Expand Down
24 changes: 21 additions & 3 deletions zeroconf/_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import itertools
from typing import Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union

from ._cache import DNSCache
from ._dns import DNSAddress, DNSIncoming, DNSOutgoing, DNSPointer, DNSQuestion, DNSRecord
from ._logger import log
from ._services import RecordUpdateListener
Expand Down Expand Up @@ -65,12 +66,13 @@ class RecordSetKeys(enum.Enum):
class _QueryResponse:
"""A pair for unicast and multicast DNSOutgoing responses."""

def __init__(self, msg: DNSIncoming, ucast_source: bool) -> None:
def __init__(self, cache: DNSCache, msg: DNSIncoming, ucast_source: bool) -> None:
"""Build a query response."""
self._msg = msg
self._ucast_source = ucast_source
self._is_probe = msg.num_authorities > 0
self._now = current_time_millis()
self._cache = cache
self._ucast: _RecordSetType = {RecordSetKeys.Answers: set(), RecordSetKeys.Additionals: set()}
self._mcast: _RecordSetType = {RecordSetKeys.Answers: set(), RecordSetKeys.Additionals: set()}

Expand All @@ -97,6 +99,9 @@ def outgoing_unicast(self) -> Optional[DNSOutgoing]:

def outgoing_multicast(self) -> Optional[DNSOutgoing]:
"""Build the outgoing multicast response."""
if not self._is_probe:
self._suppress_mcasts_from_last_second(self._mcast[RecordSetKeys.Answers])
self._suppress_mcasts_from_last_second(self._mcast[RecordSetKeys.Additionals])
return self._construct_outgoing_from_record_set(self._mcast, True)

def _construct_outgoing_from_record_set(
Expand All @@ -116,13 +121,26 @@ def _construct_outgoing_from_record_set(
out.add_additional_answer(additional)
return out

def _suppress_mcasts_from_last_second(self, records: Set[DNSRecord]) -> None:
"""Remove any records that were already sent in the last second."""
records -= set(record for record in records if self._has_mcast_record_in_last_second(record))

def _has_mcast_record_in_last_second(self, record: DNSRecord) -> bool:
"""Remove answers that were just broadcast
Protect the network against excessive packet flooding
https://datatracker.ietf.org/doc/html/rfc6762#section-14
"""
maybe_entry = self._cache.get(record)
return bool(maybe_entry and self._now - maybe_entry.created < 1000)


class QueryHandler:
"""Query the ServiceRegistry."""

def __init__(self, registry: ServiceRegistry) -> None:
def __init__(self, registry: ServiceRegistry, cache: DNSCache) -> None:
"""Init the query handler."""
self.registry = registry
self.cache = cache

def _answer_service_type_enumeration_query(
self,
Expand Down Expand Up @@ -204,7 +222,7 @@ def response( # pylint: disable=unused-argument
) -> Tuple[Optional[DNSOutgoing], Optional[DNSOutgoing]]:
"""Deal with incoming query packets. Provides a response if possible."""
ucast_source = port != _MDNS_PORT
query_res = _QueryResponse(msg, ucast_source)
query_res = _QueryResponse(self.cache, msg, ucast_source)

for question in msg.questions:
all_answers = self._answer_any_question(msg, question)
Expand Down