Skip to content

Commit f24fdaf

Browse files
committed
feat: validate and autodetect ip_version for multicast_addresses, dedup against interfaces
1 parent b1aac6e commit f24fdaf

4 files changed

Lines changed: 93 additions & 11 deletions

File tree

src/zeroconf/_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def __init__(
195195
receive queries arriving on them.
196196
"""
197197
if ip_version is None:
198-
ip_version = autodetect_ip_version(interfaces)
198+
ip_version = autodetect_ip_version(interfaces, multicast_addresses)
199199

200200
self.done = False
201201

src/zeroconf/_utils/net.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -458,15 +458,34 @@ def create_sockets(
458458
if multicast_addresses and unicast:
459459
raise ValueError("multicast_addresses is incompatible with unicast=True")
460460

461+
# Reject IP-version-incompatible entries up front so callers get a clear
462+
# error instead of a confusing adapter-lookup or socket-syscall failure.
463+
if multicast_addresses:
464+
if ip_version == IPVersion.V4Only:
465+
for entry in multicast_addresses:
466+
if isinstance(entry, (int, tuple)) or (
467+
isinstance(entry, str) and ipaddress.ip_address(entry).version == 6
468+
):
469+
raise ValueError("multicast_addresses contains IPv6 entries but ip_version is V4Only")
470+
elif ip_version == IPVersion.V6Only:
471+
for entry in multicast_addresses:
472+
if isinstance(entry, str) and ipaddress.ip_address(entry).version == 4:
473+
raise ValueError("multicast_addresses contains IPv4 entries but ip_version is V6Only")
474+
461475
if unicast:
462476
listen_socket = None
463477
else:
464478
listen_socket = new_socket(bind_addr=("",), ip_version=ip_version, apple_p2p=apple_p2p)
465479

466480
normalized_interfaces = normalize_interface_choice(interfaces, ip_version)
467-
extra_multicast_members = (
468-
normalize_interface_choice(list(multicast_addresses), ip_version) if multicast_addresses else []
469-
)
481+
if multicast_addresses:
482+
extra_multicast_members = normalize_interface_choice(list(multicast_addresses), ip_version)
483+
# Strip entries already covered by ``interfaces`` so add_multicast_member
484+
# is not called twice for the same membership.
485+
interface_set = set(normalized_interfaces)
486+
extra_multicast_members = [m for m in extra_multicast_members if m not in interface_set]
487+
else:
488+
extra_multicast_members = []
470489

471490
# If we are using InterfaceChoice.Default with only IPv4 or only IPv6, we can use
472491
# a single socket to listen and respond.
@@ -510,17 +529,29 @@ def can_send_to(ipv6_socket: bool, address: str) -> bool:
510529
return ":" in address if ipv6_socket else ":" not in address
511530

512531

513-
def autodetect_ip_version(interfaces: InterfacesType) -> IPVersion:
532+
def autodetect_ip_version(
533+
interfaces: InterfacesType,
534+
multicast_addresses: Sequence[str | int | tuple[tuple[str, int, int], int]] | None = None,
535+
) -> IPVersion:
514536
"""Auto detect the IP version when it is not provided."""
537+
has_v6 = False
538+
has_v4 = False
515539
if isinstance(interfaces, list):
516540
has_v6 = any(
517541
isinstance(i, int) or (isinstance(i, str) and ipaddress.ip_address(i).version == 6)
518542
for i in interfaces
519543
)
520544
has_v4 = any(isinstance(i, str) and ipaddress.ip_address(i).version == 4 for i in interfaces)
521-
if has_v4 and has_v6:
522-
return IPVersion.All
523-
if has_v6:
524-
return IPVersion.V6Only
525-
545+
if multicast_addresses:
546+
has_v6 = has_v6 or any(
547+
isinstance(i, (int, tuple)) or (isinstance(i, str) and ipaddress.ip_address(i).version == 6)
548+
for i in multicast_addresses
549+
)
550+
has_v4 = has_v4 or any(
551+
isinstance(i, str) and ipaddress.ip_address(i).version == 4 for i in multicast_addresses
552+
)
553+
if has_v4 and has_v6:
554+
return IPVersion.All
555+
if has_v6:
556+
return IPVersion.V6Only
526557
return IPVersion.V4Only

tests/test_core.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,6 @@ def test_multicast_addresses_forwarded_to_create_sockets(self):
190190
zc = r.Zeroconf(
191191
interfaces=["127.0.0.1"],
192192
multicast_addresses=["192.168.1.5"],
193-
unicast=True,
194193
)
195194
try:
196195
_, kwargs = mock_create.call_args

tests/utils/test_net.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,3 +524,55 @@ def test_create_sockets_multicast_addresses_default_path() -> None:
524524
joined = [c.args[1] for c in mock_add.call_args_list if c.args[0] is listen_mock]
525525
assert "0.0.0.0" in joined
526526
assert "192.168.1.5" in joined
527+
528+
529+
def test_create_sockets_multicast_addresses_v4_rejects_v6_entry() -> None:
530+
"""V4Only listen socket rejects IPv6 multicast_addresses entries."""
531+
with pytest.raises(ValueError, match="IPv6"):
532+
r.create_sockets(
533+
interfaces=["127.0.0.1"],
534+
multicast_addresses=["2001:db8::"],
535+
ip_version=r.IPVersion.V4Only,
536+
)
537+
538+
539+
def test_create_sockets_multicast_addresses_v6_rejects_v4_entry() -> None:
540+
"""V6Only listen socket rejects IPv4 multicast_addresses entries."""
541+
with pytest.raises(ValueError, match="IPv4"):
542+
r.create_sockets(
543+
interfaces=[1],
544+
multicast_addresses=["192.168.1.5"],
545+
ip_version=r.IPVersion.V6Only,
546+
)
547+
548+
549+
def test_create_sockets_multicast_addresses_deduped_against_interfaces() -> None:
550+
"""Addresses present in both interfaces and multicast_addresses join only once."""
551+
listen_mock = Mock(spec=socket.socket)
552+
respond_mock = Mock(spec=socket.socket)
553+
554+
def _new_socket(bind_addr, **kwargs):
555+
return listen_mock if bind_addr == ("",) else respond_mock
556+
557+
with (
558+
patch("zeroconf._utils.net.new_socket", side_effect=_new_socket),
559+
patch("zeroconf._utils.net.add_multicast_member", return_value=True) as mock_add,
560+
patch("zeroconf._utils.net.set_respond_socket_multicast_options"),
561+
patch("zeroconf._utils.net.socket.socket.setsockopt"),
562+
):
563+
r.create_sockets(
564+
interfaces=["127.0.0.1"],
565+
multicast_addresses=["127.0.0.1", "192.168.1.5"],
566+
ip_version=r.IPVersion.V4Only,
567+
)
568+
569+
joined_127 = [c for c in mock_add.call_args_list if c.args[0] is listen_mock and c.args[1] == "127.0.0.1"]
570+
assert len(joined_127) == 1
571+
572+
573+
def test_autodetect_ip_version_includes_multicast_addresses() -> None:
574+
"""autodetect_ip_version sees IPv6 entries from multicast_addresses."""
575+
assert (
576+
netutils.autodetect_ip_version(["127.0.0.1"], multicast_addresses=["2001:db8::"]) is r.IPVersion.All
577+
)
578+
assert netutils.autodetect_ip_version([], multicast_addresses=["2001:db8::"]) is r.IPVersion.V6Only

0 commit comments

Comments
 (0)