Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions src/zeroconf/_handlers/query_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"""


from typing import TYPE_CHECKING, List, Set, cast
from typing import TYPE_CHECKING, List, Optional, Set, cast

from .._cache import DNSCache, _UniqueRecordsType
from .._dns import DNSAddress, DNSPointer, DNSQuestion, DNSRecord, DNSRRSet
Expand Down Expand Up @@ -109,19 +109,20 @@ def add_mcast_question_response(self, answers: _AnswerWithAdditionalsType) -> No
else:
self._mcast_aggregate.add(answer)

def _generate_answers_with_additionals(self, rrset: Set[DNSRecord]) -> _AnswerWithAdditionalsType:
"""Create answers with additionals from an rrset."""
return {record: self._additionals[record] for record in rrset}

def answers(
self,
) -> QuestionAnswers:
"""Return answer sets that will be queued."""
return QuestionAnswers(
self._generate_answers_with_additionals(self._ucast),
self._generate_answers_with_additionals(self._mcast_now),
self._generate_answers_with_additionals(self._mcast_aggregate),
self._generate_answers_with_additionals(self._mcast_aggregate_last_second),
*(
{record: self._additionals[record] for record in rrset}
for rrset in (
self._ucast,
self._mcast_now,
self._mcast_aggregate,
self._mcast_aggregate_last_second,
)
)
)

def _has_mcast_within_one_quarter_ttl(self, record: DNSRecord) -> bool:
Expand Down Expand Up @@ -224,17 +225,16 @@ def _answer_question(
self,
question: DNSQuestion,
known_answers: DNSRRSet,
now: float,
) -> _AnswerWithAdditionalsType:
"""Answer a question."""
answer_set: _AnswerWithAdditionalsType = {}
question_lower_name = question.name.lower()
type_ = question.type

if question.type == _TYPE_PTR and question_lower_name == _SERVICE_TYPE_ENUMERATION_NAME:
if type_ == _TYPE_PTR and question_lower_name == _SERVICE_TYPE_ENUMERATION_NAME:
self._add_service_type_enumeration_query_answers(answer_set, known_answers)
return answer_set

type_ = question.type

if type_ in (_TYPE_PTR, _TYPE_ANY):
self._add_pointer_answers(question_lower_name, answer_set, known_answers)

Expand Down Expand Up @@ -267,12 +267,15 @@ def async_response( # pylint: disable=unused-argument
"""
known_answers = DNSRRSet([msg.answers for msg in msgs if not msg.is_probe])
query_res = _QueryResponse(self.cache, msgs)
known_answers_set: Optional[Set[DNSRecord]] = None

for msg in msgs:
for question in msg.questions:
if not question.unicast:
self.question_history.add_question_at_time(question, msg.now, set(known_answers.lookup))
answer_set = self._answer_question(question, known_answers, msg.now)
if not known_answers_set: # pragma: no branch
known_answers_set = set(known_answers.lookup)
self.question_history.add_question_at_time(question, msg.now, known_answers_set)
answer_set = self._answer_question(question, known_answers)
if not ucast_source and question.unicast:
query_res.add_qu_question_response(answer_set)
continue
Expand Down