Skip to content

Commit efce325

Browse files
authored
Fix asyncio related compiler/library issues (RustPython#6837)
* Fix socket bytes support * fix unwind_fblock * fix posix.sendfile * fix ssl_write * Fix SSL ZeroReturn * fix context * fix generator * Enable unittest test_async_case again
1 parent 9b56aa5 commit efce325

13 files changed

Lines changed: 175 additions & 67 deletions

File tree

Lib/test/test_context.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,6 @@ def fun():
217217

218218
ctx.run(fun)
219219

220-
# TODO: RUSTPYTHON
221-
@unittest.expectedFailure
222220
@isolated_context
223221
def test_context_getset_1(self):
224222
c = contextvars.ContextVar('c')
@@ -317,8 +315,6 @@ def test_context_getset_4(self):
317315
with self.assertRaisesRegex(ValueError, 'different Context'):
318316
c.reset(tok)
319317

320-
# TODO: RUSTPYTHON
321-
@unittest.expectedFailure
322318
@isolated_context
323319
def test_context_getset_5(self):
324320
c = contextvars.ContextVar('c', default=42)
@@ -332,8 +328,6 @@ def fun():
332328
contextvars.copy_context().run(fun)
333329
self.assertEqual(c.get(), [])
334330

335-
# TODO: RUSTPYTHON
336-
@unittest.expectedFailure
337331
def test_context_copy_1(self):
338332
ctx1 = contextvars.Context()
339333
c = contextvars.ContextVar('c', default=42)

Lib/test/test_inspect/test_inspect.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2797,7 +2797,6 @@ def test_easy_debugging(self):
27972797
self.assertIn(name, repr(state))
27982798
self.assertIn(name, str(state))
27992799

2800-
@unittest.expectedFailure # TODO: RUSTPYTHON
28012800
def test_getgeneratorlocals(self):
28022801
def each(lst, a=None):
28032802
b=(1, 2, 3)
@@ -2985,7 +2984,6 @@ def test_easy_debugging(self):
29852984
self.assertIn(name, repr(state))
29862985
self.assertIn(name, str(state))
29872986

2988-
@unittest.expectedFailure # TODO: RUSTPYTHON
29892987
async def test_getasyncgenlocals(self):
29902988
async def each(lst, a=None):
29912989
b=(1, 2, 3)

Lib/test/test_ssl.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3525,7 +3525,6 @@ def test_starttls(self):
35253525
else:
35263526
s.close()
35273527

3528-
@unittest.expectedFailureIfWindows("TODO: RUSTPYTHON")
35293528
def test_socketserver(self):
35303529
"""Using socketserver to create and manage SSL connections."""
35313530
server = make_https_server(self, certfile=SIGNED_CERTFILE)

Lib/test/test_unittest/test_async_case.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,7 @@ class MyException(Exception):
1313

1414

1515
def tearDownModule():
16-
# XXX: RUSTPYTHON; asyncio.events._set_event_loop_policy is not implemented
17-
# asyncio.events._set_event_loop_policy(None)
18-
pass
16+
asyncio.events._set_event_loop_policy(None)
1917

2018

2119
class TestCM:
@@ -52,7 +50,6 @@ def setUp(self):
5250
# starting a new event loop
5351
self.addCleanup(support.gc_collect)
5452

55-
@unittest.expectedFailure # TODO: RUSTPYTHON
5653
def test_full_cycle(self):
5754
expected = ['setUp',
5855
'asyncSetUp',

crates/codegen/src/compile.rs

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,27 +1528,30 @@ impl Compiler {
15281528
// Otherwise, if an exception occurs during the finally body, the stack
15291529
// will be unwound to the wrong depth and the return value will be lost.
15301530
if preserve_tos {
1531-
// Get the handler info from the saved fblock (or current handler)
1532-
// and create a new handler with stack_depth + 1
1533-
let (handler, stack_depth, preserve_lasti) =
1534-
if let Some(handler) = saved_fblock.fb_handler {
1535-
(
1536-
Some(handler),
1537-
saved_fblock.fb_stack_depth + 1, // +1 for return value
1538-
saved_fblock.fb_preserve_lasti,
1539-
)
1540-
} else {
1541-
// No handler in saved_fblock, check current handler
1542-
if let Some(current_handler) = self.current_except_handler() {
1543-
(
1544-
Some(current_handler.handler_block),
1545-
current_handler.stack_depth + 1, // +1 for return value
1546-
current_handler.preserve_lasti,
1547-
)
1548-
} else {
1549-
(None, 1, false) // No handler, but still track the return value
1531+
// Find the outer handler for exceptions during finally body execution.
1532+
// CRITICAL: Only search fblocks with index < fblock_idx (= outer fblocks).
1533+
// Inner FinallyTry blocks may have been restored after their unwind
1534+
// processing, and we must NOT use their handlers - that would cause
1535+
// the inner finally body to execute again on exception.
1536+
let (handler, stack_depth, preserve_lasti) = {
1537+
let code = self.code_stack.last().unwrap();
1538+
let mut found = None;
1539+
// Only search fblocks at indices 0..fblock_idx (outer fblocks)
1540+
// After removal, fblock_idx now points to where saved_fblock was,
1541+
// so indices 0..fblock_idx are the outer fblocks
1542+
for i in (0..fblock_idx).rev() {
1543+
let fblock = &code.fblock[i];
1544+
if let Some(handler) = fblock.fb_handler {
1545+
found = Some((
1546+
Some(handler),
1547+
fblock.fb_stack_depth + 1, // +1 for return value
1548+
fblock.fb_preserve_lasti,
1549+
));
1550+
break;
15501551
}
1551-
};
1552+
}
1553+
found.unwrap_or((None, 1, false))
1554+
};
15521555

15531556
self.push_fblock_with_handler(
15541557
FBlockType::PopValue,

crates/stdlib/src/contextvars.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,15 @@ mod _contextvars {
168168
}
169169

170170
#[pymethod]
171-
fn copy(&self) -> Self {
171+
fn copy(&self, vm: &VirtualMachine) -> Self {
172+
// Deep copy the vars - clone the underlying Hamt data, not just the PyRef
173+
let vars_copy = HamtObject {
174+
hamt: RefCell::new(self.inner.vars.hamt.borrow().clone()),
175+
};
172176
Self {
173177
inner: ContextInner {
174178
idx: Cell::new(usize::MAX),
175-
vars: self.inner.vars.clone(),
179+
vars: vars_copy.into_ref(&vm.ctx),
176180
entered: Cell::new(false),
177181
},
178182
}
@@ -630,7 +634,7 @@ mod _contextvars {
630634

631635
#[pyfunction]
632636
fn copy_context(vm: &VirtualMachine) -> PyContext {
633-
PyContext::current(vm).copy()
637+
PyContext::current(vm).copy(vm)
634638
}
635639

636640
// Set Token.MISSING attribute

crates/stdlib/src/socket.rs

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ mod _socket {
1515
},
1616
common::os::ErrorExt,
1717
convert::{IntoPyException, ToPyObject, TryFromBorrowedObject, TryFromObject},
18-
function::{ArgBytesLike, ArgMemoryBuffer, Either, FsPath, OptionalArg, OptionalOption},
18+
function::{
19+
ArgBytesLike, ArgMemoryBuffer, ArgStrOrBytesLike, Either, FsPath, OptionalArg,
20+
OptionalOption,
21+
},
1922
types::{Constructor, DefaultConstructor, Initializer, Representable},
2023
utils::ToCString,
2124
};
@@ -2783,9 +2786,9 @@ mod _socket {
27832786
#[derive(FromArgs)]
27842787
struct GAIOptions {
27852788
#[pyarg(positional)]
2786-
host: Option<PyStrRef>,
2789+
host: Option<ArgStrOrBytesLike>,
27872790
#[pyarg(positional)]
2788-
port: Option<Either<PyStrRef, i32>>,
2791+
port: Option<Either<ArgStrOrBytesLike, i32>>,
27892792

27902793
#[pyarg(positional, default = c::AF_UNSPEC)]
27912794
family: i32,
@@ -2809,9 +2812,9 @@ mod _socket {
28092812
flags: opts.flags,
28102813
};
28112814

2812-
// Encode host using IDNA encoding
2815+
// Encode host: str uses IDNA encoding, bytes must be valid UTF-8
28132816
let host_encoded: Option<String> = match opts.host.as_ref() {
2814-
Some(s) => {
2817+
Some(ArgStrOrBytesLike::Str(s)) => {
28152818
let encoded =
28162819
vm.state
28172820
.codec_registry
@@ -2820,19 +2823,43 @@ mod _socket {
28202823
.map_err(|_| vm.new_runtime_error("idna output is not utf8".to_owned()))?;
28212824
Some(host_str.to_owned())
28222825
}
2826+
Some(ArgStrOrBytesLike::Buf(b)) => {
2827+
let bytes = b.borrow_buf();
2828+
let host_str = core::str::from_utf8(&bytes).map_err(|_| {
2829+
vm.new_unicode_decode_error("host bytes is not utf8".to_owned())
2830+
})?;
2831+
Some(host_str.to_owned())
2832+
}
28232833
None => None,
28242834
};
28252835
let host = host_encoded.as_deref();
28262836

2827-
// Encode port using UTF-8
2828-
let port: Option<alloc::borrow::Cow<'_, str>> = match opts.port.as_ref() {
2829-
Some(Either::A(s)) => Some(alloc::borrow::Cow::Borrowed(s.to_str().ok_or_else(
2830-
|| vm.new_unicode_encode_error("surrogates not allowed".to_owned()),
2831-
)?)),
2832-
Some(Either::B(i)) => Some(alloc::borrow::Cow::Owned(i.to_string())),
2837+
// Encode port: str/bytes as service name, int as port number
2838+
let port_encoded: Option<String> = match opts.port.as_ref() {
2839+
Some(Either::A(sb)) => {
2840+
let port_str = match sb {
2841+
ArgStrOrBytesLike::Str(s) => {
2842+
// For str, check for surrogates and raise UnicodeEncodeError if found
2843+
s.to_str()
2844+
.ok_or_else(|| vm.new_unicode_encode_error("surrogates not allowed"))?
2845+
.to_owned()
2846+
}
2847+
ArgStrOrBytesLike::Buf(b) => {
2848+
// For bytes, check if it's valid UTF-8
2849+
let bytes = b.borrow_buf();
2850+
core::str::from_utf8(&bytes)
2851+
.map_err(|_| {
2852+
vm.new_unicode_decode_error("port is not utf8".to_owned())
2853+
})?
2854+
.to_owned()
2855+
}
2856+
};
2857+
Some(port_str)
2858+
}
2859+
Some(Either::B(i)) => Some(i.to_string()),
28332860
None => None,
28342861
};
2835-
let port = port.as_ref().map(|p| p.as_ref());
2862+
let port = port_encoded.as_deref();
28362863

28372864
let addrs = dns_lookup::getaddrinfo(host, port, Some(hints))
28382865
.map_err(|err| convert_socket_error(vm, err, SocketError::GaiError))?;

crates/stdlib/src/ssl.rs

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ mod _ssl {
5353
// Import error types used in this module (others are exposed via pymodule(with(...)))
5454
use super::error::{
5555
PySSLError, create_ssl_eof_error, create_ssl_want_read_error, create_ssl_want_write_error,
56+
create_ssl_zero_return_error,
5657
};
5758
use alloc::sync::Arc;
5859
use core::{
@@ -3593,7 +3594,7 @@ mod _ssl {
35933594
let mut conn_guard = self.connection.lock();
35943595
let conn = match conn_guard.as_mut() {
35953596
Some(conn) => conn,
3596-
None => return return_data(vec![], &buffer, vm),
3597+
None => return Err(create_ssl_zero_return_error(vm).upcast()),
35973598
};
35983599
use std::io::BufRead;
35993600
let mut reader = conn.reader();
@@ -3613,8 +3614,20 @@ mod _ssl {
36133614
return return_data(buf, &buffer, vm);
36143615
}
36153616
}
3616-
// Clean closure with close_notify - return empty data
3617-
return_data(vec![], &buffer, vm)
3617+
// Clean closure with close_notify
3618+
// CPython behavior depends on whether we've sent our close_notify:
3619+
// - If we've already sent close_notify (unwrap was called): raise SSLZeroReturnError
3620+
// - If we haven't sent close_notify yet: return empty bytes
3621+
let our_shutdown_state = *self.shutdown_state.lock();
3622+
if our_shutdown_state == ShutdownState::SentCloseNotify
3623+
|| our_shutdown_state == ShutdownState::Completed
3624+
{
3625+
// We already sent close_notify, now receiving peer's → SSLZeroReturnError
3626+
Err(create_ssl_zero_return_error(vm).upcast())
3627+
} else {
3628+
// We haven't sent close_notify yet → return empty bytes
3629+
return_data(vec![], &buffer, vm)
3630+
}
36183631
}
36193632
Err(crate::ssl::compat::SslError::WantRead) => {
36203633
// Non-blocking mode: would block

crates/stdlib/src/ssl/compat.rs

Lines changed: 63 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1552,6 +1552,11 @@ pub(super) fn ssl_read(
15521552

15531553
// Try to read plaintext from rustls buffer
15541554
if let Some(n) = try_read_plaintext(conn, buf)? {
1555+
if n == 0 {
1556+
// EOF from TLS - close_notify received
1557+
// Return ZeroReturn so Python raises SSLZeroReturnError
1558+
return Err(SslError::ZeroReturn);
1559+
}
15551560
return Ok(n);
15561561
}
15571562

@@ -1740,17 +1745,40 @@ pub(super) fn ssl_write(
17401745
let already_buffered = *socket.write_buffered_len.lock();
17411746

17421747
// Only write plaintext if not already buffered
1748+
// Track how much we wrote for partial write handling
1749+
let mut bytes_written_to_rustls = 0usize;
1750+
17431751
if already_buffered == 0 {
17441752
// Write plaintext to rustls (= SSL_write_ex internal buffer write)
1745-
{
1753+
bytes_written_to_rustls = {
17461754
let mut writer = conn.writer();
17471755
use std::io::Write;
1748-
writer
1749-
.write_all(data)
1750-
.map_err(|e| SslError::Syscall(format!("Write failed: {e}")))?;
1751-
}
1752-
// Mark data as buffered
1753-
*socket.write_buffered_len.lock() = data.len();
1756+
// Use write() instead of write_all() to support partial writes.
1757+
// In BIO mode (asyncio), when the internal buffer is full,
1758+
// we want to write as much as possible and return that count,
1759+
// rather than failing completely.
1760+
match writer.write(data) {
1761+
Ok(0) if !data.is_empty() => {
1762+
// Buffer is full and nothing could be written.
1763+
// In BIO mode, return WantWrite so the caller can
1764+
// drain the outgoing BIO and retry.
1765+
if is_bio {
1766+
return Err(SslError::WantWrite);
1767+
}
1768+
return Err(SslError::Syscall("Write failed: buffer full".to_string()));
1769+
}
1770+
Ok(n) => n,
1771+
Err(e) => {
1772+
if is_bio {
1773+
// In BIO mode, treat write errors as WantWrite
1774+
return Err(SslError::WantWrite);
1775+
}
1776+
return Err(SslError::Syscall(format!("Write failed: {e}")));
1777+
}
1778+
}
1779+
};
1780+
// Mark data as buffered (only the portion we actually wrote)
1781+
*socket.write_buffered_len.lock() = bytes_written_to_rustls;
17541782
} else if already_buffered != data.len() {
17551783
// Caller is retrying with different data - this is a protocol error
17561784
// Clear the buffer state and return an SSL error (bad write retry)
@@ -1790,13 +1818,23 @@ pub(super) fn ssl_write(
17901818
}
17911819
Err(SslError::WantWrite) => {
17921820
// Non-blocking socket would block - return WANT_WRITE
1821+
// If we had a partial write to rustls, return partial success
1822+
// instead of error to match OpenSSL partial-write semantics
1823+
if bytes_written_to_rustls > 0 && bytes_written_to_rustls < data.len() {
1824+
*socket.write_buffered_len.lock() = 0;
1825+
return Ok(bytes_written_to_rustls);
1826+
}
17931827
// Keep write_buffered_len set so we don't re-buffer on retry
17941828
return Err(SslError::WantWrite);
17951829
}
17961830
Err(SslError::WantRead) => {
17971831
// Need to read before write can complete (e.g., renegotiation)
1798-
// This matches CPython's handling of SSL_ERROR_WANT_READ in write
17991832
if is_bio {
1833+
// If we had a partial write to rustls, return partial success
1834+
if bytes_written_to_rustls > 0 && bytes_written_to_rustls < data.len() {
1835+
*socket.write_buffered_len.lock() = 0;
1836+
return Ok(bytes_written_to_rustls);
1837+
}
18001838
// Keep write_buffered_len set so we don't re-buffer on retry
18011839
return Err(SslError::WantRead);
18021840
}
@@ -1807,6 +1845,11 @@ pub(super) fn ssl_write(
18071845
// Continue loop
18081846
}
18091847
Err(e @ SslError::Timeout(_)) => {
1848+
// If we had a partial write to rustls, return partial success
1849+
if bytes_written_to_rustls > 0 && bytes_written_to_rustls < data.len() {
1850+
*socket.write_buffered_len.lock() = 0;
1851+
return Ok(bytes_written_to_rustls);
1852+
}
18101853
// Preserve buffered state so retry doesn't duplicate data
18111854
// (send_all_bytes saved unsent TLS bytes to pending_tls_output)
18121855
return Err(e);
@@ -1826,10 +1869,21 @@ pub(super) fn ssl_write(
18261869
.map_err(SslError::Py)?;
18271870
}
18281871

1872+
// Determine how many bytes we actually wrote
1873+
let actual_written = if bytes_written_to_rustls > 0 {
1874+
// Fresh write: return what we wrote to rustls
1875+
bytes_written_to_rustls
1876+
} else if already_buffered > 0 {
1877+
// Retry of previous write: return the full buffered amount
1878+
already_buffered
1879+
} else {
1880+
data.len()
1881+
};
1882+
18291883
// Write completed successfully - clear buffer state
18301884
*socket.write_buffered_len.lock() = 0;
18311885

1832-
Ok(data.len())
1886+
Ok(actual_written)
18331887
}
18341888

18351889
// Helper functions (private-ish, used by public SSL functions)

0 commit comments

Comments
 (0)