Skip to content

Commit cda7a47

Browse files
committed
Protect the network against excessive packet flooding
- Implements RFC6762 sec 14 https://datatracker.ietf.org/doc/html/rfc6762#section-14 Closes #395
1 parent b6365aa commit cda7a47

3 files changed

Lines changed: 40 additions & 9 deletions

File tree

tests/test_handlers.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from zeroconf import ServiceInfo, Zeroconf, current_time_millis
1616
from zeroconf import const
1717

18-
from . import _clear_cache
18+
from . import _clear_cache, _inject_response
1919

2020
log = logging.getLogger('zeroconf')
2121
original_logging_level = logging.NOTSET
@@ -227,7 +227,19 @@ def test_ptr_optimization():
227227
# register
228228
zc.register_service(info)
229229

230-
# query
230+
# Verify we won't respond for 1s with the same multicast
231+
query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA)
232+
query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN))
233+
unicast_out, multicast_out = zc.query_handler.response(
234+
r.DNSIncoming(query.packets()[0]), None, const._MDNS_PORT
235+
)
236+
assert unicast_out is None
237+
assert multicast_out is None
238+
239+
# Clear the cache to allow responding again
240+
_clear_cache(zc)
241+
242+
# Verify we will now respond
231243
query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA)
232244
query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN))
233245
unicast_out, multicast_out = zc.query_handler.response(
@@ -320,6 +332,7 @@ def test_unicast_response():
320332
)
321333
# register
322334
zc.register_service(info)
335+
_clear_cache(zc)
323336

324337
# query
325338
query = r.DNSOutgoing(const._FLAGS_QR_QUERY | const._FLAGS_AA)
@@ -329,8 +342,8 @@ def test_unicast_response():
329342
assert out.id == query.id
330343
has_srv = has_txt = has_a = False
331344
nbr_additionals = 0
332-
nbr_answers = len(multicast_out.answers)
333-
nbr_authorities = len(multicast_out.authorities)
345+
nbr_answers = len(out.answers)
346+
nbr_authorities = len(out.authorities)
334347
for answer in out.additionals:
335348
nbr_additionals += 1
336349
if answer.type == const._TYPE_SRV:
@@ -360,7 +373,7 @@ def test_known_answer_supression():
360373
zc.register_service(info)
361374

362375
now = current_time_millis()
363-
376+
_clear_cache(zc)
364377
# Test PTR supression
365378
generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
366379
question = r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN)

zeroconf/_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,8 @@ def __init__(
264264
self._notify_listeners: List[NotifyListener] = []
265265
self.browsers: Dict[ServiceListener, ServiceBrowser] = {}
266266
self.registry = ServiceRegistry()
267-
self.query_handler = QueryHandler(self.registry)
268267
self.cache = DNSCache()
268+
self.query_handler = QueryHandler(self.registry, self.cache)
269269
self.record_manager = RecordManager(self)
270270

271271
self.condition = threading.Condition()

zeroconf/_handlers.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import itertools
2525
from typing import Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union
2626

27+
from ._cache import DNSCache
2728
from ._dns import DNSAddress, DNSIncoming, DNSOutgoing, DNSPointer, DNSQuestion, DNSRecord
2829
from ._logger import log
2930
from ._services import RecordUpdateListener
@@ -65,12 +66,13 @@ class RecordSetKeys(enum.Enum):
6566
class _QueryResponse:
6667
"""A pair for unicast and multicast DNSOutgoing responses."""
6768

68-
def __init__(self, msg: DNSIncoming, ucast_source: bool) -> None:
69+
def __init__(self, cache: DNSCache, msg: DNSIncoming, ucast_source: bool) -> None:
6970
"""Build a query response."""
7071
self._msg = msg
7172
self._ucast_source = ucast_source
7273
self._is_probe = msg.num_authorities > 0
7374
self._now = current_time_millis()
75+
self._cache = cache
7476
self._ucast: _RecordSetType = {RecordSetKeys.Answers: set(), RecordSetKeys.Additionals: set()}
7577
self._mcast: _RecordSetType = {RecordSetKeys.Answers: set(), RecordSetKeys.Additionals: set()}
7678

@@ -97,6 +99,9 @@ def outgoing_unicast(self) -> Optional[DNSOutgoing]:
9799

98100
def outgoing_multicast(self) -> Optional[DNSOutgoing]:
99101
"""Build the outgoing multicast response."""
102+
if not self._is_probe:
103+
self._suppress_mcasts_from_last_second(self._mcast[RecordSetKeys.Answers])
104+
self._suppress_mcasts_from_last_second(self._mcast[RecordSetKeys.Additionals])
100105
return self._construct_outgoing_from_record_set(self._mcast, True)
101106

102107
def _construct_outgoing_from_record_set(
@@ -116,13 +121,26 @@ def _construct_outgoing_from_record_set(
116121
out.add_additional_answer(additional)
117122
return out
118123

124+
def _suppress_mcasts_from_last_second(self, records: Set[DNSRecord]) -> None:
125+
"""Remove any records that were already sent in the last second."""
126+
records -= set(record for record in records if self._has_mcast_record_in_last_second(record))
127+
128+
def _has_mcast_record_in_last_second(self, record: DNSRecord) -> bool:
129+
"""Remove answers that were just broadcast
130+
Protect the network against excessive packet flooding
131+
https://datatracker.ietf.org/doc/html/rfc6762#section-14
132+
"""
133+
maybe_entry = self._cache.get(record)
134+
return bool(maybe_entry and self._now - maybe_entry.created < 1000)
135+
119136

120137
class QueryHandler:
121138
"""Query the ServiceRegistry."""
122139

123-
def __init__(self, registry: ServiceRegistry) -> None:
140+
def __init__(self, registry: ServiceRegistry, cache: DNSCache) -> None:
124141
"""Init the query handler."""
125142
self.registry = registry
143+
self.cache = cache
126144

127145
def _answer_service_type_enumeration_query(
128146
self,
@@ -204,7 +222,7 @@ def response( # pylint: disable=unused-argument
204222
) -> Tuple[Optional[DNSOutgoing], Optional[DNSOutgoing]]:
205223
"""Deal with incoming query packets. Provides a response if possible."""
206224
ucast_source = port != _MDNS_PORT
207-
query_res = _QueryResponse(msg, ucast_source)
225+
query_res = _QueryResponse(self.cache, msg, ucast_source)
208226

209227
for question in msg.questions:
210228
all_answers = self._answer_any_question(msg, question)

0 commit comments

Comments
 (0)