3535import time
3636import warnings
3737from collections import OrderedDict
38+ from contextlib import contextmanager
3839from types import TracebackType # noqa # used in type hints
39- from typing import Dict , Iterable , List , Optional , Type , Union , cast
40+ from typing import Dict , Generator , Iterable , List , Optional , Type , Union , cast
4041from typing import Any , Callable , Set , Tuple # noqa # used in type hints
4142
4243import ifaddr
@@ -1424,8 +1425,8 @@ def run(self) -> None:
14241425 now = current_time_millis ()
14251426 if now - self ._last_cache_cleanup >= self .cache_cleanup_interval_ms :
14261427 self ._last_cache_cleanup = now
1427- for record in self .zc .cache .expire (now ):
1428- self . zc . update_record ( now , record )
1428+ with self . zc . update_records ( now , list ( self .zc .cache .expire (now )) ):
1429+ pass
14291430
14301431 self .socketpair [0 ].close ()
14311432 self .socketpair [1 ].close ()
@@ -1548,8 +1549,37 @@ def unregister_handler(self, handler: Callable[..., None]) -> 'SignalRegistratio
15481549
15491550
15501551class RecordUpdateListener :
1551- def update_record (self , zc : 'Zeroconf' , now : float , record : DNSRecord ) -> None :
1552- raise NotImplementedError ()
1552+ def update_record ( # pylint: disable=no-self-use
1553+ self , zc : 'Zeroconf' , now : float , record : DNSRecord
1554+ ) -> None :
1555+ """Update a single record.
1556+
1557+ This method is deprecated and will be removed in a future version.
1558+ update_records should be implemented instead.
1559+ """
1560+ raise RuntimeError ("update_record is deprecated and will be removed in a future version." )
1561+
1562+ def update_records (self , zc : 'Zeroconf' , now : float , records : List [DNSRecord ]) -> None :
1563+ """Update multiple records in one shot.
1564+
1565+ All records that are received in a single packet are passed
1566+ to update_records.
1567+
1568+ This implementation is a compatiblity shim to ensure older code
1569+ that uses RecordUpdateListener as a base class will continue to
1570+ get calls to update_record. This method will raise
1571+ NotImplementedError in a future version.
1572+
1573+ At this point the cache will not have the new records
1574+ """
1575+ for record in records :
1576+ self .update_record (zc , now , record )
1577+
1578+ def update_records_complete (self ) -> None :
1579+ """Called when a record update has completed for all handlers.
1580+
1581+ At this point the cache will have the new records.
1582+ """
15531583
15541584
15551585class ServiceListener :
@@ -1601,6 +1631,7 @@ def __init__(
16011631 current_time = current_time_millis ()
16021632 self ._next_time = {check_type_ : current_time for check_type_ in self .types }
16031633 self ._delay = {check_type_ : delay for check_type_ in self .types }
1634+ self ._pending_handlers = OrderedDict () # type: OrderedDict[Tuple[str, str], ServiceStateChange]
16041635 self ._handlers_to_call = OrderedDict () # type: OrderedDict[Tuple[str, str], ServiceStateChange]
16051636
16061637 self ._service_state_changed = Signal ()
@@ -1649,30 +1680,32 @@ def _record_matching_type(self, record: DNSRecord) -> Optional[str]:
16491680 """Return the type if the record matches one of the types we are browsing."""
16501681 return next ((type_ for type_ in self .types if record .name .endswith (type_ )), None )
16511682
1652- def update_record (self , zc : 'Zeroconf' , now : float , record : DNSRecord ) -> None :
1653- """Callback invoked by Zeroconf when new information arrives.
1654-
1655- Updates information required by browser in the Zeroconf cache.
1656-
1657- Ensures that there is are no unecessary duplicates in the list
1658-
1659- """
1660-
1661- def enqueue_callback (state_change : ServiceStateChange , type_ : str , name : str ) -> None :
1662-
1663- # Code to ensure we only do a single update message
1664- # Precedence is; Added, Remove, Update
1665- key = (name , type_ )
1666- if (
1667- state_change is ServiceStateChange .Added
1668- or (
1669- state_change is ServiceStateChange .Removed
1670- and self ._handlers_to_call .get (key ) != ServiceStateChange .Added
1671- )
1672- or (state_change is ServiceStateChange .Updated and key not in self ._handlers_to_call )
1673- ):
1674- self ._handlers_to_call [key ] = state_change
1683+ def _enqueue_callback (
1684+ self ,
1685+ state_change : ServiceStateChange ,
1686+ type_ : str ,
1687+ name : str ,
1688+ ) -> None :
1689+ # Code to ensure we only do a single update message
1690+ # Precedence is; Added, Remove, Update
1691+ key = (name , type_ )
1692+ if (
1693+ state_change is ServiceStateChange .Added
1694+ or (
1695+ state_change is ServiceStateChange .Removed
1696+ and self ._pending_handlers .get (key ) != ServiceStateChange .Added
1697+ )
1698+ or (state_change is ServiceStateChange .Updated and key not in self ._pending_handlers )
1699+ ):
1700+ self ._pending_handlers [key ] = state_change
16751701
1702+ def _process_record_update (
1703+ self ,
1704+ zc : 'Zeroconf' ,
1705+ now : float ,
1706+ record : DNSRecord ,
1707+ ) -> None :
1708+ """Process a single record update from a batch of updates."""
16761709 expired = record .is_expired (now )
16771710
16781711 if isinstance (record , DNSPointer ):
@@ -1683,10 +1716,10 @@ def enqueue_callback(state_change: ServiceStateChange, type_: str, name: str) ->
16831716 old_record = services_by_type .get (service_key )
16841717 if old_record is None :
16851718 services_by_type [service_key ] = record
1686- enqueue_callback (ServiceStateChange .Added , record .name , record .alias )
1719+ self . _enqueue_callback (ServiceStateChange .Added , record .name , record .alias )
16871720 elif expired :
16881721 del services_by_type [service_key ]
1689- enqueue_callback (ServiceStateChange .Removed , record .name , record .alias )
1722+ self . _enqueue_callback (ServiceStateChange .Removed , record .name , record .alias )
16901723 else :
16911724 old_record .reset_ttl (record )
16921725 expires = record .get_expiration_time (_EXPIRE_REFRESH_TIME_PERCENT )
@@ -1711,14 +1744,32 @@ def enqueue_callback(state_change: ServiceStateChange, type_: str, name: str) ->
17111744 for service in self .zc .cache .entries_with_server (record .name ):
17121745 type_ = self ._record_matching_type (service )
17131746 if type_ :
1714- enqueue_callback (ServiceStateChange .Updated , type_ , service .name )
1747+ self . _enqueue_callback (ServiceStateChange .Updated , type_ , service .name )
17151748 break
17161749
17171750 return
17181751
17191752 type_ = self ._record_matching_type (record )
17201753 if type_ :
1721- enqueue_callback (ServiceStateChange .Updated , type_ , record .name )
1754+ self ._enqueue_callback (ServiceStateChange .Updated , type_ , record .name )
1755+
1756+ def update_records (self , zc : 'Zeroconf' , now : float , records : List [DNSRecord ]) -> None :
1757+ """Callback invoked by Zeroconf when new information arrives.
1758+
1759+ Updates information required by browser in the Zeroconf cache.
1760+
1761+ Ensures that there is are no unecessary duplicates in the list.
1762+ """
1763+ for record in records :
1764+ self ._process_record_update (zc , now , record )
1765+
1766+ def update_records_complete (self ) -> None :
1767+ """Called when a record update has completed for all handlers.
1768+
1769+ At this point the cache will have the new records.
1770+ """
1771+ self ._handlers_to_call .update (self ._pending_handlers )
1772+ self ._pending_handlers .clear ()
17221773
17231774 def cancel (self ) -> None :
17241775 """Cancel the browser."""
@@ -1825,9 +1876,7 @@ def run(self) -> None:
18251876 if not self ._handlers_to_call :
18261877 continue
18271878
1828- with self .zc ._handlers_lock : # pylint: disable=protected-access
1829- (name_type , state_change ) = self ._handlers_to_call .popitem (False )
1830-
1879+ (name_type , state_change ) = self ._handlers_to_call .popitem (False )
18311880 self ._service_state_changed .fire (
18321881 zeroconf = self .zc ,
18331882 service_type = name_type [1 ],
@@ -2689,11 +2738,6 @@ def __init__(
26892738
26902739 self .condition = threading .Condition ()
26912740
2692- # Ensure we create the lock before
2693- # we add the listener as we could get
2694- # a message before the lock is created.
2695- self ._handlers_lock = threading .Lock () # ensure we process a full message in one go
2696-
26972741 self .engine = Engine (self )
26982742 self .listener = Listener (self )
26992743 if not unicast :
@@ -2902,12 +2946,17 @@ def add_listener(
29022946 answer the question(s)."""
29032947 now = current_time_millis ()
29042948 self .listeners .append (listener )
2949+ records = []
29052950 if question is not None :
29062951 questions = [question ] if isinstance (question , DNSQuestion ) else question
29072952 for single_question in questions :
29082953 for record in self .cache .entries_with_name (single_question .name ):
29092954 if single_question .answered_by (record ) and not record .is_expired (now ):
2910- listener .update_record (self , now , record )
2955+ records .append (record )
2956+
2957+ if records :
2958+ listener .update_records (self , now , records )
2959+ listener .update_records_complete ()
29112960 self .notify_all ()
29122961
29132962 def remove_listener (self , listener : RecordUpdateListener ) -> None :
@@ -2918,14 +2967,23 @@ def remove_listener(self, listener: RecordUpdateListener) -> None:
29182967 except Exception as e : # pylint: disable=broad-except # TODO stop catching all Exceptions
29192968 log .exception ('Unknown error, possibly benign: %r' , e )
29202969
2921- def update_record (self , now : float , rec : DNSRecord ) -> None :
2970+ @contextmanager
2971+ def update_records (self , now : float , rec : List [DNSRecord ]) -> Generator :
29222972 """Used to notify listeners of new information that has updated
2923- a record."""
2924- for listener in self .listeners :
2925- listener .update_record (self , now , rec )
2926- self .notify_all ()
2973+ a record.
2974+
2975+ This method must be called before the cache is updated.
2976+ """
2977+ try :
2978+ for listener in self .listeners :
2979+ listener .update_records (self , now , rec )
2980+ yield
2981+ finally :
2982+ for listener in self .listeners :
2983+ listener .update_records_complete ()
2984+ self .notify_all ()
29272985
2928- def handle_response (self , msg : DNSIncoming ) -> None : # pylint: disable=too-many-branches
2986+ def handle_response (self , msg : DNSIncoming ) -> None :
29292987 """Deal with incoming response packets. All answers
29302988 are held in the cache, and listeners are notified."""
29312989 updates = [] # type: List[DNSRecord]
@@ -2967,10 +3025,7 @@ def handle_response(self, msg: DNSIncoming) -> None: # pylint: disable=too-many
29673025 if not updates and not address_adds and not other_adds and not removes :
29683026 return
29693027
2970- # Only hold the lock if we have updates
2971- with self ._handlers_lock :
2972- for record in updates :
2973- self .update_record (now , record )
3028+ with self .update_records (now , updates ):
29743029 # The cache adds must be processed AFTER we trigger
29753030 # the updates since we compare existing data
29763031 # with the new data and updating the cache
@@ -2981,7 +3036,7 @@ def handle_response(self, msg: DNSIncoming) -> None: # pylint: disable=too-many
29813036 # otherwise a fetch of ServiceInfo may miss an address
29823037 # because it thinks the cache is complete
29833038 #
2984- # The cache is processed under the lock to ensure
3039+ # The cache is processed under the context manager to ensure
29853040 # that any ServiceBrowser that is going to call
29863041 # zc.get_service_info will see the cached value
29873042 # but ONLY after all the record updates have been
0 commit comments