Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 107 additions & 52 deletions zeroconf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@
import time
import warnings
from collections import OrderedDict
from contextlib import contextmanager
from types import TracebackType # noqa # used in type hints
from typing import Dict, Iterable, List, Optional, Type, Union, cast
from typing import Dict, Generator, Iterable, List, Optional, Type, Union, cast
from typing import Any, Callable, Set, Tuple # noqa # used in type hints

import ifaddr
Expand Down Expand Up @@ -1424,8 +1425,8 @@ def run(self) -> None:
now = current_time_millis()
if now - self._last_cache_cleanup >= self.cache_cleanup_interval_ms:
self._last_cache_cleanup = now
for record in self.zc.cache.expire(now):
self.zc.update_record(now, record)
with self.zc.update_records(now, list(self.zc.cache.expire(now))):
pass

self.socketpair[0].close()
self.socketpair[1].close()
Expand Down Expand Up @@ -1548,8 +1549,37 @@ def unregister_handler(self, handler: Callable[..., None]) -> 'SignalRegistratio


class RecordUpdateListener:
def update_record(self, zc: 'Zeroconf', now: float, record: DNSRecord) -> None:
raise NotImplementedError()
def update_record( # pylint: disable=no-self-use
self, zc: 'Zeroconf', now: float, record: DNSRecord
) -> None:
"""Update a single record.

This method is deprecated and will be removed in a future version.
update_records should be implemented instead.
"""
raise RuntimeError("update_record is deprecated and will be removed in a future version.")

