Skip to content

Commit 457ad08

Browse files
committed
ssl_write
1 parent eb50246 commit 457ad08

File tree

3 files changed

+260
-102
lines changed

3 files changed

+260
-102
lines changed

crates/stdlib/src/ssl.rs

Lines changed: 85 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ mod _ssl {
5050

5151
// Import error types used in this module (others are exposed via pymodule(with(...)))
5252
use super::error::{
53-
PySSLEOFError, PySSLError, create_ssl_want_read_error, create_ssl_want_write_error,
53+
PySSLError, create_ssl_eof_error, create_ssl_want_read_error, create_ssl_want_write_error,
5454
};
5555
use alloc::sync::Arc;
5656
use core::{
@@ -2783,6 +2783,7 @@ mod _ssl {
27832783
let is_non_blocking = socket_timeout.map(|t| t.is_zero()).unwrap_or(false);
27842784

27852785
let mut sent_total = 0;
2786+
27862787
while sent_total < pending.len() {
27872788
// Calculate timeout: use deadline if provided, otherwise use socket timeout
27882789
let timeout_to_use = if let Some(dl) = deadline {
@@ -2810,6 +2811,9 @@ mod _ssl {
28102811
if timed_out {
28112812
// Keep unsent data in pending buffer
28122813
*pending = pending[sent_total..].to_vec();
2814+
if is_non_blocking {
2815+
return Err(create_ssl_want_write_error(vm).upcast());
2816+
}
28132817
return Err(
28142818
timeout_error_msg(vm, "The write operation timed out".to_string()).upcast(),
28152819
);
@@ -2824,6 +2828,7 @@ mod _ssl {
28242828
*pending = pending[sent_total..].to_vec();
28252829
return Err(create_ssl_want_write_error(vm).upcast());
28262830
}
2831+
// Socket said ready but sent 0 bytes - retry
28272832
continue;
28282833
}
28292834
sent_total += sent;
@@ -2916,6 +2921,9 @@ mod _ssl {
29162921
pub(crate) fn blocking_flush_all_pending(&self, vm: &VirtualMachine) -> PyResult<()> {
29172922
// Get socket timeout to respect during flush
29182923
let timeout = self.get_socket_timeout(vm)?;
2924+
if timeout.map(|t| t.is_zero()).unwrap_or(false) {
2925+
return self.flush_pending_tls_output(vm, None);
2926+
}
29192927

29202928
loop {
29212929
let pending_data = {
@@ -2948,8 +2956,7 @@ mod _ssl {
29482956
let mut pending = self.pending_tls_output.lock();
29492957
pending.drain(..sent);
29502958
}
2951-
// If sent == 0, socket wasn't ready despite select() saying so
2952-
// Continue loop to retry - this avoids infinite loops
2959+
// If sent == 0, loop will retry with sock_select
29532960
}
29542961
Err(e) => {
29552962
if is_blocking_io_error(&e, vm) {
@@ -3515,16 +3522,60 @@ mod _ssl {
35153522
return_data(buf, &buffer, vm)
35163523
}
35173524
Err(crate::ssl::compat::SslError::Eof) => {
3525+
// If plaintext is still buffered, return it before EOF.
3526+
let pending = {
3527+
let mut conn_guard = self.connection.lock();
3528+
let conn = match conn_guard.as_mut() {
3529+
Some(conn) => conn,
3530+
None => return Err(create_ssl_eof_error(vm).upcast()),
3531+
};
3532+
use std::io::BufRead;
3533+
let mut reader = conn.reader();
3534+
reader.fill_buf().map(|buf| buf.len()).unwrap_or(0)
3535+
};
3536+
if pending > 0 {
3537+
let mut buf = vec![0u8; pending.min(len)];
3538+
let read_retry = {
3539+
let mut conn_guard = self.connection.lock();
3540+
let conn = conn_guard
3541+
.as_mut()
3542+
.ok_or_else(|| vm.new_value_error("Connection not established"))?;
3543+
crate::ssl::compat::ssl_read(conn, &mut buf, self, vm)
3544+
};
3545+
if let Ok(n) = read_retry {
3546+
buf.truncate(n);
3547+
return return_data(buf, &buffer, vm);
3548+
}
3549+
}
35183550
// EOF occurred in violation of protocol (unexpected closure)
3519-
Err(vm
3520-
.new_os_subtype_error(
3521-
PySSLEOFError::class(&vm.ctx).to_owned(),
3522-
None,
3523-
"EOF occurred in violation of protocol",
3524-
)
3525-
.upcast())
3551+
Err(create_ssl_eof_error(vm).upcast())
35263552
}
35273553
Err(crate::ssl::compat::SslError::ZeroReturn) => {
3554+
// If plaintext is still buffered, return it before clean EOF.
3555+
let pending = {
3556+
let mut conn_guard = self.connection.lock();
3557+
let conn = match conn_guard.as_mut() {
3558+
Some(conn) => conn,
3559+
None => return return_data(vec![], &buffer, vm),
3560+
};
3561+
use std::io::BufRead;
3562+
let mut reader = conn.reader();
3563+
reader.fill_buf().map(|buf| buf.len()).unwrap_or(0)
3564+
};
3565+
if pending > 0 {
3566+
let mut buf = vec![0u8; pending.min(len)];
3567+
let read_retry = {
3568+
let mut conn_guard = self.connection.lock();
3569+
let conn = conn_guard
3570+
.as_mut()
3571+
.ok_or_else(|| vm.new_value_error("Connection not established"))?;
3572+
crate::ssl::compat::ssl_read(conn, &mut buf, self, vm)
3573+
};
3574+
if let Ok(n) = read_retry {
3575+
buf.truncate(n);
3576+
return return_data(buf, &buffer, vm);
3577+
}
3578+
}
35283579
// Clean closure with close_notify - return empty data
35293580
return_data(vec![], &buffer, vm)
35303581
}
@@ -3580,21 +3631,17 @@ mod _ssl {
35803631
let data_bytes = data.borrow_buf();
35813632
let data_len = data_bytes.len();
35823633

3583-
// return 0 immediately for empty write
35843634
if data_len == 0 {
35853635
return Ok(0);
35863636
}
35873637

3588-
// Ensure handshake is done - if not, complete it first
3589-
// This matches OpenSSL behavior where SSL_write() auto-completes handshake
3638+
// Ensure handshake is done (SSL_write auto-completes handshake)
35903639
if !*self.handshake_done.lock() {
35913640
self.do_handshake(vm)?;
35923641
}
35933642

3594-
// Check if connection has been shut down
3595-
// After unwrap()/shutdown(), write operations should fail with SSLError
3596-
let shutdown_state = *self.shutdown_state.lock();
3597-
if shutdown_state != ShutdownState::NotStarted {
3643+
// Check shutdown state
3644+
if *self.shutdown_state.lock() != ShutdownState::NotStarted {
35983645
return Err(vm
35993646
.new_os_subtype_error(
36003647
PySSLError::class(&vm.ctx).to_owned(),
@@ -3604,76 +3651,32 @@ mod _ssl {
36043651
.upcast());
36053652
}
36063653

3607-
{
3654+
// Call ssl_write (matches CPython's SSL_write_ex loop)
3655+
let result = {
36083656
let mut conn_guard = self.connection.lock();
36093657
let conn = conn_guard
36103658
.as_mut()
36113659
.ok_or_else(|| vm.new_value_error("Connection not established"))?;
36123660

3613-
let is_bio = self.is_bio_mode();
3614-
let data: &[u8] = data_bytes.as_ref();
3661+
crate::ssl::compat::ssl_write(conn, data_bytes.as_ref(), self, vm)
3662+
};
36153663

3616-
// CRITICAL: Flush any pending TLS data before writing new data
3617-
// This ensures TLS 1.3 Finished message reaches server before application data
3618-
// Without this, server may not be ready to process our data
3619-
if !is_bio {
3620-
self.flush_pending_tls_output(vm, None)?;
3664+
match result {
3665+
Ok(n) => {
3666+
self.check_deferred_cert_error(vm)?;
3667+
Ok(n)
36213668
}
3622-
3623-
// Write data in chunks to avoid filling the internal TLS buffer
3624-
// rustls has a limited internal buffer, so we need to flush periodically
3625-
const CHUNK_SIZE: usize = 16384; // 16KB chunks (typical TLS record size)
3626-
let mut written = 0;
3627-
3628-
while written < data.len() {
3629-
let chunk_end = core::cmp::min(written + CHUNK_SIZE, data.len());
3630-
let chunk = &data[written..chunk_end];
3631-
3632-
// Write chunk to TLS layer
3633-
{
3634-
let mut writer = conn.writer();
3635-
use std::io::Write;
3636-
writer
3637-
.write_all(chunk)
3638-
.map_err(|e| vm.new_os_error(format!("Write failed: {e}")))?;
3639-
// Flush to ensure data is converted to TLS records
3640-
writer
3641-
.flush()
3642-
.map_err(|e| vm.new_os_error(format!("Flush failed: {e}")))?;
3643-
}
3644-
3645-
written = chunk_end;
3646-
3647-
// Flush TLS data to socket after each chunk
3648-
if conn.wants_write() {
3649-
if is_bio {
3650-
self.write_pending_tls(conn, vm)?;
3651-
} else {
3652-
// Socket mode: flush all pending TLS data
3653-
// First, try to send any previously pending data
3654-
self.flush_pending_tls_output(vm, None)?;
3655-
3656-
while conn.wants_write() {
3657-
let mut buf = Vec::new();
3658-
conn.write_tls(&mut buf).map_err(|e| {
3659-
vm.new_os_error(format!("TLS write failed: {e}"))
3660-
})?;
3661-
3662-
if !buf.is_empty() {
3663-
// Try to send TLS data, saving unsent bytes to pending buffer
3664-
self.send_tls_output(buf, vm)?;
3665-
}
3666-
}
3667-
}
3668-
}
3669+
Err(crate::ssl::compat::SslError::WantRead) => {
3670+
Err(create_ssl_want_read_error(vm).upcast())
36693671
}
3672+
Err(crate::ssl::compat::SslError::WantWrite) => {
3673+
Err(create_ssl_want_write_error(vm).upcast())
3674+
}
3675+
Err(crate::ssl::compat::SslError::Timeout(msg)) => {
3676+
Err(timeout_error_msg(vm, msg).upcast())
3677+
}
3678+
Err(e) => Err(e.into_py_err(vm)),
36703679
}
3671-
3672-
// Check for deferred certificate verification errors (TLS 1.3)
3673-
// Must be checked AFTER write completes, as the error may be set during I/O
3674-
self.check_deferred_cert_error(vm)?;
3675-
3676-
Ok(data_len)
36773680
}
36783681

36793682
#[pymethod]
@@ -4013,6 +4016,10 @@ mod _ssl {
40134016

40144017
// Write close_notify to outgoing buffer/BIO
40154018
self.write_pending_tls(conn, vm)?;
4019+
// Ensure close_notify and any pending TLS data are flushed
4020+
if !is_bio {
4021+
self.flush_pending_tls_output(vm, None)?;
4022+
}
40164023

40174024
// Update state
40184025
*self.shutdown_state.lock() = ShutdownState::SentCloseNotify;

0 commit comments

Comments
 (0)