2424import itertools
2525from typing import Dict , List , Optional , Set , TYPE_CHECKING , Tuple , Union
2626
27+ from ._cache import DNSCache
2728from ._dns import DNSAddress , DNSIncoming , DNSOutgoing , DNSPointer , DNSQuestion , DNSRecord
2829from ._logger import log
2930from ._services import RecordUpdateListener
@@ -65,12 +66,13 @@ class RecordSetKeys(enum.Enum):
6566class _QueryResponse :
6667 """A pair for unicast and multicast DNSOutgoing responses."""
6768
68- def __init__ (self , msg : DNSIncoming , ucast_source : bool ) -> None :
69+ def __init__ (self , cache : DNSCache , msg : DNSIncoming , ucast_source : bool ) -> None :
6970 """Build a query response."""
7071 self ._msg = msg
7172 self ._ucast_source = ucast_source
7273 self ._is_probe = msg .num_authorities > 0
7374 self ._now = current_time_millis ()
75+ self ._cache = cache
7476 self ._ucast : _RecordSetType = {RecordSetKeys .Answers : set (), RecordSetKeys .Additionals : set ()}
7577 self ._mcast : _RecordSetType = {RecordSetKeys .Answers : set (), RecordSetKeys .Additionals : set ()}
7678
@@ -97,6 +99,9 @@ def outgoing_unicast(self) -> Optional[DNSOutgoing]:
9799
98100 def outgoing_multicast (self ) -> Optional [DNSOutgoing ]:
99101 """Build the outgoing multicast response."""
102+ if not self ._is_probe :
103+ self ._suppress_mcasts_from_last_second (self ._mcast [RecordSetKeys .Answers ])
104+ self ._suppress_mcasts_from_last_second (self ._mcast [RecordSetKeys .Additionals ])
100105 return self ._construct_outgoing_from_record_set (self ._mcast , True )
101106
102107 def _construct_outgoing_from_record_set (
@@ -116,13 +121,26 @@ def _construct_outgoing_from_record_set(
116121 out .add_additional_answer (additional )
117122 return out
118123
124+ def _suppress_mcasts_from_last_second (self , records : Set [DNSRecord ]) -> None :
125+ """Remove any records that were already sent in the last second."""
126+ records -= set (record for record in records if self ._has_mcast_record_in_last_second (record ))
127+
128+ def _has_mcast_record_in_last_second (self , record : DNSRecord ) -> bool :
129+ """Remove answers that were just broadcast
130+ Protect the network against excessive packet flooding
131+ https://datatracker.ietf.org/doc/html/rfc6762#section-14
132+ """
133+ maybe_entry = self ._cache .get (record )
134+ return bool (maybe_entry and self ._now - maybe_entry .created < 1000 )
135+
119136
120137class QueryHandler :
121138 """Query the ServiceRegistry."""
122139
123- def __init__ (self , registry : ServiceRegistry ) -> None :
140+ def __init__ (self , registry : ServiceRegistry , cache : DNSCache ) -> None :
124141 """Init the query handler."""
125142 self .registry = registry
143+ self .cache = cache
126144
127145 def _answer_service_type_enumeration_query (
128146 self ,
@@ -204,7 +222,7 @@ def response( # pylint: disable=unused-argument
204222 ) -> Tuple [Optional [DNSOutgoing ], Optional [DNSOutgoing ]]:
205223 """Deal with incoming query packets. Provides a response if possible."""
206224 ucast_source = port != _MDNS_PORT
207- query_res = _QueryResponse (msg , ucast_source )
225+ query_res = _QueryResponse (self . cache , msg , ucast_source )
208226
209227 for question in msg .questions :
210228 all_answers = self ._answer_any_question (msg , question )
0 commit comments