def update_records(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None:
"""Update multiple records in one shot.

All records that are received in a single packet are passed
to update_records.

This implementation is a compatiblity shim to ensure older code
that uses RecordUpdateListener as a base class will continue to
get calls to update_record. This method will raise
NotImplementedError in a future version.

At this point the cache will not have the new records
"""
for record in records:
self.update_record(zc, now, record)

def update_records_complete(self) -> None:
"""Called when a record update has completed for all handlers.

At this point the cache will have the new records.
"""


class ServiceListener:
Expand Down Expand Up @@ -1601,6 +1631,7 @@ def __init__(
current_time = current_time_millis()
self._next_time = {check_type_: current_time for check_type_ in self.types}
self._delay = {check_type_: delay for check_type_ in self.types}
self._pending_handlers = OrderedDict() # type: OrderedDict[Tuple[str, str], ServiceStateChange]
self._handlers_to_call = OrderedDict() # type: OrderedDict[Tuple[str, str], ServiceStateChange]

self._service_state_changed = Signal()
Expand Down Expand Up @@ -1649,30 +1680,32 @@ def _record_matching_type(self, record: DNSRecord) -> Optional[str]:
"""Return the type if the record matches one of the types we are browsing."""
return next((type_ for type_ in self.types if record.name.endswith(type_)), None)

def update_record(self, zc: 'Zeroconf', now: float, record: DNSRecord) -> None:
"""Callback invoked by Zeroconf when new information arrives.

Updates information required by browser in the Zeroconf cache.

Ensures that there is are no unecessary duplicates in the list

"""

def enqueue_callback(state_change: ServiceStateChange, type_: str, name: str) -> None:

# Code to ensure we only do a single update message
# Precedence is; Added, Remove, Update
key = (name, type_)
if (
state_change is ServiceStateChange.Added
or (
state_change is ServiceStateChange.Removed
and self._handlers_to_call.get(key) != ServiceStateChange.Added
)
or (state_change is ServiceStateChange.Updated and key not in self._handlers_to_call)
):
self._handlers_to_call[key] = state_change
def _enqueue_callback(
self,
state_change: ServiceStateChange,
type_: str,
name: str,
) -> None:
# Code to ensure we only do a single update message
# Precedence is; Added, Remove, Update
key = (name, type_)
if (
state_change is ServiceStateChange.Added
or (
state_change is ServiceStateChange.Removed
and self._pending_handlers.get(key) != ServiceStateChange.Added
)
or (state_change is ServiceStateChange.Updated and key not in self._pending_handlers)
):
self._pending_handlers[key] = state_change

def _process_record_update(
self,
zc: 'Zeroconf',
now: float,
record: DNSRecord,
) -> None:
"""Process a single record update from a batch of updates."""
expired = record.is_expired(now)

if isinstance(record, DNSPointer):
Expand All @@ -1683,10 +1716,10 @@ def enqueue_callback(state_change: ServiceStateChange, type_: str, name: str) ->
old_record = services_by_type.get(service_key)
if old_record is None:
services_by_type[service_key] = record
enqueue_callback(ServiceStateChange.Added, record.name, record.alias)
self._enqueue_callback(ServiceStateChange.Added, record.name, record.alias)
elif expired:
del services_by_type[service_key]
enqueue_callback(ServiceStateChange.Removed, record.name, record.alias)
self._enqueue_callback(ServiceStateChange.Removed, record.name, record.alias)
else:
old_record.reset_ttl(record)
expires = record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT)
Expand All @@ -1711,14 +1744,32 @@ def enqueue_callback(state_change: ServiceStateChange, type_: str, name: str) ->
for service in self.zc.cache.entries_with_server(record.name):
type_ = self._record_matching_type(service)
if type_:
enqueue_callback(ServiceStateChange.Updated, type_, service.name)
self._enqueue_callback(ServiceStateChange.Updated, type_, service.name)
break

return

type_ = self._record_matching_type(record)
if type_:
enqueue_callback(ServiceStateChange.Updated, type_, record.name)
self._enqueue_callback(ServiceStateChange.Updated, type_, record.name)

def update_records(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None:
"""Callback invoked by Zeroconf when new information arrives.

Updates information required by browser in the Zeroconf cache.

Ensures that there is are no unecessary duplicates in the list.
"""
for record in records:
self._process_record_update(zc, now, record)

def update_records_complete(self) -> None:
"""Called when a record update has completed for all handlers.

At this point the cache will have the new records.
"""
self._handlers_to_call.update(self._pending_handlers)
self._pending_handlers.clear()

def cancel(self) -> None:
"""Cancel the browser."""
Expand Down Expand Up @@ -1825,9 +1876,7 @@ def run(self) -> None:
if not self._handlers_to_call:
continue

with self.zc._handlers_lock: # pylint: disable=protected-access
(name_type, state_change) = self._handlers_to_call.popitem(False)

(name_type, state_change) = self._handlers_to_call.popitem(False)
self._service_state_changed.fire(
zeroconf=self.zc,
service_type=name_type[1],
Expand Down Expand Up @@ -2689,11 +2738,6 @@ def __init__(

self.condition = threading.Condition()

# Ensure we create the lock before
# we add the listener as we could get
# a message before the lock is created.
self._handlers_lock = threading.Lock() # ensure we process a full message in one go

self.engine = Engine(self)
self.listener = Listener(self)
if not unicast:
Expand Down Expand Up @@ -2902,12 +2946,17 @@ def add_listener(
answer the question(s)."""
now = current_time_millis()
self.listeners.append(listener)
records = []
if question is not None:
questions = [question] if isinstance(question, DNSQuestion) else question
for single_question in questions:
for record in self.cache.entries_with_name(single_question.name):
if single_question.answered_by(record) and not record.is_expired(now):
listener.update_record(self, now, record)
records.append(record)

if records:
listener.update_records(self, now, records)
Comment thread
bdraco marked this conversation as resolved.
Outdated
listener.update_records_complete()
self.notify_all()

def remove_listener(self, listener: RecordUpdateListener) -> None:
Expand All @@ -2918,14 +2967,23 @@ def remove_listener(self, listener: RecordUpdateListener) -> None:
except Exception as e: # pylint: disable=broad-except # TODO stop catching all Exceptions
log.exception('Unknown error, possibly benign: %r', e)

def update_record(self, now: float, rec: DNSRecord) -> None:
@contextmanager
def update_records(self, now: float, rec: List[DNSRecord]) -> Generator:
"""Used to notify listeners of new information that has updated
a record."""
for listener in self.listeners:
listener.update_record(self, now, rec)
self.notify_all()
a record.

This method must be called before the cache is updated.
"""
try:
for listener in self.listeners:
listener.update_records(self, now, rec)
yield
finally:
for listener in self.listeners:
listener.update_records_complete()
self.notify_all()

def handle_response(self, msg: DNSIncoming) -> None: # pylint: disable=too-many-branches
def handle_response(self, msg: DNSIncoming) -> None:
"""Deal with incoming response packets. All answers
are held in the cache, and listeners are notified."""
updates = [] # type: List[DNSRecord]
Expand Down Expand Up @@ -2967,10 +3025,7 @@ def handle_response(self, msg: DNSIncoming) -> None: # pylint: disable=too-many
if not updates and not address_adds and not other_adds and not removes:
return

# Only hold the lock if we have updates
with self._handlers_lock:
for record in updates:
self.update_record(now, record)
with self.update_records(now, updates):
# The cache adds must be processed AFTER we trigger
# the updates since we compare existing data
# with the new data and updating the cache
Expand All @@ -2981,7 +3036,7 @@ def handle_response(self, msg: DNSIncoming) -> None: # pylint: disable=too-many
# otherwise a fetch of ServiceInfo may miss an address
# because it thinks the cache is complete
#
# The cache is processed under the lock to ensure
# The cache is processed under the context manager to ensure
# that any ServiceBrowser that is going to call
# zc.get_service_info will see the cached value
# but ONLY after all the record updates have been
Expand Down
54 changes: 54 additions & 0 deletions zeroconf/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2552,3 +2552,57 @@ def on_service_state_change(zeroconf, service_type, state_change, name):
assert not notify_called

zc.close()


def test_legacy_record_update_listener():
"""Test a RecordUpdateListener that does not implement update_records."""

# instantiate a zeroconf instance
zc = Zeroconf(interfaces=['127.0.0.1'])

with pytest.raises(RuntimeError):
r.RecordUpdateListener().update_record(
zc, 0, r.DNSRecord('irrelevant', r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL)
)

updates = []

class LegacyRecordUpdateListener(r.RecordUpdateListener):
"""A RecordUpdateListener that does not implement update_records."""

def update_record(self, zc: 'Zeroconf', now: float, record: r.DNSRecord) -> None:
nonlocal updates
updates.append(record)

zc.add_listener(LegacyRecordUpdateListener(), None)

# dummy service callback
def on_service_state_change(zeroconf, service_type, state_change, name):
pass

# start a browser
type_ = "_homeassistant._tcp.local."
name = "MyTestHome"
browser = ServiceBrowser(zc, type_, [on_service_state_change])

info_service = ServiceInfo(
type_,
'%s.%s' % (name, type_),
80,
0,
0,
{'path': '/~paulsm/'},
"ash-2.local.",
addresses=[socket.inet_aton("10.0.1.2")],
)

zc.register_service(info_service)

zc.wait(1)

browser.cancel()

assert len(updates)
assert len([isinstance(update, r.DNSPointer) and update.name == type_ for update in updates]) >= 1

zc.close()
7 changes: 6 additions & 1 deletion zeroconf/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,9 +435,14 @@ def update_service(self, aiozc: AsyncZeroconf, type: str, name: str) -> None:
await task
task = await aiozc.async_unregister_service(new_info)
await task
await aiozc.async_wait(1)
await aiozc.async_close()

assert calls[0] == ('add', type_, registration_name)
assert calls == [
('add', type_, registration_name),
('update', type_, registration_name),
('remove', type_, registration_name),
]


@pytest.mark.asyncio
Expand Down