Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 26 additions & 26 deletions tests/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _process_outgoing_packet(out):
query.add_question(r.DNSQuestion(info.name, const._TYPE_SRV, const._CLASS_IN))
query.add_question(r.DNSQuestion(info.name, const._TYPE_TXT, const._CLASS_IN))
query.add_question(r.DNSQuestion(info.server, const._TYPE_A, const._CLASS_IN))
multicast_out = zc.query_handler.response(
multicast_out = zc.query_handler.async_response(
[r.DNSIncoming(packet) for packet in query.packets()], None, const._MDNS_PORT
)[1]
_process_outgoing_packet(multicast_out)
Expand Down Expand Up @@ -134,7 +134,7 @@ def _process_outgoing_packet(out):
query.add_question(r.DNSQuestion(info.name, const._TYPE_TXT, const._CLASS_IN))
query.add_question(r.DNSQuestion(info.server, const._TYPE_A, const._CLASS_IN))
_process_outgoing_packet(
zc.query_handler.response(
zc.query_handler.async_response(
[r.DNSIncoming(packet) for packet in query.packets()], None, const._MDNS_PORT
)[1]
)
Expand Down Expand Up @@ -232,7 +232,7 @@ def test_ptr_optimization():
# Verify we won't respond for 1s with the same multicast
query = r.DNSOutgoing(const._FLAGS_QR_QUERY)
query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN))
unicast_out, multicast_out = zc.query_handler.response(
unicast_out, multicast_out = zc.query_handler.async_response(
[r.DNSIncoming(packet) for packet in query.packets()], None, const._MDNS_PORT
)
assert unicast_out is None
Expand All @@ -244,7 +244,7 @@ def test_ptr_optimization():
# Verify we will now respond
query = r.DNSOutgoing(const._FLAGS_QR_QUERY)
query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN))
unicast_out, multicast_out = zc.query_handler.response(
unicast_out, multicast_out = zc.query_handler.async_response(
[r.DNSIncoming(packet) for packet in query.packets()], None, const._MDNS_PORT
)
assert multicast_out.id == query.id
Expand Down Expand Up @@ -287,7 +287,7 @@ def test_any_query_for_ptr():
question = r.DNSQuestion(type_, const._TYPE_ANY, const._CLASS_IN)
generated.add_question(question)
packets = generated.packets()
_, multicast_out = zc.query_handler.response(
_, multicast_out = zc.query_handler.async_response(
[r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT
)
assert multicast_out.answers[0][0].name == type_
Expand All @@ -313,7 +313,7 @@ def test_aaaa_query():
question = r.DNSQuestion(server_name, const._TYPE_AAAA, const._CLASS_IN)
generated.add_question(question)
packets = generated.packets()
_, multicast_out = zc.query_handler.response(
_, multicast_out = zc.query_handler.async_response(
[r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT
)
assert multicast_out.answers[0][0].address == ipv6_address
Expand Down Expand Up @@ -342,7 +342,7 @@ def test_unicast_response():
# query
query = r.DNSOutgoing(const._FLAGS_QR_QUERY)
query.add_question(r.DNSQuestion(info.type, const._TYPE_PTR, const._CLASS_IN))
unicast_out, multicast_out = zc.query_handler.response(
unicast_out, multicast_out = zc.query_handler.async_response(
[r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", 1234
)
for out in (unicast_out, multicast_out):
Expand Down Expand Up @@ -419,7 +419,7 @@ def _validate_complete_response(query, out):
assert question.unicast is True
query.add_question(question)

unicast_out, multicast_out = zc.query_handler.response(
unicast_out, multicast_out = zc.query_handler.async_response(
[r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT
)
assert multicast_out is None
Expand All @@ -432,7 +432,7 @@ def _validate_complete_response(query, out):
question.unicast = True # Set the QU bit
assert question.unicast is True
query.add_question(question)
unicast_out, multicast_out = zc.query_handler.response(
unicast_out, multicast_out = zc.query_handler.async_response(
[r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT
)
assert unicast_out is None
Expand All @@ -445,7 +445,7 @@ def _validate_complete_response(query, out):
assert question.unicast is True
query.add_question(question)
query.add_authorative_answer(info2.dns_pointer())
unicast_out, multicast_out = zc.query_handler.response(
unicast_out, multicast_out = zc.query_handler.async_response(
[r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT
)
_validate_complete_response(query, unicast_out)
Expand All @@ -458,7 +458,7 @@ def _validate_complete_response(query, out):
question.unicast = True # Set the QU bit
assert question.unicast is True
query.add_question(question)
unicast_out, multicast_out = zc.query_handler.response(
unicast_out, multicast_out = zc.query_handler.async_response(
[r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT
)
assert multicast_out is None
Expand Down Expand Up @@ -487,7 +487,7 @@ def test_known_answer_supression():
question = r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN)
generated.add_question(question)
packets = generated.packets()
unicast_out, multicast_out = zc.query_handler.response(
unicast_out, multicast_out = zc.query_handler.async_response(
[r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT
)
assert unicast_out is None
Expand All @@ -498,7 +498,7 @@ def test_known_answer_supression():
generated.add_question(question)
generated.add_answer_at_time(info.dns_pointer(), now)
packets = generated.packets()
unicast_out, multicast_out = zc.query_handler.response(
unicast_out, multicast_out = zc.query_handler.async_response(
[r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT
)
assert unicast_out is None
Expand All @@ -510,7 +510,7 @@ def test_known_answer_supression():
question = r.DNSQuestion(server_name, const._TYPE_A, const._CLASS_IN)
generated.add_question(question)
packets = generated.packets()
unicast_out, multicast_out = zc.query_handler.response(
unicast_out, multicast_out = zc.query_handler.async_response(
[r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT
)
assert unicast_out is None
Expand All @@ -522,7 +522,7 @@ def test_known_answer_supression():
for dns_address in info.dns_addresses():
generated.add_answer_at_time(dns_address, now)
packets = generated.packets()
unicast_out, multicast_out = zc.query_handler.response(
unicast_out, multicast_out = zc.query_handler.async_response(
[r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT
)
assert unicast_out is None
Expand All @@ -533,7 +533,7 @@ def test_known_answer_supression():
question = r.DNSQuestion(registration_name, const._TYPE_SRV, const._CLASS_IN)
generated.add_question(question)
packets = generated.packets()
unicast_out, multicast_out = zc.query_handler.response(
unicast_out, multicast_out = zc.query_handler.async_response(
[r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT
)
assert unicast_out is None
Expand All @@ -544,7 +544,7 @@ def test_known_answer_supression():
generated.add_question(question)
generated.add_answer_at_time(info.dns_service(), now)
packets = generated.packets()
unicast_out, multicast_out = zc.query_handler.response(
unicast_out, multicast_out = zc.query_handler.async_response(
[r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT
)
assert unicast_out is None
Expand All @@ -556,7 +556,7 @@ def test_known_answer_supression():
question = r.DNSQuestion(registration_name, const._TYPE_TXT, const._CLASS_IN)
generated.add_question(question)
packets = generated.packets()
unicast_out, multicast_out = zc.query_handler.response(
unicast_out, multicast_out = zc.query_handler.async_response(
[r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT
)
assert unicast_out is None
Expand All @@ -567,7 +567,7 @@ def test_known_answer_supression():
generated.add_question(question)
generated.add_answer_at_time(info.dns_text(), now)
packets = generated.packets()
unicast_out, multicast_out = zc.query_handler.response(
unicast_out, multicast_out = zc.query_handler.async_response(
[r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT
)
assert unicast_out is None
Expand Down Expand Up @@ -620,7 +620,7 @@ def test_multi_packet_known_answer_supression():
generated.add_answer_at_time(info3.dns_pointer(), now)
packets = generated.packets()
assert len(packets) > 1
unicast_out, multicast_out = zc.query_handler.response(
unicast_out, multicast_out = zc.query_handler.async_response(
[r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT
)
assert unicast_out is None
Expand Down Expand Up @@ -661,7 +661,7 @@ def test_known_answer_supression_service_type_enumeration_query():
question = r.DNSQuestion(const._SERVICE_TYPE_ENUMERATION_NAME, const._TYPE_PTR, const._CLASS_IN)
generated.add_question(question)
packets = generated.packets()
unicast_out, multicast_out = zc.query_handler.response(
unicast_out, multicast_out = zc.query_handler.async_response(
[r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT
)
assert unicast_out is None
Expand Down Expand Up @@ -691,7 +691,7 @@ def test_known_answer_supression_service_type_enumeration_query():
now,
)
packets = generated.packets()
unicast_out, multicast_out = zc.query_handler.response(
unicast_out, multicast_out = zc.query_handler.async_response(
[r.DNSIncoming(packet) for packet in packets], "1.2.3.4", const._MDNS_PORT
)
assert unicast_out is None
Expand Down Expand Up @@ -747,7 +747,7 @@ def test_qu_response_only_sends_additionals_if_sends_answer():
assert question.unicast is True
query.add_question(question)

unicast_out, multicast_out = zc.query_handler.response(
unicast_out, multicast_out = zc.query_handler.async_response(
[r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT
)
assert multicast_out is None
Expand All @@ -767,7 +767,7 @@ def test_qu_response_only_sends_additionals_if_sends_answer():
assert question.unicast is True
query.add_question(question)

unicast_out, multicast_out = zc.query_handler.response(
unicast_out, multicast_out = zc.query_handler.async_response(
[r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT
)
assert multicast_out is None
Expand All @@ -787,7 +787,7 @@ def test_qu_response_only_sends_additionals_if_sends_answer():
assert question.unicast is True
query.add_question(question)

unicast_out, multicast_out = zc.query_handler.response(
unicast_out, multicast_out = zc.query_handler.async_response(
[r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT
)
assert multicast_out.answers[0][0] == ptr_record
Expand All @@ -813,7 +813,7 @@ def test_qu_response_only_sends_additionals_if_sends_answer():
query.add_question(question)
zc.cache.add(info2.dns_pointer()) # Add 100% TTL for info2 to the cache

unicast_out, multicast_out = zc.query_handler.response(
unicast_out, multicast_out = zc.query_handler.async_response(
[r.DNSIncoming(packet) for packet in query.packets()], "1.2.3.4", const._MDNS_PORT
)
assert multicast_out.answers[0][0] == info.dns_pointer()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,7 +1219,7 @@ def mock_incoming_msg(records) -> r.DNSIncoming:
zc,
mock_incoming_msg([info.dns_pointer(), info.dns_service(), info.dns_text(), *info.dns_addresses()]),
)
zc.wait(100)
time.sleep(0.1)

assert callbacks == [('_hap._tcp.local.', ServiceStateChange.Added, 'xxxyyy._hap._tcp.local.')]
assert zc.get_service_info(type_, registration_name).port == 80
Expand All @@ -1229,7 +1229,7 @@ def mock_incoming_msg(records) -> r.DNSIncoming:
zc,
mock_incoming_msg([info.dns_service()]),
)
zc.wait(100)
time.sleep(0.1)

assert callbacks == [
('_hap._tcp.local.', ServiceStateChange.Added, 'xxxyyy._hap._tcp.local.'),
Expand Down
8 changes: 4 additions & 4 deletions zeroconf/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,8 @@ async def _async_cache_cleanup(self) -> None:
"""Periodic cache cleanup."""
while not self.zc.done:
now = current_time_millis()
self.zc.record_manager.updates(now, list(self.zc.cache.expire(now)))
self.zc.record_manager.updates_complete()
self.zc.record_manager.async_updates(now, list(self.zc.cache.expire(now)))
self.zc.record_manager.async_updates_complete()
await asyncio.sleep(millis_to_seconds(_CACHE_CLEANUP_INTERVAL))

async def _async_close(self) -> None:
Expand Down Expand Up @@ -565,7 +565,7 @@ def remove_listener(self, listener: RecordUpdateListener) -> None:
def handle_response(self, msg: DNSIncoming) -> None:
"""Deal with incoming response packets. All answers
are held in the cache, and listeners are notified."""
self.record_manager.updates_from_response(msg)
self.record_manager.async_updates_from_response(msg)

def handle_query(self, msg: DNSIncoming, addr: str, port: int) -> None:
"""Deal with incoming query packets. Provides a response if
Expand Down Expand Up @@ -594,7 +594,7 @@ def _respond_query(self, msg: Optional[DNSIncoming], addr: str, port: int) -> No
if msg:
packets.append(msg)

unicast_out, multicast_out = self.query_handler.response(packets, addr, port)
unicast_out, multicast_out = self.query_handler.async_response(packets, addr, port)
if unicast_out:
self.async_send(unicast_out, addr, port)
if multicast_out:
Expand Down
36 changes: 24 additions & 12 deletions zeroconf/_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,14 @@ def _answer_question(
if not known_answers.suppresses(dns_text):
answer_set[dns_text] = set()

def response( # pylint: disable=unused-argument
def async_response( # pylint: disable=unused-argument
self, msgs: List[DNSIncoming], addr: Optional[str], port: int
) -> Tuple[Optional[DNSOutgoing], Optional[DNSOutgoing]]:
"""Deal with incoming query packets. Provides a response if possible."""
"""Deal with incoming query packets. Provides a response if possible.

This function must be run in the event loop as it is not
threadsafe.
"""
ucast_source = port != _MDNS_PORT
known_answers = DNSRRSet(itertools.chain(*[msg.answers for msg in msgs]))
query_res = _QueryResponse(self.cache, msgs[0], ucast_source)
Expand Down Expand Up @@ -272,28 +276,36 @@ def __init__(self, zeroconf: 'Zeroconf') -> None:
self.cache = zeroconf.cache
self.listeners: List[RecordUpdateListener] = []

def updates(self, now: float, rec: List[DNSRecord]) -> None:
def async_updates(self, now: float, rec: List[DNSRecord]) -> None:
"""Used to notify listeners of new information that has updated
a record.

This method must be called before the cache is updated.

This method will be run in the event loop.
"""
for listener in self.listeners:
listener.update_records(self.zc, now, rec)
listener.async_update_records(self.zc, now, rec)

def updates_complete(self) -> None:
def async_updates_complete(self) -> None:
"""Used to notify listeners of new information that has updated
a record.

This method must be called after the cache is updated.

This method will be run in the event loop.
"""
for listener in self.listeners:
listener.update_records_complete()
listener.async_update_records_complete()
self.zc.notify_all()

def updates_from_response(self, msg: DNSIncoming) -> None:
def async_updates_from_response(self, msg: DNSIncoming) -> None:
"""Deal with incoming response packets. All answers
are held in the cache, and listeners are notified."""
are held in the cache, and listeners are notified.

This function must be run in the event loop as it is not
threadsafe.
"""
updates: List[DNSRecord] = []
address_adds: List[DNSAddress] = []
other_adds: List[DNSRecord] = []
Expand Down Expand Up @@ -334,7 +346,7 @@ def updates_from_response(self, msg: DNSIncoming) -> None:
if not updates and not address_adds and not other_adds and not removes:
return

self.updates(now, updates)
self.async_updates(now, updates)
# The cache adds must be processed AFTER we trigger
# the updates since we compare existing data
# with the new data and updating the cache
Expand All @@ -355,7 +367,7 @@ def updates_from_response(self, msg: DNSIncoming) -> None:
# ServiceInfo could generate an un-needed query
# because the data was not yet populated.
self.cache.remove_records(removes)
self.updates_complete()
self.async_updates_complete()

def add_listener(
self, listener: RecordUpdateListener, question: Optional[Union[DNSQuestion, List[DNSQuestion]]]
Expand All @@ -374,8 +386,8 @@ def add_listener(
if single_question.answered_by(record) and not record.is_expired(now):
records.append(record)
if records:
listener.update_records(self.zc, now, records)
listener.update_records_complete()
listener.async_update_records(self.zc, now, records)
listener.async_update_records_complete()

self.zc.notify_all()

Expand Down
Loading