Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 183 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

""" Unit tests for zeroconf._core """

import asyncio
import itertools
import logging
import os
Expand All @@ -18,7 +19,7 @@
import zeroconf as r
from zeroconf import _core, const, ServiceBrowser, Zeroconf

from . import has_working_ipv6, _inject_response
from . import has_working_ipv6, _clear_cache, _inject_response

log = logging.getLogger('zeroconf')
original_logging_level = logging.NOTSET
Expand Down Expand Up @@ -423,3 +424,184 @@ def test_sending_unicast():
assert zc.cache.get(entry) is not None

zc.close()


def test_tc_bit_defers():
zc = Zeroconf(interfaces=['127.0.0.1'])
type_ = "_tcbitdefer._tcp.local."
name = "knownname"
name2 = "knownname2"
name3 = "knownname3"

registration_name = "%s.%s" % (name, type_)
registration2_name = "%s.%s" % (name2, type_)
registration3_name = "%s.%s" % (name3, type_)

desc = {'path': '/~paulsm/'}
server_name = "ash-2.local."
server_name2 = "ash-3.local."
server_name3 = "ash-4.local."

info = r.ServiceInfo(
type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")]
)
info2 = r.ServiceInfo(
type_, registration2_name, 80, 0, 0, desc, server_name2, addresses=[socket.inet_aton("10.0.1.2")]
)
info3 = r.ServiceInfo(
type_, registration3_name, 80, 0, 0, desc, server_name3, addresses=[socket.inet_aton("10.0.1.2")]
)
zc.registry.add(info)
zc.registry.add(info2)
zc.registry.add(info3)

def threadsafe_query(*args):
async def make_query():
zc.handle_query(*args)

asyncio.run_coroutine_threadsafe(make_query(), zc.loop).result()

now = r.current_time_millis()
_clear_cache(zc)

generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
question = r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN)
generated.add_question(question)
for _ in range(300):
# Add so many answers we end up with another packet
generated.add_answer_at_time(info.dns_pointer(), now)
generated.add_answer_at_time(info2.dns_pointer(), now)
generated.add_answer_at_time(info3.dns_pointer(), now)
packets = generated.packets()
assert len(packets) == 4
expected_deferred = []
source_ip = '203.0.113.13'

next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(next_packet, source_ip, const._MDNS_PORT)
assert zc._deferred[source_ip] == expected_deferred
assert source_ip in zc._timers

next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(next_packet, source_ip, const._MDNS_PORT)
assert zc._deferred[source_ip] == expected_deferred
assert source_ip in zc._timers
threadsafe_query(next_packet, source_ip, const._MDNS_PORT)
assert zc._deferred[source_ip] == expected_deferred
assert source_ip in zc._timers

next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(next_packet, source_ip, const._MDNS_PORT)
assert zc._deferred[source_ip] == expected_deferred
assert source_ip in zc._timers

next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(next_packet, source_ip, const._MDNS_PORT)
assert source_ip not in zc._deferred
assert source_ip not in zc._timers

# unregister
zc.unregister_service(info)
zc.close()


def test_tc_bit_defers_last_response_missing():
zc = Zeroconf(interfaces=['127.0.0.1'])
type_ = "_knowndefer._tcp.local."
name = "knownname"
name2 = "knownname2"
name3 = "knownname3"

registration_name = "%s.%s" % (name, type_)
registration2_name = "%s.%s" % (name2, type_)
registration3_name = "%s.%s" % (name3, type_)

desc = {'path': '/~paulsm/'}
server_name = "ash-2.local."
server_name2 = "ash-3.local."
server_name3 = "ash-4.local."

info = r.ServiceInfo(
type_, registration_name, 80, 0, 0, desc, server_name, addresses=[socket.inet_aton("10.0.1.2")]
)
info2 = r.ServiceInfo(
type_, registration2_name, 80, 0, 0, desc, server_name2, addresses=[socket.inet_aton("10.0.1.2")]
)
info3 = r.ServiceInfo(
type_, registration3_name, 80, 0, 0, desc, server_name3, addresses=[socket.inet_aton("10.0.1.2")]
)
zc.registry.add(info)
zc.registry.add(info2)
zc.registry.add(info3)

def threadsafe_query(*args):
async def make_query():
zc.handle_query(*args)

asyncio.run_coroutine_threadsafe(make_query(), zc.loop).result()

now = r.current_time_millis()
_clear_cache(zc)
source_ip = '203.0.113.12'

generated = r.DNSOutgoing(const._FLAGS_QR_QUERY)
question = r.DNSQuestion(type_, const._TYPE_PTR, const._CLASS_IN)
generated.add_question(question)
for _ in range(300):
# Add so many answers we end up with another packet
generated.add_answer_at_time(info.dns_pointer(), now)
generated.add_answer_at_time(info2.dns_pointer(), now)
generated.add_answer_at_time(info3.dns_pointer(), now)
packets = generated.packets()
assert len(packets) == 4
expected_deferred = []

next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(next_packet, source_ip, const._MDNS_PORT)
assert zc._deferred[source_ip] == expected_deferred
timer1 = zc._timers[source_ip]

next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(next_packet, source_ip, const._MDNS_PORT)
assert zc._deferred[source_ip] == expected_deferred
timer2 = zc._timers[source_ip]
if sys.version_info >= (3, 7):
assert timer1.cancelled()
assert timer2 != timer1

# Send the same packet again to similar multi interfaces
threadsafe_query(next_packet, source_ip, const._MDNS_PORT)
assert zc._deferred[source_ip] == expected_deferred
assert source_ip in zc._timers
timer3 = zc._timers[source_ip]
if sys.version_info >= (3, 7):
assert not timer3.cancelled()
assert timer3 == timer2

next_packet = r.DNSIncoming(packets.pop(0))
expected_deferred.append(next_packet)
threadsafe_query(next_packet, source_ip, const._MDNS_PORT)
assert zc._deferred[source_ip] == expected_deferred
assert source_ip in zc._timers
timer4 = zc._timers[source_ip]
if sys.version_info >= (3, 7):
assert timer3.cancelled()
assert timer4 != timer3

for _ in range(7):
time.sleep(0.1)
if source_ip not in zc._timers:
break

assert source_ip not in zc._deferred
assert source_ip not in zc._timers

# unregister
zc.registry.remove(info)
zc.close()
Loading