Skip to content

Commit 0a69aa0

Browse files
authored
RecordUpdateListener now uses update_records instead of update_record (#419)
1 parent 9606936 commit 0a69aa0

3 files changed

Lines changed: 167 additions & 53 deletions

File tree

zeroconf/__init__.py

Lines changed: 107 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@
3535
import time
3636
import warnings
3737
from collections import OrderedDict
38+
from contextlib import contextmanager
3839
from types import TracebackType # noqa # used in type hints
39-
from typing import Dict, Iterable, List, Optional, Type, Union, cast
40+
from typing import Dict, Generator, Iterable, List, Optional, Type, Union, cast
4041
from typing import Any, Callable, Set, Tuple # noqa # used in type hints
4142

4243
import ifaddr
@@ -1424,8 +1425,8 @@ def run(self) -> None:
14241425
now = current_time_millis()
14251426
if now - self._last_cache_cleanup >= self.cache_cleanup_interval_ms:
14261427
self._last_cache_cleanup = now
1427-
for record in self.zc.cache.expire(now):
1428-
self.zc.update_record(now, record)
1428+
with self.zc.update_records(now, list(self.zc.cache.expire(now))):
1429+
pass
14291430

14301431
self.socketpair[0].close()
14311432
self.socketpair[1].close()
@@ -1548,8 +1549,37 @@ def unregister_handler(self, handler: Callable[..., None]) -> 'SignalRegistratio
15481549

15491550

15501551
class RecordUpdateListener:
1551-
def update_record(self, zc: 'Zeroconf', now: float, record: DNSRecord) -> None:
1552-
raise NotImplementedError()
1552+
def update_record( # pylint: disable=no-self-use
1553+
self, zc: 'Zeroconf', now: float, record: DNSRecord
1554+
) -> None:
1555+
"""Update a single record.
1556+
1557+
This method is deprecated and will be removed in a future version.
1558+
update_records should be implemented instead.
1559+
"""
1560+
raise RuntimeError("update_record is deprecated and will be removed in a future version.")
1561+
1562+
def update_records(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None:
1563+
"""Update multiple records in one shot.
1564+
1565+
All records that are received in a single packet are passed
1566+
to update_records.
1567+
1568+
This implementation is a compatiblity shim to ensure older code
1569+
that uses RecordUpdateListener as a base class will continue to
1570+
get calls to update_record. This method will raise
1571+
NotImplementedError in a future version.
1572+
1573+
At this point the cache will not have the new records
1574+
"""
1575+
for record in records:
1576+
self.update_record(zc, now, record)
1577+
1578+
def update_records_complete(self) -> None:
1579+
"""Called when a record update has completed for all handlers.
1580+
1581+
At this point the cache will have the new records.
1582+
"""
15531583

15541584

15551585
class ServiceListener:
@@ -1601,6 +1631,7 @@ def __init__(
16011631
current_time = current_time_millis()
16021632
self._next_time = {check_type_: current_time for check_type_ in self.types}
16031633
self._delay = {check_type_: delay for check_type_ in self.types}
1634+
self._pending_handlers = OrderedDict() # type: OrderedDict[Tuple[str, str], ServiceStateChange]
16041635
self._handlers_to_call = OrderedDict() # type: OrderedDict[Tuple[str, str], ServiceStateChange]
16051636

16061637
self._service_state_changed = Signal()
@@ -1649,30 +1680,32 @@ def _record_matching_type(self, record: DNSRecord) -> Optional[str]:
16491680
"""Return the type if the record matches one of the types we are browsing."""
16501681
return next((type_ for type_ in self.types if record.name.endswith(type_)), None)
16511682

1652-
def update_record(self, zc: 'Zeroconf', now: float, record: DNSRecord) -> None:
1653-
"""Callback invoked by Zeroconf when new information arrives.
1654-
1655-
Updates information required by browser in the Zeroconf cache.
1656-
1657-
Ensures that there is are no unecessary duplicates in the list
1658-
1659-
"""
1660-
1661-
def enqueue_callback(state_change: ServiceStateChange, type_: str, name: str) -> None:
1662-
1663-
# Code to ensure we only do a single update message
1664-
# Precedence is; Added, Remove, Update
1665-
key = (name, type_)
1666-
if (
1667-
state_change is ServiceStateChange.Added
1668-
or (
1669-
state_change is ServiceStateChange.Removed
1670-
and self._handlers_to_call.get(key) != ServiceStateChange.Added
1671-
)
1672-
or (state_change is ServiceStateChange.Updated and key not in self._handlers_to_call)
1673-
):
1674-
self._handlers_to_call[key] = state_change
1683+
def _enqueue_callback(
1684+
self,
1685+
state_change: ServiceStateChange,
1686+
type_: str,
1687+
name: str,
1688+
) -> None:
1689+
# Code to ensure we only do a single update message
1690+
# Precedence is; Added, Remove, Update
1691+
key = (name, type_)
1692+
if (
1693+
state_change is ServiceStateChange.Added
1694+
or (
1695+
state_change is ServiceStateChange.Removed
1696+
and self._pending_handlers.get(key) != ServiceStateChange.Added
1697+
)
1698+
or (state_change is ServiceStateChange.Updated and key not in self._pending_handlers)
1699+
):
1700+
self._pending_handlers[key] = state_change
16751701

1702+
def _process_record_update(
1703+
self,
1704+
zc: 'Zeroconf',
1705+
now: float,
1706+
record: DNSRecord,
1707+
) -> None:
1708+
"""Process a single record update from a batch of updates."""
16761709
expired = record.is_expired(now)
16771710

16781711
if isinstance(record, DNSPointer):
@@ -1683,10 +1716,10 @@ def enqueue_callback(state_change: ServiceStateChange, type_: str, name: str) ->
16831716
old_record = services_by_type.get(service_key)
16841717
if old_record is None:
16851718
services_by_type[service_key] = record
1686-
enqueue_callback(ServiceStateChange.Added, record.name, record.alias)
1719+
self._enqueue_callback(ServiceStateChange.Added, record.name, record.alias)
16871720
elif expired:
16881721
del services_by_type[service_key]
1689-
enqueue_callback(ServiceStateChange.Removed, record.name, record.alias)
1722+
self._enqueue_callback(ServiceStateChange.Removed, record.name, record.alias)
16901723
else:
16911724
old_record.reset_ttl(record)
16921725
expires = record.get_expiration_time(_EXPIRE_REFRESH_TIME_PERCENT)
@@ -1711,14 +1744,32 @@ def enqueue_callback(state_change: ServiceStateChange, type_: str, name: str) ->
17111744
for service in self.zc.cache.entries_with_server(record.name):
17121745
type_ = self._record_matching_type(service)
17131746
if type_:
1714-
enqueue_callback(ServiceStateChange.Updated, type_, service.name)
1747+
self._enqueue_callback(ServiceStateChange.Updated, type_, service.name)
17151748
break
17161749

17171750
return
17181751

17191752
type_ = self._record_matching_type(record)
17201753
if type_:
1721-
enqueue_callback(ServiceStateChange.Updated, type_, record.name)
1754+
self._enqueue_callback(ServiceStateChange.Updated, type_, record.name)
1755+
1756+
def update_records(self, zc: 'Zeroconf', now: float, records: List[DNSRecord]) -> None:
1757+
"""Callback invoked by Zeroconf when new information arrives.
1758+
1759+
Updates information required by browser in the Zeroconf cache.
1760+
1761+
Ensures that there is are no unecessary duplicates in the list.
1762+
"""
1763+
for record in records:
1764+
self._process_record_update(zc, now, record)
1765+
1766+
def update_records_complete(self) -> None:
1767+
"""Called when a record update has completed for all handlers.
1768+
1769+
At this point the cache will have the new records.
1770+
"""
1771+
self._handlers_to_call.update(self._pending_handlers)
1772+
self._pending_handlers.clear()
17221773

17231774
def cancel(self) -> None:
17241775
"""Cancel the browser."""
@@ -1825,9 +1876,7 @@ def run(self) -> None:
18251876
if not self._handlers_to_call:
18261877
continue
18271878

1828-
with self.zc._handlers_lock: # pylint: disable=protected-access
1829-
(name_type, state_change) = self._handlers_to_call.popitem(False)
1830-
1879+
(name_type, state_change) = self._handlers_to_call.popitem(False)
18311880
self._service_state_changed.fire(
18321881
zeroconf=self.zc,
18331882
service_type=name_type[1],
@@ -2689,11 +2738,6 @@ def __init__(
26892738

26902739
self.condition = threading.Condition()
26912740

2692-
# Ensure we create the lock before
2693-
# we add the listener as we could get
2694-
# a message before the lock is created.
2695-
self._handlers_lock = threading.Lock() # ensure we process a full message in one go
2696-
26972741
self.engine = Engine(self)
26982742
self.listener = Listener(self)
26992743
if not unicast:
@@ -2902,12 +2946,17 @@ def add_listener(
29022946
answer the question(s)."""
29032947
now = current_time_millis()
29042948
self.listeners.append(listener)
2949+
records = []
29052950
if question is not None:
29062951
questions = [question] if isinstance(question, DNSQuestion) else question
29072952
for single_question in questions:
29082953
for record in self.cache.entries_with_name(single_question.name):
29092954
if single_question.answered_by(record) and not record.is_expired(now):
2910-
listener.update_record(self, now, record)
2955+
records.append(record)
2956+
2957+
if records:
2958+
listener.update_records(self, now, records)
2959+
listener.update_records_complete()
29112960
self.notify_all()
29122961

29132962
def remove_listener(self, listener: RecordUpdateListener) -> None:
@@ -2918,14 +2967,23 @@ def remove_listener(self, listener: RecordUpdateListener) -> None:
29182967
except Exception as e: # pylint: disable=broad-except # TODO stop catching all Exceptions
29192968
log.exception('Unknown error, possibly benign: %r', e)
29202969

2921-
def update_record(self, now: float, rec: DNSRecord) -> None:
2970+
@contextmanager
2971+
def update_records(self, now: float, rec: List[DNSRecord]) -> Generator:
29222972
"""Used to notify listeners of new information that has updated
2923-
a record."""
2924-
for listener in self.listeners:
2925-
listener.update_record(self, now, rec)
2926-
self.notify_all()
2973+
a record.
2974+
2975+
This method must be called before the cache is updated.
2976+
"""
2977+
try:
2978+
for listener in self.listeners:
2979+
listener.update_records(self, now, rec)
2980+
yield
2981+
finally:
2982+
for listener in self.listeners:
2983+
listener.update_records_complete()
2984+
self.notify_all()
29272985

2928-
def handle_response(self, msg: DNSIncoming) -> None: # pylint: disable=too-many-branches
2986+
def handle_response(self, msg: DNSIncoming) -> None:
29292987
"""Deal with incoming response packets. All answers
29302988
are held in the cache, and listeners are notified."""
29312989
updates = [] # type: List[DNSRecord]
@@ -2967,10 +3025,7 @@ def handle_response(self, msg: DNSIncoming) -> None: # pylint: disable=too-many
29673025
if not updates and not address_adds and not other_adds and not removes:
29683026
return
29693027

2970-
# Only hold the lock if we have updates
2971-
with self._handlers_lock:
2972-
for record in updates:
2973-
self.update_record(now, record)
3028+
with self.update_records(now, updates):
29743029
# The cache adds must be processed AFTER we trigger
29753030
# the updates since we compare existing data
29763031
# with the new data and updating the cache
@@ -2981,7 +3036,7 @@ def handle_response(self, msg: DNSIncoming) -> None: # pylint: disable=too-many
29813036
# otherwise a fetch of ServiceInfo may miss an address
29823037
# because it thinks the cache is complete
29833038
#
2984-
# The cache is processed under the lock to ensure
3039+
# The cache is processed under the context manager to ensure
29853040
# that any ServiceBrowser that is going to call
29863041
# zc.get_service_info will see the cached value
29873042
# but ONLY after all the record updates have been

zeroconf/test.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2552,3 +2552,57 @@ def on_service_state_change(zeroconf, service_type, state_change, name):
25522552
assert not notify_called
25532553

25542554
zc.close()
2555+
2556+
2557+
def test_legacy_record_update_listener():
2558+
"""Test a RecordUpdateListener that does not implement update_records."""
2559+
2560+
# instantiate a zeroconf instance
2561+
zc = Zeroconf(interfaces=['127.0.0.1'])
2562+
2563+
with pytest.raises(RuntimeError):
2564+
r.RecordUpdateListener().update_record(
2565+
zc, 0, r.DNSRecord('irrelevant', r._TYPE_SRV, r._CLASS_IN, r._DNS_HOST_TTL)
2566+
)
2567+
2568+
updates = []
2569+
2570+
class LegacyRecordUpdateListener(r.RecordUpdateListener):
2571+
"""A RecordUpdateListener that does not implement update_records."""
2572+
2573+
def update_record(self, zc: 'Zeroconf', now: float, record: r.DNSRecord) -> None:
2574+
nonlocal updates
2575+
updates.append(record)
2576+
2577+
zc.add_listener(LegacyRecordUpdateListener(), None)
2578+
2579+
# dummy service callback
2580+
def on_service_state_change(zeroconf, service_type, state_change, name):
2581+
pass
2582+
2583+
# start a browser
2584+
type_ = "_homeassistant._tcp.local."
2585+
name = "MyTestHome"
2586+
browser = ServiceBrowser(zc, type_, [on_service_state_change])
2587+
2588+
info_service = ServiceInfo(
2589+
type_,
2590+
'%s.%s' % (name, type_),
2591+
80,
2592+
0,
2593+
0,
2594+
{'path': '/~paulsm/'},
2595+
"ash-2.local.",
2596+
addresses=[socket.inet_aton("10.0.1.2")],
2597+
)
2598+
2599+
zc.register_service(info_service)
2600+
2601+
zc.wait(1)
2602+
2603+
browser.cancel()
2604+
2605+
assert len(updates)
2606+
assert len([isinstance(update, r.DNSPointer) and update.name == type_ for update in updates]) >= 1
2607+
2608+
zc.close()

zeroconf/test_asyncio.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,9 +435,14 @@ def update_service(self, aiozc: AsyncZeroconf, type: str, name: str) -> None:
435435
await task
436436
task = await aiozc.async_unregister_service(new_info)
437437
await task
438+
await aiozc.async_wait(1)
438439
await aiozc.async_close()
439440

440-
assert calls[0] == ('add', type_, registration_name)
441+
assert calls == [
442+
('add', type_, registration_name),
443+
('update', type_, registration_name),
444+
('remove', type_, registration_name),
445+
]
441446

442447

443448
@pytest.mark.asyncio

0 commit comments

Comments
 (0)