Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions tests/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,30 @@ def test_ptr_optimization():
zc.close()


def test_aaaa_query():
"""Test that queries for AAAA records work."""
zc = Zeroconf(interfaces=['127.0.0.1'])
type_ = "_knownservice._tcp.local."
name = "knownname"
registration_name = "%s.%s" % (name, type_)
desc = {'path': '/~paulsm/'}
server_name = "ash-2.local."
ipv6_address = socket.inet_pton(socket.AF_INET6, "2001:db8::1")
info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, server_name, addresses=[ipv6_address])
zc.register_service(info)

_clear_cache(zc)
generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
question = r.DNSQuestion(server_name, const._TYPE_AAAA, const._CLASS_IN)
generated.add_question(question)
packets = generated.packets()
_, multicast_out = zc.query_handler.response(r.DNSIncoming(packets[0]), "1.2.3.4", const._MDNS_PORT)
assert multicast_out.answers[0][0].address == ipv6_address
# unregister
zc.unregister_service(info)
zc.close()


def test_unicast_response():
"""Ensure we send a unicast response when the source port is not the MDNS port."""
# instantiate a zeroconf instance
Expand Down
13 changes: 8 additions & 5 deletions zeroconf/_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ._logger import log
from ._services import RecordUpdateListener
from ._services.registry import ServiceRegistry
from ._utils.net import IPVersion
from ._utils.time import current_time_millis
from .const import (
_CLASS_IN,
Expand All @@ -37,12 +38,14 @@
_MDNS_PORT,
_SERVICE_TYPE_ENUMERATION_NAME,
_TYPE_A,
_TYPE_AAAA,
_TYPE_ANY,
_TYPE_PTR,
_TYPE_SRV,
_TYPE_TXT,
)

_TYPE_TO_IP_VERSION = {_TYPE_A: IPVersion.V4Only, _TYPE_AAAA: IPVersion.V6Only, _TYPE_ANY: IPVersion.All}

if TYPE_CHECKING:
# https://github.com/PyCQA/pylint/issues/3525
Expand Down Expand Up @@ -146,10 +149,10 @@ def _add_pointer_answers(
additionals.add(service.dns_text())
additionals.update(service.dns_addresses())

def _add_address_answers(self, name: str, msg: DNSIncoming, answers: Set[DNSRecord]) -> None:
"""Answer address question."""
def _add_address_answers(self, name: str, msg: DNSIncoming, answers: Set[DNSRecord], type_: int) -> None:
"""Answer A/AAAA/ANY question."""
for service in self.registry.get_infos_server(name):
for dns_address in service.dns_addresses():
for dns_address in service.dns_addresses(version=_TYPE_TO_IP_VERSION[type_]):
if not dns_address.suppressed_by(msg):
answers.add(dns_address)

Expand All @@ -163,8 +166,8 @@ def _answer_question(
if type_ == _TYPE_PTR:
self._add_pointer_answers(question.name, msg, answers, additionals)

if type_ in (_TYPE_A, _TYPE_ANY):
self._add_address_answers(question.name, msg, answers)
if type_ in (_TYPE_A, _TYPE_AAAA, _TYPE_ANY):
self._add_address_answers(question.name, msg, answers, type_)

if type_ in (_TYPE_SRV, _TYPE_TXT, _TYPE_ANY):
service = self.registry.get_info_name(question.name) # type: ignore
Expand Down