Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions tests/services/test_browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,3 +1017,86 @@ async def test_query_scheduler():
assert set(query_scheduler.process_ready_types(now + delay * 20)) == set()

assert set(query_scheduler.process_ready_types(now + delay * 31)) == {"_http._tcp.local."}


def test_service_browser_matching():
"""Test that the ServiceBrowser matching does not match partial names."""

# instantiate a zeroconf instance
zc = Zeroconf(interfaces=['127.0.0.1'])
# start a browser
type_ = "_http._tcp.local."
registration_name = "xxxyyy.%s" % type_
not_match_type_ = "_asustor-looksgood_http._tcp.local."
not_match_registration_name = "xxxyyy.%s" % not_match_type_
callbacks = []

class MyServiceListener(r.ServiceListener):
def add_service(self, zc, type_, name) -> None:
nonlocal callbacks
if name == registration_name:
callbacks.append(("add", type_, name))

def remove_service(self, zc, type_, name) -> None:
nonlocal callbacks
if name == registration_name:
callbacks.append(("remove", type_, name))

def update_service(self, zc, type_, name) -> None:
nonlocal callbacks
if name == registration_name:
callbacks.append(("update", type_, name))

listener = MyServiceListener()

browser = r.ServiceBrowser(zc, type_, None, listener)

desc = {'path': '/~paulsm/'}
address_parsed = "10.0.1.2"
address = socket.inet_aton(address_parsed)
info = ServiceInfo(type_, registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address])
should_not_match = ServiceInfo(
not_match_type_, not_match_registration_name, 80, 0, 0, desc, "ash-2.local.", addresses=[address]
)

def mock_incoming_msg(records) -> r.DNSIncoming:
generated = r.DNSOutgoing(const._FLAGS_QR_RESPONSE)
for record in records:
generated.add_answer_at_time(record, 0)
return r.DNSIncoming(generated.packets()[0])

_inject_response(
zc,
mock_incoming_msg([info.dns_pointer(), info.dns_service(), info.dns_text(), *info.dns_addresses()]),
)
_inject_response(
zc,
mock_incoming_msg(
[
should_not_match.dns_pointer(),
should_not_match.dns_service(),
should_not_match.dns_text(),
*should_not_match.dns_addresses(),
]
),
)
time.sleep(0.2)
info.port = 400
_inject_response(
zc,
mock_incoming_msg([info.dns_service()]),
)
should_not_match.port = 400
_inject_response(
zc,
mock_incoming_msg([should_not_match.dns_service()]),
)
time.sleep(0.2)

assert callbacks == [
('add', type_, registration_name),
('update', type_, registration_name),
]
browser.cancel()

zc.close()
36 changes: 28 additions & 8 deletions tests/test_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@ def teardown_module():
class ListenerTest(unittest.TestCase):
def test_integration_with_listener_class(self):

sub_service_added = Event()
service_added = Event()
service_removed = Event()
service_updated = Event()
service_updated2 = Event()
sub_service_updated = Event()
duplicate_service_added = Event()

subtype_name = "My special Subtype"
type_ = "_http._tcp.local."
Expand All @@ -58,21 +59,32 @@ def remove_service(self, zeroconf, type, name):
service_removed.set()

def update_service(self, zeroconf, type, name):
service_updated2.set()
pass

class DuplicateListener(r.ServiceListener):
def add_service(self, zeroconf, type, name):
duplicate_service_added.set()

def remove_service(self, zeroconf, type, name):
pass

def update_service(self, zeroconf, type, name):
pass

class MySubListener(r.ServiceListener):
def add_service(self, zeroconf, type, name):
sub_service_added.set()
pass

def remove_service(self, zeroconf, type, name):
pass

def update_service(self, zeroconf, type, name):
service_updated.set()
sub_service_updated.set()

listener = MyListener()
zeroconf_browser = Zeroconf(interfaces=['127.0.0.1'])
zeroconf_browser.add_service_listener(subtype, listener)
zeroconf_browser.add_service_listener(type_, listener)

