Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/zeroconf/_services/registry.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,19 @@ cdef class ServiceRegistry:
cdef public bint has_entries

@cython.locals(
record_list=cython.list,
record_keys=cython.dict,
)
cdef cython.list _async_get_by_index(self, cython.dict records, str key)

cdef _add(self, ServiceInfo info)

@cython.locals(
info=ServiceInfo,
old_service_info=ServiceInfo
old_service_info=ServiceInfo,
type_bucket=cython.dict,
server_bucket=cython.dict,
type_key=str,
server_key=str,
)
cdef _remove(self, cython.list infos)

Expand Down
31 changes: 21 additions & 10 deletions src/zeroconf/_services/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def __init__(
) -> None:
"""Create the ServiceRegistry class."""
self._services: dict[str, ServiceInfo] = {}
self.types: dict[str, list] = {}
self.servers: dict[str, list] = {}
self.types: dict[str, dict[str, None]] = {}
self.servers: dict[str, dict[str, None]] = {}
self.has_entries: bool = False

def async_add(self, info: ServiceInfo) -> None:
Expand Down Expand Up @@ -79,12 +79,12 @@ def async_get_infos_server(self, server: str) -> list[ServiceInfo]:
"""Return all ServiceInfo matching server."""
return self._async_get_by_index(self.servers, server)

def _async_get_by_index(self, records: dict[str, list], key: _str) -> list[ServiceInfo]:
def _async_get_by_index(self, records: dict[str, dict[str, None]], key: _str) -> list[ServiceInfo]:
"""Return all ServiceInfo matching the index."""
record_list = records.get(key)
if record_list is None:
record_keys = records.get(key)
if record_keys is None:
return []
return [self._services[name] for name in record_list]
return [self._services[name] for name in record_keys]

def _add(self, info: ServiceInfo) -> None:
"""Add a new service under the lock."""
Expand All @@ -94,8 +94,11 @@ def _add(self, info: ServiceInfo) -> None:

info.async_clear_cache()
self._services[info.key] = info
self.types.setdefault(info.type.lower(), []).append(info.key)
self.servers.setdefault(info.server_key, []).append(info.key)
# dict[str, None] gives O(1) add/remove while preserving insertion order
# so async_get_infos_type / async_get_infos_server return entries in the
# order they were registered.
self.types.setdefault(info.type.lower(), {})[info.key] = None
self.servers.setdefault(info.server_key, {})[info.key] = None
self.has_entries = True

def _remove(self, infos: list[ServiceInfo]) -> None:
Expand All @@ -105,8 +108,16 @@ def _remove(self, infos: list[ServiceInfo]) -> None:
if old_service_info is None:
continue
assert old_service_info.server_key is not None
self.types[old_service_info.type.lower()].remove(info.key)
self.servers[old_service_info.server_key].remove(info.key)
type_key = old_service_info.type.lower()
server_key = old_service_info.server_key
type_bucket = self.types[type_key]
del type_bucket[info.key]
if not type_bucket:
del self.types[type_key]
server_bucket = self.servers[server_key]
del server_bucket[info.key]
if not server_bucket:
del self.servers[server_key]
del self._services[info.key]

self.has_entries = bool(self._services)
77 changes: 77 additions & 0 deletions tests/services/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,80 @@ def test_lookups_upper_case_by_lower_case(self):
assert registry.async_get_infos_type(type_.lower()) == [info]
assert registry.async_get_infos_server("ash-2.local.") == [info]
assert registry.async_get_types() == [type_.lower()]

def test_empty_buckets_are_removed_when_last_entry_is_removed(self):
type_ = "_test-srvc-type._tcp.local."
registration_name = f"xxxyyy.{type_}"
desc = {"path": "/~paulsm/"}
info = ServiceInfo(
type_,
registration_name,
80,
0,
0,
desc,
"ash-2.local.",
addresses=[socket.inet_aton("10.0.1.2")],
)

registry = r.ServiceRegistry()
registry.async_add(info)
registry.async_remove(info)

assert type_.lower() not in registry.types
assert "ash-2.local." not in registry.servers
assert registry.async_get_types() == []

def test_bulk_remove_preserves_remaining_insertion_order(self):
type_ = "_test-srvc-type._tcp.local."
server = "shared.local."
desc = {"path": "/~paulsm/"}
infos = [
ServiceInfo(
type_,
f"svc{i}.{type_}",
80,
0,
0,
desc,
server,
addresses=[socket.inet_aton("10.0.1.2")],
)
for i in range(20)
]

registry = r.ServiceRegistry()
for info in infos:
registry.async_add(info)

# Remove every other entry in one bulk call.
to_remove = [infos[i] for i in range(0, 20, 2)]
registry.async_remove(to_remove)

expected = [infos[i] for i in range(1, 20, 2)]
assert registry.async_get_infos_type(type_) == expected
assert registry.async_get_infos_server(server) == expected

def test_bulk_remove_then_readd_under_same_key(self):
"""Re-adding after the bucket was deleted must rebuild it cleanly."""
type_ = "_test-srvc-type._tcp.local."
server = "ash-2.local."
desc = {"path": "/~paulsm/"}
info = ServiceInfo(
type_,
f"only.{type_}",
80,
0,
0,
desc,
server,
addresses=[socket.inet_aton("10.0.1.2")],
)

registry = r.ServiceRegistry()
registry.async_add(info)
registry.async_remove(info)
assert type_.lower() not in registry.types
registry.async_add(info)
assert registry.async_get_infos_type(type_) == [info]
assert registry.async_get_infos_server(server) == [info]
Loading