Skip to content

Commit f9e2359

Browse files
authored
fix: bound DNS compression-pointer chain depth in DNSIncoming (#1719)
1 parent 64d143d commit f9e2359

3 files changed

Lines changed: 33 additions & 5 deletions

File tree

src/zeroconf/_protocol/incoming.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ cdef class DNSIncoming:
8383
link_py_int=object,
8484
linked_labels=cython.list
8585
)
86-
cdef unsigned int _decode_labels_at_offset(self, unsigned int off, cython.list labels, cython.set seen_pointers)
86+
cdef unsigned int _decode_labels_at_offset(self, unsigned int off, cython.list labels, cython.set seen_pointers, unsigned int depth)
8787

8888
@cython.locals(offset="unsigned int")
8989
cdef void _read_header(self)

src/zeroconf/_protocol/incoming.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
MAX_DNS_LABELS = 128
6161
MAX_NAME_LENGTH = 253
6262

63-
DECODE_EXCEPTIONS = (IndexError, struct.error, IncomingDecodeError)
63+
DECODE_EXCEPTIONS = (IndexError, struct.error, IncomingDecodeError, RecursionError)
6464

6565

6666
_seen_logs: dict[str, int | tuple] = {}
@@ -409,7 +409,7 @@ def _read_name(self) -> str:
409409
labels: list[str] = []
410410
seen_pointers: set[int] = set()
411411
original_offset = self.offset
412-
self.offset = self._decode_labels_at_offset(original_offset, labels, seen_pointers)
412+
self.offset = self._decode_labels_at_offset(original_offset, labels, seen_pointers, 0)
413413
self._name_cache[original_offset] = labels
414414
name = ".".join(labels) + "."
415415
if len(name) > MAX_NAME_LENGTH:
@@ -418,8 +418,14 @@ def _read_name(self) -> str:
418418
)
419419
return name
420420

421-
def _decode_labels_at_offset(self, off: _int, labels: list[str], seen_pointers: set[int]) -> int:
421+
def _decode_labels_at_offset(
422+
self, off: _int, labels: list[str], seen_pointers: set[int], depth: _int
423+
) -> int:
422424
# This is a tight loop that is called frequently, small optimizations can make a difference.
425+
if depth > MAX_DNS_LABELS:
426+
raise IncomingDecodeError(
427+
f"DNS compression pointer chain exceeds {MAX_DNS_LABELS} at {off} from {self.source}"
428+
)
423429
view = self.view
424430
while off < self._data_len:
425431
length = view[off]
@@ -457,7 +463,7 @@ def _decode_labels_at_offset(self, off: _int, labels: list[str], seen_pointers:
457463
if not linked_labels:
458464
linked_labels = []
459465
seen_pointers.add(link_py_int)
460-
self._decode_labels_at_offset(link, linked_labels, seen_pointers)
466+
self._decode_labels_at_offset(link, linked_labels, seen_pointers, depth + 1)
461467
self._name_cache[link_py_int] = linked_labels
462468
labels.extend(linked_labels)
463469
if len(labels) > MAX_DNS_LABELS:

tests/test_protocol.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,6 +1011,28 @@ def test_label_compression_attack():
10111011
assert len(parsed.answers()) == 1
10121012

10131013

1014+
def test_dns_compression_pointer_chain_depth_attack() -> None:
1015+
"""Test our wire parser rejects deeply chained compression pointers without recursing."""
1016+
# Build a packet with one question whose name is a 1500-deep chain of forward
1017+
# compression pointers, ending in a root label. Each pointer is 2 bytes,
1018+
# so chain length easily exceeds CPython's default recursion limit.
1019+
header = b"\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00"
1020+
# Question at offset 12: pointer to offset 18 (past the question's type/class).
1021+
question_name = bytes([0xC0, 18])
1022+
question_type_class = b"\x00\x01\x00\x01"
1023+
chain_depth = 1500
1024+
chain = bytearray()
1025+
for i in range(chain_depth):
1026+
target = 18 + 2 * (i + 1)
1027+
chain.append(0xC0 | (target >> 8))
1028+
chain.append(target & 0xFF)
1029+
chain.append(0x00)
1030+
packet = header + question_name + question_type_class + bytes(chain)
1031+
parsed = r.DNSIncoming(packet, ("1.2.3.4", 5353))
1032+
assert parsed.valid is False
1033+
assert parsed.questions == []
1034+
1035+
10141036
def test_dns_compression_loop_attack():
10151037
"""Test our wire parser does not loop forever when dns compression is in a loop."""
10161038
packet = (

0 commit comments

Comments
 (0)