Skip to content

Commit a6ad100

Browse files
bdracojstasiak
authored andcommitted
Add support for multiple types to ServiceBrowsers
As each ServiceBrowser runs in its own thread there is a scale problem when listening for many types. ServiceBrowser can now accept a list of types in addition to a single type.
1 parent 24a0619 commit a6ad100

3 files changed

Lines changed: 126 additions & 40 deletions

File tree

examples/browser.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ def on_service_state_change(
5555

5656
zeroconf = Zeroconf(ip_version=ip_version)
5757
print("\nBrowsing services, press Ctrl-C to exit...\n")
58-
browser = ServiceBrowser(zeroconf, "_http._tcp.local.", handlers=[on_service_state_change])
58+
browser = ServiceBrowser(
59+
zeroconf, ["_http._tcp.local.", "_hap._tcp.local."], handlers=[on_service_state_change]
60+
)
5961

6062
try:
6163
while True:

zeroconf/__init__.py

Lines changed: 50 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1379,7 +1379,7 @@ class ServiceBrowser(RecordUpdateListener, threading.Thread):
13791379
def __init__(
13801380
self,
13811381
zc: 'Zeroconf',
1382-
type_: str,
1382+
type_: Union[str, list],
13831383
# NOTE: Callable quoting needed on Python 3.5.2, see
13841384
# https://github.com/jstasiak/python-zeroconf/issues/208 for details.
13851385
handlers: Optional[Union[ServiceListener, List['Callable[..., None]']]] = None,
@@ -1390,19 +1390,23 @@ def __init__(
13901390
) -> None:
13911391
"""Creates a browser for a specific type"""
13921392
assert handlers or listener, 'You need to specify at least one handler'
1393-
if not type_.endswith(service_type_name(type_, allow_underscores=True)):
1394-
raise BadTypeInNameException
1395-
threading.Thread.__init__(self, name='zeroconf-ServiceBrowser_' + type_)
1393+
self.types = set(type_ if isinstance(type_, list) else [type_])
1394+
for check_type_ in self.types:
1395+
if not check_type_.endswith(service_type_name(check_type_, allow_underscores=True)):
1396+
raise BadTypeInNameException
1397+
threading.Thread.__init__(self, name='zeroconf-ServiceBrowser_' + '-'.join(self.types))
13961398
self.daemon = True
13971399
self.zc = zc
1398-
self.type = type_
13991400
self.addr = addr
14001401
self.port = port
14011402
self.multicast = self.addr in (None, _MDNS_ADDR, _MDNS_ADDR6)
1402-
self.services = {} # type: Dict[str, DNSRecord]
1403-
self.next_time = current_time_millis()
1404-
self.delay = delay
1405-
self._handlers_to_call = OrderedDict() # type: OrderedDict[str, ServiceStateChange]
1403+
self._services = {
1404+
check_type_: {} for check_type_ in self.types
1405+
} # type: Dict[str, Dict[str, DNSRecord]]
1406+
current_time = current_time_millis()
1407+
self._next_time = {check_type_: current_time for check_type_ in self.types}
1408+
self._delay = {check_type_: delay for check_type_ in self.types}
1409+
self._handlers_to_call = OrderedDict() # type: OrderedDict[str, Tuple[str, ServiceStateChange]]
14061410

14071411
self._service_state_changed = Signal()
14081412

@@ -1453,7 +1457,7 @@ def update_record(self, zc: 'Zeroconf', now: float, record: DNSRecord) -> None:
14531457
14541458
"""
14551459

1456-
def enqueue_callback(state_change: ServiceStateChange, name: str) -> None:
1460+
def enqueue_callback(state_change: ServiceStateChange, type_: str, name: str) -> None:
14571461

14581462
# Code to ensure we only do a single update message
14591463
# Precedence is; Added, Remove, Update
@@ -1470,29 +1474,29 @@ def enqueue_callback(state_change: ServiceStateChange, name: str) -> None:
14701474
)
14711475
or (state_change is ServiceStateChange.Updated and name not in self._handlers_to_call)
14721476
):
1473-
self._handlers_to_call[name] = state_change
1477+
self._handlers_to_call[name] = (type_, state_change)
14741478

1475-
if record.type == _TYPE_PTR and record.name == self.type:
1479+
if record.type == _TYPE_PTR and record.name in self.types:
14761480
assert isinstance(record, DNSPointer)
14771481
expired = record.is_expired(now)
14781482
service_key = record.alias.lower()
14791483
try:
1480-
old_record = self.services[service_key]
1484+
old_record = self._services[record.name][service_key]
14811485
except KeyError:
14821486
if not expired:
1483-
self.services[service_key] = record
1484-
enqueue_callback(ServiceStateChange.Added, record.alias)
1487+
self._services[record.name][service_key] = record
1488+
enqueue_callback(ServiceStateChange.Added, record.name, record.alias)
14851489
else:
14861490
if not expired:
14871491
old_record.reset_ttl(record)
14881492
else:
1489-
del self.services[service_key]
1490-
enqueue_callback(ServiceStateChange.Removed, record.alias)
1493+
del self._services[record.name][service_key]
1494+
enqueue_callback(ServiceStateChange.Removed, record.name, record.alias)
14911495
return
14921496

14931497
expires = record.get_expiration_time(75)
1494-
if expires < self.next_time:
1495-
self.next_time = expires
1498+
if expires < self._next_time[record.name]:
1499+
self._next_time[record.name] = expires
14961500

14971501
elif record.type == _TYPE_A or record.type == _TYPE_AAAA:
14981502
assert isinstance(record, DNSAddress)
@@ -1513,49 +1517,56 @@ def enqueue_callback(state_change: ServiceStateChange, name: str) -> None:
15131517

15141518
# Iterate through the DNSCache and callback any services that use this address
15151519
for service in zc.cache.entries():
1516-
if (
1517-
isinstance(service, DNSService)
1518-
and service.name.endswith(self.type)
1519-
and service.server == record.name
1520-
):
1521-
enqueue_callback(ServiceStateChange.Updated, service.name)
1520+
if not isinstance(service, DNSService) or not service.server == record.name:
1521+
continue
1522+
for type_ in self.types:
1523+
if service.name.endswith(type_):
1524+
enqueue_callback(ServiceStateChange.Updated, type_, service.name)
15221525

1523-
elif record.name.endswith(self.type):
1524-
expired = record.is_expired(now)
1525-
if not expired:
1526-
enqueue_callback(ServiceStateChange.Updated, record.name)
1526+
elif not record.is_expired(now):
1527+
for type_ in self.types:
1528+
if record.name.endswith(type_):
1529+
enqueue_callback(ServiceStateChange.Updated, type_, record.name)
15271530

15281531
def cancel(self) -> None:
15291532
self.done = True
15301533
self.zc.remove_listener(self)
15311534
self.join()
15321535

15331536
def run(self) -> None:
1534-
self.zc.add_listener(self, DNSQuestion(self.type, _TYPE_PTR, _CLASS_IN))
1537+
for type_ in self.types:
1538+
self.zc.add_listener(self, DNSQuestion(type_, _TYPE_PTR, _CLASS_IN))
15351539

15361540
while True:
15371541
now = current_time_millis()
1538-
if len(self._handlers_to_call) == 0 and self.next_time > now:
1539-
self.zc.wait(self.next_time - now)
1542+
# Wait for the type has the smallest next time
1543+
next_time = min(self._next_time.values())
1544+
if len(self._handlers_to_call) == 0 and next_time > now:
1545+
self.zc.wait(next_time - now)
15401546
if self.zc.done or self.done:
15411547
return
15421548
now = current_time_millis()
1543-
if self.next_time <= now:
1549+
for type_ in self.types:
1550+
if self._next_time[type_] > now:
1551+
continue
15441552
out = DNSOutgoing(_FLAGS_QR_QUERY, multicast=self.multicast)
1545-
out.add_question(DNSQuestion(self.type, _TYPE_PTR, _CLASS_IN))
1546-
for record in self.services.values():
1553+
out.add_question(DNSQuestion(type_, _TYPE_PTR, _CLASS_IN))
1554+
for record in self._services[type_].values():
15471555
if not record.is_stale(now):
15481556
out.add_answer_at_time(record, now)
15491557

15501558
self.zc.send(out, addr=self.addr, port=self.port)
1551-
self.next_time = now + self.delay
1552-
self.delay = min(_BROWSER_BACKOFF_LIMIT * 1000, self.delay * 2)
1559+
self._next_time[type_] = now + self._delay[type_]
1560+
self._delay[type_] = min(_BROWSER_BACKOFF_LIMIT * 1000, self._delay[type_] * 2)
15531561

15541562
if len(self._handlers_to_call) > 0 and not self.zc.done:
15551563
with self.zc._handlers_lock:
1556-
handler = self._handlers_to_call.popitem(False)
1564+
(name, service_type_state_change) = self._handlers_to_call.popitem(False)
15571565
self._service_state_changed.fire(
1558-
zeroconf=self.zc, service_type=self.type, name=handler[0], state_change=handler[1]
1566+
zeroconf=self.zc,
1567+
service_type=service_type_state_change[0],
1568+
name=name,
1569+
state_change=service_type_state_change[1],
15591570
)
15601571

15611572

zeroconf/test.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1185,6 +1185,79 @@ def mock_incoming_msg(service_state_change: r.ServiceStateChange) -> r.DNSIncomi
11851185
zeroconf.close()
11861186

11871187

1188+
class TestServiceBrowserMultipleTypes(unittest.TestCase):
1189+
def test_update_record(self):
1190+
1191+
service_names = ['name._type._tcp.local.', 'name._type._udp.local']
1192+
service_types = ['_type._tcp.local.', '_type._udp.local.']
1193+
1194+
service_added_count = 0
1195+
service_removed_count = 0
1196+
service_add_event = Event()
1197+
service_removed_event = Event()
1198+
1199+
class MyServiceListener(r.ServiceListener):
1200+
def add_service(self, zc, type_, name) -> None:
1201+
nonlocal service_added_count
1202+
service_added_count += 1
1203+
if service_added_count == 2:
1204+
service_add_event.set()
1205+
1206+
def remove_service(self, zc, type_, name) -> None:
1207+
nonlocal service_removed_count
1208+
service_removed_count += 1
1209+
if service_removed_count == 2:
1210+
service_removed_event.set()
1211+
1212+
def mock_incoming_msg(
1213+
service_state_change: r.ServiceStateChange, service_type: str, service_name: str
1214+
) -> r.DNSIncoming:
1215+
generated = r.DNSOutgoing(r._FLAGS_QR_RESPONSE)
1216+
1217+
if service_state_change == r.ServiceStateChange.Removed:
1218+
ttl = 0
1219+
else:
1220+
ttl = 120
1221+
1222+
generated.add_answer_at_time(
1223+
r.DNSPointer(service_type, r._TYPE_PTR, r._CLASS_IN, ttl, service_name), 0
1224+
)
1225+
return r.DNSIncoming(generated.packet())
1226+
1227+
zeroconf = r.Zeroconf(interfaces=['127.0.0.1'])
1228+
service_browser = r.ServiceBrowser(zeroconf, service_types, listener=MyServiceListener())
1229+
1230+
try:
1231+
wait_time = 3
1232+
1233+
# both services added
1234+
zeroconf.handle_response(
1235+
mock_incoming_msg(r.ServiceStateChange.Added, service_types[0], service_names[0])
1236+
)
1237+
zeroconf.handle_response(
1238+
mock_incoming_msg(r.ServiceStateChange.Added, service_types[1], service_names[1])
1239+
)
1240+
service_add_event.wait(wait_time)
1241+
assert service_added_count == 2
1242+
assert service_removed_count == 0
1243+
1244+
# both services removed
1245+
zeroconf.handle_response(
1246+
mock_incoming_msg(r.ServiceStateChange.Removed, service_types[0], service_names[0])
1247+
)
1248+
zeroconf.handle_response(
1249+
mock_incoming_msg(r.ServiceStateChange.Removed, service_types[1], service_names[1])
1250+
)
1251+
service_removed_event.wait(wait_time)
1252+
assert service_added_count == 2
1253+
assert service_removed_count == 2
1254+
1255+
finally:
1256+
service_browser.cancel()
1257+
zeroconf.remove_all_service_listeners()
1258+
zeroconf.close()
1259+
1260+
11881261
def test_backoff():
11891262
got_query = Event()
11901263

0 commit comments

Comments
 (0)