Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions Lib/asyncio/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,22 +86,27 @@ def run(self):
global return_code

try:
banner = (
f'asyncio REPL {sys.version} on {sys.platform}\n'
f'Use "await" directly instead of "asyncio.run()".\n'
f'Type "help", "copyright", "credits" or "license" '
f'for more information.\n'
)
if not sys.flags.quiet:
banner = (
f'asyncio REPL {sys.version} on {sys.platform}\n'
f'Use "await" directly instead of "asyncio.run()".\n'
f'Type "help", "copyright", "credits" or "license" '
f'for more information.\n'
)

console.write(banner)
console.write(banner)

if startup_path := os.getenv("PYTHONSTARTUP"):
if not sys.flags.isolated and (startup_path := os.getenv("PYTHONSTARTUP")):
sys.audit("cpython.run_startup", startup_path)

import tokenize
with tokenize.open(startup_path) as f:
startup_code = compile(f.read(), startup_path, "exec")
try:
import tokenize
with tokenize.open(startup_path) as f:
startup_code = compile(f.read(), startup_path, "exec")
exec(startup_code, console.locals)
except SystemExit:
raise
except BaseException:
console.showtraceback()

ps1 = getattr(sys, "ps1", ">>> ")
if CAN_USE_PYREPL:
Expand Down Expand Up @@ -236,4 +241,5 @@ def interrupt(self) -> None:
break

