2020USA
2121"""
2222
23+ from heapq import heapify , heappop , heappush
2324from typing import Dict , Iterable , List , Optional , Set , Tuple , Union , cast
2425
2526from ._dns import (
4344_float = float
4445_int = int
4546
47+ # The minimum number of scheduled record expirations before we start cleaning up
48+ # the expiration heap. This is a performance optimization to avoid cleaning up the
49+ # heap too often when there are only a few scheduled expirations.
50+ _MIN_SCHEDULED_RECORD_EXPIRATION = 100
51+
4652
4753def _remove_key (cache : _DNSRecordCacheType , key : _str , record : _DNSRecord ) -> None :
4854 """Remove a key from a DNSRecord cache
@@ -60,6 +66,8 @@ class DNSCache:
6066
6167 def __init__ (self ) -> None :
6268 self .cache : _DNSRecordCacheType = {}
69+ self ._expire_heap : List [Tuple [float , DNSRecord ]] = []
70+ self ._expirations : Dict [DNSRecord , float ] = {}
6371 self .service_cache : _DNSRecordCacheType = {}
6472
6573 # Functions prefixed with async_ are NOT threadsafe and must
@@ -81,6 +89,12 @@ def _async_add(self, record: _DNSRecord) -> bool:
8189 store = self .cache .setdefault (record .key , {})
8290 new = record not in store and not isinstance (record , DNSNsec )
8391 store [record ] = record
92+ when = record .created + (record .ttl * 1000 )
93+ if self ._expirations .get (record ) != when :
94+ # Avoid adding duplicates to the heap
95+ heappush (self ._expire_heap , (when , record ))
96+ self ._expirations [record ] = when
97+
8498 if isinstance (record , DNSService ):
8599 service_record = record
86100 self .service_cache .setdefault (record .server_key , {})[service_record ] = service_record
@@ -108,6 +122,7 @@ def _async_remove(self, record: _DNSRecord) -> None:
108122 service_record = record
109123 _remove_key (self .service_cache , service_record .server_key , service_record )
110124 _remove_key (self .cache , record .key , record )
125+ self ._expirations .pop (record , None )
111126
112127 def async_remove_records (self , entries : Iterable [DNSRecord ]) -> None :
113128 """Remove multiple records.
@@ -121,8 +136,44 @@ def async_expire(self, now: _float) -> List[DNSRecord]:
121136 """Purge expired entries from the cache.
122137
123138 This function must be run in from event loop.
139+
140+ :param now: The current time in milliseconds.
124141 """
125- expired = [record for records in self .cache .values () for record in records if record .is_expired (now )]
142+ if not (expire_heap_len := len (self ._expire_heap )):
143+ return []
144+
145+ expired : List [DNSRecord ] = []
146+ # Find any expired records and add them to the to-delete list
147+ while self ._expire_heap :
148+ when , record = self ._expire_heap [0 ]
149+ if when > now :
150+ break
151+ heappop (self ._expire_heap )
152+ # Check if the record hasn't been re-added to the heap
153+ # with a different expiration time as it will be removed
154+ # later when it reaches the top of the heap and its
155+ # expiration time is met.
156+ if self ._expirations .get (record ) == when :
157+ expired .append (record )
158+
159+ # If the expiration heap grows larger than the number expirations
160+ # times two, we clean it up to avoid keeping expired entries in
161+ # the heap and consuming memory. We guard this with a minimum
162+ # threshold to avoid cleaning up the heap too often when there are
163+ # only a few scheduled expirations.
164+ if (
165+ expire_heap_len > _MIN_SCHEDULED_RECORD_EXPIRATION
166+ and expire_heap_len > len (self ._expirations ) * 2
167+ ):
168+ # Remove any expired entries from the expiration heap
169+ # that do not match the expiration time in the expirations
170+ # as it means the record has been re-added to the heap
171+ # with a different expiration time.
172+ self ._expire_heap = [
173+ entry for entry in self ._expire_heap if self ._expirations .get (entry [1 ]) == entry [0 ]
174+ ]
175+ heapify (self ._expire_heap )
176+
126177 self .async_remove_records (expired )
127178 return expired
128179
@@ -256,4 +307,11 @@ def async_mark_unique_records_older_than_1s_to_expire(
256307 created_double = record .created
257308 if (now - created_double > _ONE_SECOND ) and record not in answers_rrset :
258309 # Expire in 1s
259- record .set_created_ttl (now , 1 )
310+ self ._async_set_created_ttl (record , now , 1 )
311+
312+ def _async_set_created_ttl (self , record : DNSRecord , now : _float , ttl : _float ) -> None :
313+ """Set the created time and ttl of a record."""
314+ # It would be better if we made a copy instead of mutating the record
315+ # in place, but records currently don't have a copy method.
316+ record ._set_created_ttl (now , ttl )
317+ self ._async_add (record )
0 commit comments