Skip to content

Commit 2e07f25

Browse files
committed
fix: bound DNS compression-pointer chain depth in DNSIncoming
Closes #1713. A crafted mDNS packet can chain thousands of forward compression pointers into a single ~3 kB datagram. ``DNSIncoming._decode_labels_at_offset`` recurses once per pointer it follows (RFC 1035 §4.1.4 name compression), so a deep chain blows through CPython's default recursion limit (``sys.getrecursionlimit() == 1000``). ``RecursionError`` was *not* in ``DECODE_EXCEPTIONS``, so it escaped ``DNSIncoming.__init__`` and bubbled up to asyncio's default exception handler — turning a single small multicast packet from any host on the local link (a guest on the same Wi-Fi, a compromised IoT device, a container on a shared bridge) into sustained CPU burn (each crash unwinds ~1000 frames + the asyncio exception machinery) and debug-log flooding. Home Assistant deployments on Raspberry-Pi-class hardware are the canonical victim. ``seen_pointers`` already blocked cycles and ``MAX_DNS_LABELS = 128`` already capped the *number of labels*, but nothing capped the *chain length of unique forward pointers*. A ``_MAX_MSG_ABSOLUTE`` (8966 B) packet can carry ~4000 2-byte pointers, each pointing to the next. Thread an explicit ``depth`` counter through ``_decode_labels_at_offset`` and raise ``IncomingDecodeError`` when it exceeds ``MAX_DNS_LABELS`` — same bound as the existing label cap, so no new constant. Belt-and- braces, add ``RecursionError`` to ``DECODE_EXCEPTIONS`` so any future regression is contained as an invalid packet rather than logged by asyncio's default handler. ``incoming.pxd`` updated in the same commit so the Cython build picks up the new ``unsigned int depth`` parameter. ``cython -a`` confirms the depth check and ``depth + 1`` recursive-call argument both compile to score-0 (pure C) — direct ``unsigned int`` compare and a C add — with no Python interaction added to the hot path. Function entry score is unchanged (score-11, same boilerplate). Wall-clock smoke test on a typical compressed mDNS response (best-of-10 over 50k iters, REQUIRE_CYTHON=1): 1237.2 ns/parse pre-fix vs 1234.1 ns/parse post-fix — within noise. CWE-674 (Uncontrolled Recursion). LAN-local attack surface only.
1 parent 64d143d commit 2e07f25

3 files changed

Lines changed: 33 additions & 5 deletions

File tree

src/zeroconf/_protocol/incoming.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ cdef class DNSIncoming:
8383
link_py_int=object,
8484
linked_labels=cython.list
8585
)
86-
cdef unsigned int _decode_labels_at_offset(self, unsigned int off, cython.list labels, cython.set seen_pointers)
86+
cdef unsigned int _decode_labels_at_offset(self, unsigned int off, cython.list labels, cython.set seen_pointers, unsigned int depth)
8787

8888
@cython.locals(offset="unsigned int")
8989
cdef void _read_header(self)

src/zeroconf/_protocol/incoming.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
MAX_DNS_LABELS = 128
6161
MAX_NAME_LENGTH = 253
6262

63-
DECODE_EXCEPTIONS = (IndexError, struct.error, IncomingDecodeError)
63+
DECODE_EXCEPTIONS = (IndexError, struct.error, IncomingDecodeError, RecursionError)
6464

6565

6666
_seen_logs: dict[str, int | tuple] = {}
@@ -409,7 +409,7 @@ def _read_name(self) -> str:
409409
labels: list[str] = []
410410
seen_pointers: set[int] = set()
411411
original_offset = self.offset
412-
self.offset = self._decode_labels_at_offset(original_offset, labels, seen_pointers)
412+
self.offset = self._decode_labels_at_offset(original_offset, labels, seen_pointers, 0)
413413
self._name_cache[original_offset] = labels
414414
name = ".".join(labels) + "."
415415
if len(name) > MAX_NAME_LENGTH:
@@ -418,8 +418,14 @@ def _read_name(self) -> str:
418418
)
419419
return name
420420

421-
def _decode_labels_at_offset(self, off: _int, labels: list[str], seen_pointers: set[int]) -> int:
421+
def _decode_labels_at_offset(
422+
self, off: _int, labels: list[str], seen_pointers: set[int], depth: _int
423+
) -> int:
422424
# This is a tight loop that is called frequently, small optimizations can make a difference.
425+
if depth > MAX_DNS_LABELS:
426+
raise IncomingDecodeError(
427+
f"DNS compression pointer chain exceeds {MAX_DNS_LABELS} at {off} from {self.source}"
428+
)
423429
view = self.view
424430
while off < self._data_len:
425431
length = view[off]
@@ -457,7 +463,7 @@ def _decode_labels_at_offset(self, off: _int, labels: list[str], seen_pointers:
457463
if not linked_labels:
458464
linked_labels = []
459465
seen_pointers.add(link_py_int)
460-
self._decode_labels_at_offset(link, linked_labels, seen_pointers)
466+
self._decode_labels_at_offset(link, linked_labels, seen_pointers, depth + 1)
461467
self._name_cache[link_py_int] = linked_labels
462468
labels.extend(linked_labels)
463469
if len(labels) > MAX_DNS_LABELS:

tests/test_protocol.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,6 +1011,28 @@ def test_label_compression_attack():
10111011
assert len(parsed.answers()) == 1
10121012

10131013

1014+
def test_dns_compression_pointer_chain_depth_attack() -> None:
1015+
"""Test our wire parser rejects deeply chained compression pointers without recursing."""
1016+
# Build a packet with one question whose name is a 1500-deep chain of forward
1017+
# compression pointers, ending in a root label. Each pointer is 2 bytes,
1018+
# so chain length easily exceeds CPython's default recursion limit.
1019+
header = b"\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00"
1020+
# Question at offset 12: pointer to offset 18 (past the question's type/class).
1021+
question_name = bytes([0xC0, 18])
1022+
question_type_class = b"\x00\x01\x00\x01"
1023+
chain_depth = 1500
1024+
chain = bytearray()
1025+
for i in range(chain_depth):
1026+
target = 18 + 2 * (i + 1)
1027+
chain.append(0xC0 | (target >> 8))
1028+
chain.append(target & 0xFF)
1029+
chain.append(0x00)
1030+
packet = header + question_name + question_type_class + bytes(chain)
1031+
parsed = r.DNSIncoming(packet, ("1.2.3.4", 5353))
1032+
assert parsed.valid is False
1033+
assert parsed.questions == []
1034+
1035+
10141036
def test_dns_compression_loop_attack():
10151037
"""Test our wire parser does not loop forever when dns compression is in a loop."""
10161038
packet = (

0 commit comments

Comments
 (0)