properties = dict(
prop_none=None,
Expand Down Expand Up @@ -107,6 +119,11 @@ def update_service(self, zeroconf, type, name):
# short pause to allow multicast timers to expire
time.sleep(3)

zeroconf_browser.add_service_listener(type_, DuplicateListener())
duplicate_service_added.wait(
1
) # Ensure a listener for the same type calls back right away from cache

# clear the answer cache to force query
_clear_cache(zeroconf_browser)

Expand Down Expand Up @@ -160,7 +177,9 @@ def update_service(self, zeroconf, type, name):

# test TXT record update
sublistener = MySubListener()
zeroconf_browser.add_service_listener(registration_name, sublistener)

zeroconf_browser.add_service_listener(subtype, sublistener)

properties['prop_blank'] = b'an updated string'
desc.update(properties)
info_service = ServiceInfo(
Expand All @@ -174,8 +193,9 @@ def update_service(self, zeroconf, type, name):
addresses=[socket.inet_aton("10.0.1.2")],
)
zeroconf_registrar.update_service(info_service)
service_updated.wait(1)
assert service_updated.is_set()

sub_service_added.wait(1) # we cleared the cache above
assert sub_service_added.is_set()

info = zeroconf_browser.get_service_info(type_, registration_name)
assert info is not None
Expand Down
12 changes: 6 additions & 6 deletions zeroconf/_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,12 +515,12 @@ def _async_update_matching_records(
This function must be run from the event loop.
"""
now = current_time_millis()
records: List[RecordUpdate] = []
for question in questions:
for record in self.cache.async_entries_with_name(question.name):
if not record.is_expired(now) and question.answered_by(record):
records.append(RecordUpdate(record, None))

records: List[RecordUpdate] = [
RecordUpdate(record, None)
for question in questions
for record in self.cache.async_entries_with_name(question.name)
if not record.is_expired(now) and question.answered_by(record)
]
if not records:
return
listener.async_update_records(self.zc, now, records)
Expand Down
43 changes: 22 additions & 21 deletions zeroconf/_services/browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import threading
import warnings
from collections import OrderedDict
from typing import Callable, Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union, cast
from typing import Callable, Dict, Iterable, List, Optional, Set, TYPE_CHECKING, Tuple, Union, cast

from .._dns import DNSAddress, DNSPointer, DNSQuestion, DNSQuestionType, DNSRecord
from .._logger import log
Expand Down Expand Up @@ -324,9 +324,9 @@ def _async_start(self) -> None:
def service_state_changed(self) -> SignalRegistrationInterface:
return self._service_state_changed.registration_interface

def _record_matching_type(self, record: DNSRecord) -> Optional[str]:
"""Return the type if the record matches one of the types we are browsing."""
return next((type_ for type_ in self.types if record.name.endswith(type_)), None)
def _names_matching_types(self, names: Iterable[str]) -> List[Tuple[str, str]]:
"""Return the type and name for records matching the types we are browsing."""
return [(type_, name) for type_ in self.types for name in names if name.endswith(f".{type_}")]

def _enqueue_callback(
self,
Expand All @@ -352,14 +352,18 @@ def _async_process_record_update(
) -> None:
"""Process a single record update from a batch of updates."""
if isinstance(record, DNSPointer):
if record.name not in self.types:
return
if old_record is None:
self._enqueue_callback(ServiceStateChange.Added, record.name, record.alias)
elif record.is_expired(now):
self._enqueue_callback(ServiceStateChange.Removed, record.name, record.alias)
else:
self.reschedule_type(record.name, record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT))
name = record.name
alias = record.alias
matches = self._names_matching_types((alias,))
if name in self.types:
matches.append((name, alias))
for type_, name in matches:
if old_record is None:
self._enqueue_callback(ServiceStateChange.Added, type_, name)
elif record.is_expired(now):
self._enqueue_callback(ServiceStateChange.Removed, type_, name)
else:
self.reschedule_type(type_, record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT))
return

# If its expired or already exists in the cache it cannot be updated.
Expand All @@ -368,17 +372,14 @@ def _async_process_record_update(

if isinstance(record, DNSAddress):
# Iterate through the DNSCache and callback any services that use this address
for service in self.zc.cache.async_entries_with_server(record.name):
type_ = self._record_matching_type(service)
if type_:
self._enqueue_callback(ServiceStateChange.Updated, type_, service.name)
break

for type_, name in self._names_matching_types(
{service.name for service in self.zc.cache.async_entries_with_server(record.name)}
):
self._enqueue_callback(ServiceStateChange.Updated, type_, name)
return

type_ = self._record_matching_type(record)
if type_:
self._enqueue_callback(ServiceStateChange.Updated, type_, record.name)
for type_, name in self._names_matching_types((record.name,)):
self._enqueue_callback(ServiceStateChange.Updated, type_, name)

def async_update_records(self, zc: 'Zeroconf', now: float, records: List[RecordUpdate]) -> None:
"""Callback invoked by Zeroconf when new information arrives.
Expand Down