console.write('exiting asyncio REPL...\n')
loop.close()
sys.exit(return_code)
11 changes: 11 additions & 0 deletions Lib/asyncio/base_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1345,6 +1345,17 @@ async def start_tls(self, transport, protocol, sslcontext, *,
# have a chance to get called before "ssl_protocol.connection_made()".
transport.pause_reading()

# gh-142352: move buffered StreamReader data to SSLProtocol
if server_side:
from .streams import StreamReaderProtocol
if isinstance(protocol, StreamReaderProtocol):
stream_reader = getattr(protocol, '_stream_reader', None)
if stream_reader is not None:
buffer = stream_reader._buffer
if buffer:
ssl_protocol._incoming.write(buffer)
buffer.clear()

transport.set_protocol(ssl_protocol)
conmade_cb = self.call_soon(ssl_protocol.connection_made, transport)
resume_cb = self.call_soon(transport.resume_reading)
Expand Down
4 changes: 2 additions & 2 deletions Lib/asyncio/base_subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def _try_finish(self):
# to avoid hanging forever in self._wait as otherwise _exit_waiters
# would never be woken up, we wake them up here.
for waiter in self._exit_waiters:
if not waiter.cancelled():
if not waiter.done():
waiter.set_result(self._returncode)
if all(p is not None and p.disconnected
for p in self._pipes.values()):
Expand All @@ -278,7 +278,7 @@ def _call_connection_lost(self, exc):
finally:
# wake up futures waiting for wait()
for waiter in self._exit_waiters:
if not waiter.cancelled():
if not waiter.done():
waiter.set_result(self._returncode)
self._exit_waiters = None
self._loop = None
Expand Down
4 changes: 2 additions & 2 deletions Lib/asyncio/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def _set_state(future, other):

def _call_check_cancel(destination):
if destination.cancelled():
if source_loop is None or source_loop is dest_loop:
if source_loop is None or source_loop is events._get_running_loop():
source.cancel()
else:
source_loop.call_soon_threadsafe(source.cancel)
Expand All @@ -401,7 +401,7 @@ def _call_set_state(source):
if (destination.cancelled() and
dest_loop is not None and dest_loop.is_closed()):
return
if dest_loop is None or dest_loop is source_loop:
if dest_loop is None or dest_loop is events._get_running_loop():
_set_state(destination, source)
else:
if dest_loop.is_closed():
Expand Down
2 changes: 1 addition & 1 deletion Lib/asyncio/queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Queue(mixins._LoopBoundMixin):
is an integer greater than 0, then "await put()" will block when the
queue reaches maxsize, until an item is removed by get().

Unlike the standard library Queue, you can reliably know this Queue's size
Unlike queue.Queue, you can reliably know this Queue's size
with qsize(), since your single-threaded asyncio application won't be
interrupted between calling qsize() and doing an operation on the Queue.
"""
Expand Down
26 changes: 17 additions & 9 deletions Lib/asyncio/windows_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import msvcrt
import os
import subprocess
import tempfile
import warnings


Expand All @@ -24,17 +23,14 @@
PIPE = subprocess.PIPE
STDOUT = subprocess.STDOUT
_mmap_counter = itertools.count()
_MAX_PIPE_ATTEMPTS = 20


# Replacement for os.pipe() using handles instead of fds


def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE):
"""Like os.pipe() but with overlapped support and using handles not fds."""
address = tempfile.mktemp(
prefix=r'\\.\pipe\python-pipe-{:d}-{:d}-'.format(
os.getpid(), next(_mmap_counter)))

if duplex:
openmode = _winapi.PIPE_ACCESS_DUPLEX
access = _winapi.GENERIC_READ | _winapi.GENERIC_WRITE
Expand All @@ -56,9 +52,20 @@ def pipe(*, duplex=False, overlapped=(True, True), bufsize=BUFSIZE):

h1 = h2 = None
try:
h1 = _winapi.CreateNamedPipe(
address, openmode, _winapi.PIPE_WAIT,
1, obsize, ibsize, _winapi.NMPWAIT_WAIT_FOREVER, _winapi.NULL)
for attempts in itertools.count():
address = r'\\.\pipe\python-pipe-{:d}-{:d}-{}'.format(
os.getpid(), next(_mmap_counter), os.urandom(8).hex())
try:
h1 = _winapi.CreateNamedPipe(
address, openmode, _winapi.PIPE_WAIT,
1, obsize, ibsize, _winapi.NMPWAIT_WAIT_FOREVER, _winapi.NULL)
break
except OSError as e:
if attempts >= _MAX_PIPE_ATTEMPTS:
raise
if e.winerror not in (_winapi.ERROR_PIPE_BUSY,
_winapi.ERROR_ACCESS_DENIED):
raise

h2 = _winapi.CreateFile(
address, access, 0, _winapi.NULL, _winapi.OPEN_EXISTING,
Expand Down Expand Up @@ -104,8 +111,9 @@ def fileno(self):

def close(self, *, CloseHandle=_winapi.CloseHandle):
if self._handle is not None:
CloseHandle(self._handle)
handle = self._handle
self._handle = None
CloseHandle(handle)

def __del__(self, _warn=warnings.warn):
if self._handle is not None:
Expand Down
25 changes: 25 additions & 0 deletions Lib/test/test_asyncio/test_sock_lowlevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,27 @@ def test_recvfrom_into(self):
self.loop.run_until_complete(
self._basetest_datagram_recvfrom_into(server_address))

async def _basetest_datagram_recvfrom_into_wrong_size(self, server_address):
# Call sock_sendto() with a size larger than the buffer
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
sock.setblocking(False)

buf = bytearray(5000)
data = b'\x01' * 4096
wrong_size = len(buf) + 1
await self.loop.sock_sendto(sock, data, server_address)
with self.assertRaises(ValueError):
await self.loop.sock_recvfrom_into(
sock, buf, wrong_size)

size, addr = await self.loop.sock_recvfrom_into(sock, buf)
self.assertEqual(buf[:size], data)

def test_recvfrom_into_wrong_size(self):
with test_utils.run_udp_echo_server() as server_address:
self.loop.run_until_complete(
self._basetest_datagram_recvfrom_into_wrong_size(server_address))

async def _basetest_datagram_sendto_blocking(self, server_address):
# Sad path, sock.sendto() raises BlockingIOError
# This involves patching sock.sendto() to raise BlockingIOError but
Expand Down Expand Up @@ -642,6 +663,10 @@ async def recvfrom_into(socket):
self._basetest_datagram_send_to_non_listening_address(
recvfrom_into))

@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON; AssertionError: ValueError not raised")
def test_recvfrom_into_wrong_size(self):
return super().test_recvfrom_into_wrong_size()

else:
import selectors

Expand Down
40 changes: 17 additions & 23 deletions Lib/test/test_asyncio/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

MACOS = (sys.platform == 'darwin')
BUF_MULTIPLIER = 1024 if not MACOS else 64
HANDSHAKE_TIMEOUT = support.LONG_TIMEOUT


def tearDownModule():
Expand Down Expand Up @@ -257,15 +258,12 @@ def prog(sock):
await fut

async def start_server():
extras = {}
extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT)

srv = await asyncio.start_server(
handle_client,
'127.0.0.1', 0,
family=socket.AF_INET,
ssl=sslctx,
**extras)
ssl_handshake_timeout=HANDSHAKE_TIMEOUT)

try:
srv_socks = srv.sockets
Expand Down Expand Up @@ -322,14 +320,11 @@ def server(sock):
sock.close()

async def client(addr):
extras = {}
extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT)

reader, writer = await asyncio.open_connection(
*addr,
ssl=client_sslctx,
server_hostname='',
**extras)
ssl_handshake_timeout=HANDSHAKE_TIMEOUT)

writer.write(A_DATA)
self.assertEqual(await reader.readexactly(2), b'OK')
Expand All @@ -349,7 +344,8 @@ async def client_sock(addr):
reader, writer = await asyncio.open_connection(
sock=sock,
ssl=client_sslctx,
server_hostname='')
server_hostname='',
ssl_handshake_timeout=HANDSHAKE_TIMEOUT)

writer.write(A_DATA)
self.assertEqual(await reader.readexactly(2), b'OK')
Expand Down Expand Up @@ -448,7 +444,7 @@ async def client(addr):
*addr,
ssl=client_sslctx,
server_hostname='',
ssl_handshake_timeout=support.SHORT_TIMEOUT)
ssl_handshake_timeout=HANDSHAKE_TIMEOUT)
writer.close()
await self.wait_closed(writer)

Expand Down Expand Up @@ -610,7 +606,7 @@ def client():

extras = {}
if server_ssl:
extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT)
extras = dict(ssl_handshake_timeout=HANDSHAKE_TIMEOUT)

f = loop.create_task(
loop.connect_accepted_socket(
Expand Down Expand Up @@ -659,7 +655,8 @@ async def client(addr):
reader, writer = await asyncio.open_connection(
*addr,
ssl=client_sslctx,
server_hostname='')
server_hostname='',
ssl_handshake_timeout=HANDSHAKE_TIMEOUT)

self.assertEqual(await reader.readline(), b'A\n')
writer.write(b'B')
Expand Down Expand Up @@ -1153,14 +1150,11 @@ def do(func, *args):
await fut

async def start_server():
extras = {}

srv = await self.loop.create_server(
server_protocol_factory,
'127.0.0.1', 0,
family=socket.AF_INET,
ssl=sslctx_1,
**extras)
ssl=sslctx_1)

try:
srv_socks = srv.sockets
Expand Down Expand Up @@ -1210,14 +1204,11 @@ def server(sock):
sock.close()

async def client(addr):
extras = {}
extras = dict(ssl_handshake_timeout=support.SHORT_TIMEOUT)

reader, writer = await asyncio.open_connection(
*addr,
ssl=client_sslctx,
server_hostname='',
**extras)
ssl_handshake_timeout=HANDSHAKE_TIMEOUT)

writer.write(A_DATA)
self.assertEqual(await reader.readexactly(2), b'OK')
Expand Down Expand Up @@ -1287,7 +1278,8 @@ async def client(addr):
reader, writer = await asyncio.open_connection(
*addr,
ssl=client_sslctx,
server_hostname='')
server_hostname='',
ssl_handshake_timeout=HANDSHAKE_TIMEOUT)
sslprotocol = writer.transport._ssl_protocol
writer.write(b'ping')
data = await reader.readexactly(4)
Expand Down Expand Up @@ -1399,7 +1391,8 @@ async def client(addr):
reader, writer = await asyncio.open_connection(
*addr,
ssl=client_sslctx,
server_hostname='')
server_hostname='',
ssl_handshake_timeout=HANDSHAKE_TIMEOUT)
writer.write(b'ping')
data = await reader.readexactly(4)
self.assertEqual(data, b'pong')
Expand Down Expand Up @@ -1530,7 +1523,8 @@ async def client(addr):
reader, writer = await asyncio.open_connection(
*addr,
ssl=client_sslctx,
server_hostname='')
server_hostname='',
ssl_handshake_timeout=HANDSHAKE_TIMEOUT)
writer.write(b'ping')
data = await reader.readexactly(4)
self.assertEqual(data, b'pong')
Expand Down
42 changes: 42 additions & 0 deletions Lib/test/test_asyncio/test_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,48 @@ async def client(addr):
self.assertEqual(msg1, b"hello world 1!\n")
self.assertEqual(msg2, b"hello world 2!\n")

@unittest.skipIf(ssl is None, 'No ssl module')
def test_start_tls_buffered_data(self):
# gh-142352: test start_tls() with buffered data

async def server_handler(client_reader, client_writer):
# Wait for TLS ClientHello to be buffered before start_tls().
await client_reader._wait_for_data('test_start_tls_buffered_data'),
self.assertTrue(client_reader._buffer)
await client_writer.start_tls(test_utils.simple_server_sslcontext())

line = await client_reader.readline()
self.assertEqual(line, b"ping\n")
client_writer.write(b"pong\n")
await client_writer.drain()
client_writer.close()
await client_writer.wait_closed()

async def client(addr):
reader, writer = await asyncio.open_connection(*addr)
await writer.start_tls(test_utils.simple_client_sslcontext())

writer.write(b"ping\n")
await writer.drain()
line = await reader.readline()
self.assertEqual(line, b"pong\n")
writer.close()
await writer.wait_closed()

async def run_test():
server = await asyncio.start_server(
server_handler, socket_helper.HOSTv4, 0)
server_addr = server.sockets[0].getsockname()

await client(server_addr)
server.close()
await server.wait_closed()

messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
self.loop.run_until_complete(run_test())
self.assertEqual(messages, [])

def test_streamreader_constructor_without_loop(self):
with self.assertRaisesRegex(RuntimeError, 'no current event loop'):
asyncio.StreamReader()
Expand Down
Loading
Loading