Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/zeroconf/_protocol/incoming.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ cdef class DNSIncoming:
link_py_int=object,
linked_labels=cython.list
)
cdef unsigned int _decode_labels_at_offset(self, unsigned int off, cython.list labels, cython.set seen_pointers)
cdef unsigned int _decode_labels_at_offset(self, unsigned int off, cython.list labels, cython.set seen_pointers, unsigned int depth)

@cython.locals(offset="unsigned int")
cdef void _read_header(self)
Expand Down
14 changes: 10 additions & 4 deletions src/zeroconf/_protocol/incoming.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
MAX_DNS_LABELS = 128
MAX_NAME_LENGTH = 253

DECODE_EXCEPTIONS = (IndexError, struct.error, IncomingDecodeError)
DECODE_EXCEPTIONS = (IndexError, struct.error, IncomingDecodeError, RecursionError)


_seen_logs: dict[str, int | tuple] = {}
Expand Down Expand Up @@ -409,7 +409,7 @@ def _read_name(self) -> str:
labels: list[str] = []
seen_pointers: set[int] = set()
original_offset = self.offset
self.offset = self._decode_labels_at_offset(original_offset, labels, seen_pointers)
self.offset = self._decode_labels_at_offset(original_offset, labels, seen_pointers, 0)
self._name_cache[original_offset] = labels
name = ".".join(labels) + "."
if len(name) > MAX_NAME_LENGTH:
Expand All @@ -418,8 +418,14 @@ def _read_name(self) -> str:
)
return name

def _decode_labels_at_offset(self, off: _int, labels: list[str], seen_pointers: set[int]) -> int:
def _decode_labels_at_offset(
self, off: _int, labels: list[str], seen_pointers: set[int], depth: _int
) -> int:
# This is a tight loop that is called frequently, small optimizations can make a difference.
if depth > MAX_DNS_LABELS:
raise IncomingDecodeError(
f"DNS compression pointer chain exceeds {MAX_DNS_LABELS} at {off} from {self.source}"
)
view = self.view
while off < self._data_len:
length = view[off]
Expand Down Expand Up @@ -457,7 +463,7 @@ def _decode_labels_at_offset(self, off: _int, labels: list[str], seen_pointers:
if not linked_labels:
linked_labels = []
seen_pointers.add(link_py_int)
self._decode_labels_at_offset(link, linked_labels, seen_pointers)
self._decode_labels_at_offset(link, linked_labels, seen_pointers, depth + 1)
self._name_cache[link_py_int] = linked_labels
labels.extend(linked_labels)
if len(labels) > MAX_DNS_LABELS:
Expand Down
22 changes: 22 additions & 0 deletions tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,28 @@ def test_label_compression_attack():
assert len(parsed.answers()) == 1


def test_dns_compression_pointer_chain_depth_attack() -> None:
"""Test our wire parser rejects deeply chained compression pointers without recursing."""
# Build a packet with one question whose name is a 1500-deep chain of forward
# compression pointers, ending in a root label. Each pointer is 2 bytes,
# so chain length easily exceeds CPython's default recursion limit.
header = b"\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00"
# Question at offset 12: pointer to offset 18 (past the question's type/class).
question_name = bytes([0xC0, 18])
question_type_class = b"\x00\x01\x00\x01"
chain_depth = 1500
chain = bytearray()
for i in range(chain_depth):
target = 18 + 2 * (i + 1)
chain.append(0xC0 | (target >> 8))
chain.append(target & 0xFF)
chain.append(0x00)
packet = header + question_name + question_type_class + bytes(chain)
parsed = r.DNSIncoming(packet, ("1.2.3.4", 5353))
assert parsed.valid is False
assert parsed.questions == []


def test_dns_compression_loop_attack():
"""Test our wire parser does not loop forever when dns compression is in a loop."""
packet = (
Expand Down
Loading