Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/zeroconf/_handlers/answers.pxd
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@

import cython

from .._dns cimport DNSRecord
from .._protocol.outgoing cimport DNSOutgoing


Expand All @@ -10,7 +11,8 @@ cdef object NAME_GETTER
cpdef construct_outgoing_multicast_answers(cython.dict answers)

cpdef construct_outgoing_unicast_answers(
cython.dict answers, object ucast_source, cython.list questions, object id_
cython.dict answers, bint ucast_source, cython.list questions, object id_
)

@cython.locals(answer=DNSRecord, additionals=cython.set, additional=DNSRecord)
cdef _add_answers_additionals(DNSOutgoing out, cython.dict answers)
3 changes: 2 additions & 1 deletion src/zeroconf/_handlers/answers.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def _add_answers_additionals(out: DNSOutgoing, answers: _AnswerWithAdditionalsTy
# overall size of the outgoing response via name compression
for answer in sorted(answers, key=NAME_GETTER):
out.add_answer_at_time(answer, 0)
for additional in answers[answer]:
additionals = answers[answer]
for additional in additionals:
if additional not in sending:
out.add_additional_answer(additional)
sending.add(additional)
25 changes: 20 additions & 5 deletions src/zeroconf/_protocol/outgoing.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,25 @@ cdef object PACK_BYTE
cdef object PACK_SHORT
cdef object PACK_LONG

cdef object STATE_INIT
cdef object STATE_FINISHED

cdef object LOGGING_IS_ENABLED_FOR
cdef object LOGGING_DEBUG

cdef cython.tuple BYTE_TABLE

cdef class DNSOutgoing:

cdef public unsigned int flags
cdef public object finished
cdef public bint finished
cdef public object id
cdef public bint multicast
cdef public cython.list packets_data
cdef public cython.dict names
cdef public cython.list data
cdef public unsigned int size
cdef public object allow_long
cdef public bint allow_long
cdef public object state
cdef public cython.list questions
cdef public cython.list answers
Expand All @@ -48,18 +56,21 @@ cdef class DNSOutgoing:

cdef _write_int(self, object value)

cdef _write_question(self, DNSQuestion question)
cdef cython.bint _write_question(self, DNSQuestion question)

@cython.locals(
d=cython.bytes,
data_view=cython.list,
length=cython.uint
)
cdef _write_record(self, DNSRecord record, object now)
cdef cython.bint _write_record(self, DNSRecord record, object now)

cdef _write_record_class(self, DNSEntry record)

cdef _check_data_limit_or_rollback(self, object start_data_length, object start_size)
@cython.locals(
start_size_int=object
)
cdef cython.bint _check_data_limit_or_rollback(self, cython.uint start_data_length, cython.uint start_size)

cdef _write_questions_from_offset(self, object questions_offset)

Expand All @@ -74,6 +85,9 @@ cdef class DNSOutgoing:
@cython.locals(
labels=cython.list,
label=cython.str,
index=cython.uint,
start_size=cython.uint,
name_length=cython.uint,
)
cpdef write_name(self, cython.str name)

Expand Down Expand Up @@ -103,6 +117,7 @@ cdef class DNSOutgoing:

cpdef add_answer(self, DNSIncoming inp, DNSRecord record)

@cython.locals(now_float=cython.float)
cpdef add_answer_at_time(self, DNSRecord record, object now)

cpdef add_authorative_answer(self, DNSPointer record)
Expand Down
40 changes: 27 additions & 13 deletions src/zeroconf/_protocol/outgoing.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,21 @@
PACK_SHORT = Struct('>H').pack
PACK_LONG = Struct('>L').pack

BYTE_TABLE = tuple(PACK_BYTE(i) for i in range(256))


class State(enum.Enum):
init = 0
finished = 1


STATE_INIT = State.init
STATE_FINISHED = State.finished

LOGGING_IS_ENABLED_FOR = log.isEnabledFor
LOGGING_DEBUG = logging.DEBUG


class DNSOutgoing:

"""Object representation of an outgoing packet"""
Expand Down Expand Up @@ -93,7 +102,7 @@ def __init__(self, flags: int, multicast: bool = True, id_: int = 0) -> None:
self.size: int = _DNS_PACKET_HEADER_LEN
self.allow_long: bool = True

self.state = State.init
self.state = STATE_INIT

self.questions: List[DNSQuestion] = []
self.answers: List[Tuple[DNSRecord, float]] = []
Expand Down Expand Up @@ -137,7 +146,8 @@ def add_answer(self, inp: DNSIncoming, record: DNSRecord) -> None:

def add_answer_at_time(self, record: Optional[DNSRecord], now: Union[float, int]) -> None:
"""Adds an answer if it does not expire by a certain time"""
if record is not None and (now == 0 or not record.is_expired(now)):
now_float = now
if record is not None and (now_float == 0 or not record.is_expired(now_float)):
self.answers.append((record, now))

def add_authorative_answer(self, record: DNSPointer) -> None:
Expand Down Expand Up @@ -207,7 +217,7 @@ def add_question_or_all_cache(

def _write_byte(self, value: int_) -> None:
"""Writes a single byte to the packet"""
self.data.append(PACK_BYTE(value))
self.data.append(BYTE_TABLE[value])
self.size += 1

def _insert_short_at_start(self, value: int_) -> None:
Expand Down Expand Up @@ -267,7 +277,7 @@ def write_name(self, name: str_) -> None:
"""

# split name into each label
name_length = None
name_length = 0
if name.endswith('.'):
name = name[: len(name) - 1]
labels = name.split('.')
Expand All @@ -276,14 +286,14 @@ def write_name(self, name: str_) -> None:
start_size = self.size
for count in range(len(labels)):
label = name if count == 0 else '.'.join(labels[count:])
index = self.names.get(label)
index = self.names.get(label, 0)
if index:
# If part of the name already exists in the packet,
# create a pointer to it
self._write_byte((index >> 8) | 0xC0)
self._write_byte(index & 0xFF)
return
if name_length is None:
if name_length == 0:
name_length = len(name.encode('utf-8'))
self.names[label] = start_size + name_length - len(label.encode('utf-8'))
self._write_utf(labels[count])
Expand All @@ -293,7 +303,8 @@ def write_name(self, name: str_) -> None:

def _write_question(self, question: DNSQuestion_) -> bool:
"""Writes a question to the packet"""
start_data_length, start_size = len(self.data), self.size
start_data_length = len(self.data)
start_size = self.size
self.write_name(question.name)
self.write_short(question.type)
self._write_record_class(question)
Expand All @@ -314,7 +325,8 @@ def _write_record(self, record: DNSRecord_, now: float_) -> bool:
"""Writes a record (answer, authoritative answer, additional) to
the packet. Returns True on success, or False if we did not
because the packet because the record does not fit."""
start_data_length, start_size = len(self.data), self.size
start_data_length = len(self.data)
start_size = self.size
self.write_name(record.name)
self.write_short(record.type)
self._write_record_class(record)
Expand All @@ -339,11 +351,13 @@ def _check_data_limit_or_rollback(self, start_data_length: int_, start_size: int
if self.size <= len_limit:
return True

log.debug("Reached data limit (size=%d) > (limit=%d) - rolling back", self.size, len_limit)
if LOGGING_IS_ENABLED_FOR(LOGGING_DEBUG): # pragma: no branch
log.debug("Reached data limit (size=%d) > (limit=%d) - rolling back", self.size, len_limit)
del self.data[start_data_length:]
self.size = start_size

rollback_names = [name for name, idx in self.names.items() if idx >= start_size]
start_size_int = start_size
rollback_names = [name for name, idx in self.names.items() if idx >= start_size_int]
for name in rollback_names:
del self.names[name]
return False
Expand Down Expand Up @@ -395,7 +409,7 @@ def packets(self) -> List[bytes]:
return self._packets()

def _packets(self) -> List[bytes]:
if self.state == State.finished:
if self.state == STATE_FINISHED:
return self.packets_data

questions_offset = 0
Expand All @@ -404,7 +418,7 @@ def _packets(self) -> List[bytes]:
additional_offset = 0
# we have to at least write out the question
first_time = True
debug_enable = log.isEnabledFor(logging.DEBUG)
debug_enable = LOGGING_IS_ENABLED_FOR(LOGGING_DEBUG)

while first_time or self._has_more_to_add(
questions_offset, answer_offset, authority_offset, additional_offset
Expand Down Expand Up @@ -476,5 +490,5 @@ def _packets(self) -> List[bytes]:
):
log.warning("packets() made no progress adding records; returning")
break
self.state = State.finished
self.state = STATE_FINISHED
return self.packets_data