--- a/src/zeroconf/_protocol/incoming.pxd +++ b/src/zeroconf/_protocol/incoming.pxd @@ -83,7 +83,7 @@ link_py_int=object, linked_labels=cython.list ) - cdef unsigned int _decode_labels_at_offset(self, unsigned int off, cython.list labels, cython.set seen_pointers) + cdef unsigned int _decode_labels_at_offset(self, unsigned int off, cython.list labels, cython.set seen_pointers, unsigned int depth) @cython.locals(offset="unsigned int") cdef void _read_header(self) --- a/src/zeroconf/_protocol/incoming.py +++ b/src/zeroconf/_protocol/incoming.py @@ -60,7 +60,7 @@ MAX_DNS_LABELS = 128 MAX_NAME_LENGTH = 253 -DECODE_EXCEPTIONS = (IndexError, struct.error, IncomingDecodeError) +DECODE_EXCEPTIONS = (IndexError, struct.error, IncomingDecodeError, RecursionError) _seen_logs: dict[str, int | tuple] = {} @@ -409,7 +409,7 @@ labels: list[str] = [] seen_pointers: set[int] = set() original_offset = self.offset - self.offset = self._decode_labels_at_offset(original_offset, labels, seen_pointers) + self.offset = self._decode_labels_at_offset(original_offset, labels, seen_pointers, 0) self._name_cache[original_offset] = labels name = ".".join(labels) + "." if len(name) > MAX_NAME_LENGTH: @@ -418,8 +418,14 @@ ) return name - def _decode_labels_at_offset(self, off: _int, labels: list[str], seen_pointers: set[int]) -> int: + def _decode_labels_at_offset( + self, off: _int, labels: list[str], seen_pointers: set[int], depth: _int + ) -> int: # This is a tight loop that is called frequently, small optimizations can make a difference. + if depth > MAX_DNS_LABELS: + raise IncomingDecodeError( + f"DNS compression pointer chain exceeds {MAX_DNS_LABELS} at {off} from {self.source}" + ) view = self.view while off < self._data_len: length = view[off] @@ -457,7 +463,7 @@ if not linked_labels: linked_labels = [] seen_pointers.add(link_py_int) - self._decode_labels_at_offset(link, linked_labels, seen_pointers) + self._decode_labels_at_offset(link, linked_labels, seen_pointers, depth + 1) self._name_cache[link_py_int] = linked_labels labels.extend(linked_labels) if len(labels) > MAX_DNS_LABELS: --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -1011,6 +1011,28 @@ assert len(parsed.answers()) == 1 +def test_dns_compression_pointer_chain_depth_attack() -> None: + """Test our wire parser rejects deeply chained compression pointers without recursing.""" + # Build a packet with one question whose name is a 1500-deep chain of forward + # compression pointers, ending in a root label. Each pointer is 2 bytes, + # so chain length easily exceeds CPython's default recursion limit. + header = b"\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00" + # Question at offset 12: pointer to offset 18 (past the question's type/class). + question_name = bytes([0xC0, 18]) + question_type_class = b"\x00\x01\x00\x01" + chain_depth = 1500 + chain = bytearray() + for i in range(chain_depth): + target = 18 + 2 * (i + 1) + chain.append(0xC0 | (target >> 8)) + chain.append(target & 0xFF) + chain.append(0x00) + packet = header + question_name + question_type_class + bytes(chain) + parsed = r.DNSIncoming(packet, ("1.2.3.4", 5353)) + assert parsed.valid is False + assert parsed.questions == [] + + def test_dns_compression_loop_attack(): """Test our wire parser does not loop forever when dns compression is in a loop.""" packet = (