Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions tests/services/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,41 @@ def test_lookups(self):
assert registry.get_infos_type(type_) == [info]
assert registry.get_infos_server("ash-2.local.") == [info]
assert registry.get_types() == [type_]

def test_lookups_upper_case_by_lower_case(self):
type_ = "_test-SRVC-type._tcp.local."
name = "Xxxyyy"
registration_name = "%s.%s" % (name, type_)

desc = {'path': '/~paulsm/'}
info = ServiceInfo(
type_, registration_name, 80, 0, 0, desc, "ASH-2.local.", addresses=[socket.inet_aton("10.0.1.2")]
)

registry = r.ServiceRegistry()
registry.add(info)

assert registry.get_service_infos() == [info]
assert registry.get_info_name(registration_name.lower()) == info
assert registry.get_infos_type(type_.lower()) == [info]
assert registry.get_infos_server("ash-2.local.") == [info]
assert registry.get_types() == [type_.lower()]

def test_lookups_lower_case_by_upper_case(self):
type_ = "_test-srvc-type._tcp.local."
name = "xxxyyy"
registration_name = "%s.%s" % (name, type_)

desc = {'path': '/~paulsm/'}
info = ServiceInfo(
type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[socket.inet_aton("10.0.1.2")]
)

registry = r.ServiceRegistry()
registry.add(info)

assert registry.get_service_infos() == [info]
assert registry.get_info_name(registration_name.upper()) == info
assert registry.get_infos_type(type_.upper()) == [info]
assert registry.get_infos_server("ASH-2.local.") == [info]
assert registry.get_types() == [type_]
7 changes: 3 additions & 4 deletions zeroconf/_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def _answer_service_type_enumeration_query(self, msg: DNSIncoming, out: DNSOutgo

def _answer_ptr_query(self, msg: DNSIncoming, out: DNSOutgoing, question: DNSQuestion) -> None:
"""Answer a PTR query."""
for service in self.registry.get_infos_type(question.name.lower()):
for service in self.registry.get_infos_type(question.name):
out.add_answer(msg, service.dns_pointer())
# Add recommended additional answers according to
# https://tools.ietf.org/html/rfc6763#section-12.1.
Expand All @@ -87,14 +87,13 @@ def _answer_non_ptr_query(self, msg: DNSIncoming, out: DNSOutgoing, question: DN

Add answer(s) for A, AAAA, SRV, or TXT queries.
"""
name_to_find = question.name.lower()
# Answer A record queries for any service addresses we know
if question.type in (_TYPE_A, _TYPE_ANY):
for service in self.registry.get_infos_server(name_to_find):
for service in self.registry.get_infos_server(question.name):
for dns_address in service.dns_addresses():
out.add_answer(msg, dns_address)

service = self.registry.get_info_name(name_to_find) # type: ignore
service = self.registry.get_info_name(question.name) # type: ignore
if service is None:
return

Expand Down
12 changes: 6 additions & 6 deletions zeroconf/_services/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_service_infos(self) -> List[ServiceInfo]:

def get_info_name(self, name: str) -> Optional[ServiceInfo]:
"""Return all ServiceInfo for the name."""
return self._services.get(name)
return self._services.get(name.lower())

def get_types(self) -> List[str]:
"""Return all types."""
Expand All @@ -88,7 +88,7 @@ def _get_by_index(self, attr: str, key: str) -> List[ServiceInfo]:
"""Return all ServiceInfo matching the index."""
service_infos = []

for name in getattr(self, attr).get(key, [])[:]:
for name in getattr(self, attr).get(key.lower(), [])[:]:
info = self._services.get(name)
# Since we do not get under a lock since it would be
# a performance issue, its possible
Expand All @@ -106,13 +106,13 @@ def _add(self, info: ServiceInfo) -> None:
raise ServiceNameAlreadyRegistered

self._services[lower_name] = info
self.types.setdefault(info.type, []).append(lower_name)
self.servers.setdefault(info.server, []).append(lower_name)
self.types.setdefault(info.type.lower(), []).append(lower_name)
self.servers.setdefault(info.server.lower(), []).append(lower_name)

def _remove(self, info: ServiceInfo) -> None:
"""Remove a service under the lock."""
lower_name = info.name.lower()
old_service_info = self._services[lower_name]
self.types[old_service_info.type].remove(lower_name)
self.servers[old_service_info.server].remove(lower_name)
self.types[old_service_info.type.lower()].remove(lower_name)
self.servers[old_service_info.server.lower()].remove(lower_name)
del self._services[lower_name]