Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 94 additions & 11 deletions crates/stdlib/src/ssl.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// spell-checker: ignore ssleof aesccm aesgcm getblocking setblocking ENDTLS TLSEXT
// spell-checker: ignore ssleof aesccm aesgcm capath getblocking setblocking ENDTLS TLSEXT

//! Pure Rust SSL/TLS implementation using rustls
//!
Expand Down Expand Up @@ -2786,6 +2786,16 @@ mod _ssl {
recv_method.call((self.sock.clone(), vm.ctx.new_int(size)), vm)
}

/// Peek at socket data without consuming it (MSG_PEEK).
/// Used during TLS shutdown to avoid consuming post-TLS cleartext data.
pub(crate) fn sock_peek(&self, size: usize, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
let socket_mod = vm.import("socket", 0)?;
let socket_class = socket_mod.get_attr("socket", vm)?;
let recv_method = socket_class.get_attr("recv", vm)?;
let msg_peek = socket_mod.get_attr("MSG_PEEK", vm)?;
recv_method.call((self.sock.clone(), vm.ctx.new_int(size), msg_peek), vm)
}

/// Socket send - just sends data, caller must handle pending flush
/// Use flush_pending_tls_output before this if ordering is important
pub(crate) fn sock_send(&self, data: &[u8], vm: &VirtualMachine) -> PyResult<PyObjectRef> {
Expand Down Expand Up @@ -4287,45 +4297,118 @@ mod _ssl {
conn: &mut TlsConnection,
vm: &VirtualMachine,
) -> PyResult<bool> {
// Try to read incoming data
// In socket mode, peek first to avoid consuming post-TLS cleartext
// data. During STARTTLS, after close_notify exchange, the socket
// transitions to cleartext. Without peeking, sock_recv may consume
// cleartext data meant for the application after unwrap().
if self.incoming_bio.is_none() {
return self.try_read_close_notify_socket(conn, vm);
}

// BIO mode: read from incoming BIO
match self.sock_recv(SSL3_RT_MAX_PLAIN_LENGTH, vm) {
Ok(bytes_obj) => {
let bytes = ArgBytesLike::try_from_object(vm, bytes_obj)?;
let data = bytes.borrow_buf();

if data.is_empty() {
// Empty read could mean EOF or just "no data yet" in BIO mode
if let Some(ref bio) = self.incoming_bio {
// BIO mode: check if EOF was signaled via write_eof()
let bio_obj: PyObjectRef = bio.clone().into();
let eof_attr = bio_obj.get_attr("eof", vm)?;
let is_eof = eof_attr.try_to_bool(vm)?;
if !is_eof {
// No EOF signaled, just no data available yet
return Ok(false);
}
}
// Socket mode or BIO with EOF: peer closed connection
// This is "ragged EOF" - peer closed without close_notify
return Ok(true);
}

// Feed data to TLS connection
let data_slice: &[u8] = data.as_ref();
let mut cursor = std::io::Cursor::new(data_slice);
let _ = conn.read_tls(&mut cursor);
let _ = conn.process_new_packets();
Ok(false)
}
Err(e) => {
if is_blocking_io_error(&e, vm) {
return Ok(false);
}
Ok(true)
}
}
}

// Process packets
/// Socket-mode close_notify reader that respects TLS record boundaries.
/// Uses MSG_PEEK to inspect data before consuming, preventing accidental
/// consumption of post-TLS cleartext data during STARTTLS transitions.
///
/// Equivalent to OpenSSL's `SSL_set_read_ahead(ssl, 0)` — rustls has no
/// such knob, so we enforce record-level reads manually via peek.
fn try_read_close_notify_socket(
&self,
conn: &mut TlsConnection,
vm: &VirtualMachine,
) -> PyResult<bool> {
// Peek at the first 5 bytes (TLS record header size)
let peeked_obj = match self.sock_peek(5, vm) {
Ok(obj) => obj,
Err(e) => {
if is_blocking_io_error(&e, vm) {
return Ok(false);
}
return Ok(true);
}
};

let peeked = ArgBytesLike::try_from_object(vm, peeked_obj)?;
let peek_data = peeked.borrow_buf();

if peek_data.is_empty() {
return Ok(true); // EOF
}

// TLS record content types: ChangeCipherSpec(20), Alert(21),
// Handshake(22), ApplicationData(23)
let content_type = peek_data[0];
if !(20..=23).contains(&content_type) {
// Not a TLS record - post-TLS cleartext data.
// Peer has completed TLS shutdown; don't consume this data.
return Ok(true);
}

// Determine how many bytes to read for exactly one TLS record
let recv_size = if peek_data.len() >= 5 {
let record_length = u16::from_be_bytes([peek_data[3], peek_data[4]]) as usize;
5 + record_length
} else {
// Partial header available - read just these bytes for now
peek_data.len()
};

drop(peek_data);
drop(peeked);

// Now consume exactly one TLS record from the socket
match self.sock_recv(recv_size, vm) {
Ok(bytes_obj) => {
let bytes = ArgBytesLike::try_from_object(vm, bytes_obj)?;
let data = bytes.borrow_buf();

if data.is_empty() {
return Ok(true);
}

let data_slice: &[u8] = data.as_ref();
let mut cursor = std::io::Cursor::new(data_slice);
let _ = conn.read_tls(&mut cursor);
let _ = conn.process_new_packets();
Ok(false)
}
Err(e) => {
// BlockingIOError means no data yet
if is_blocking_io_error(&e, vm) {
return Ok(false);
}
// Connection reset, EOF, or other error means peer closed
// ECONNRESET, EPIPE, broken pipe, etc.
Ok(true)
}
}
Expand Down
14 changes: 11 additions & 3 deletions extra_tests/snippets/builtin_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def __gt__(self, other):
lst.sort(key=C)
assert lst == [1, 2, 3, 4, 5]


# Test that sorted() uses __lt__ (not __gt__) for comparisons.
# Track which comparison method is actually called during sort.
class TrackComparison:
Expand All @@ -287,29 +288,36 @@ def __gt__(self, other):
TrackComparison.gt_calls += 1
return self.value > other.value


# Reset and test sorted()
TrackComparison.lt_calls = 0
TrackComparison.gt_calls = 0
items = [TrackComparison(3), TrackComparison(1), TrackComparison(2)]
sorted(items)
assert TrackComparison.lt_calls > 0, "sorted() should call __lt__"
assert TrackComparison.gt_calls == 0, f"sorted() should not call __gt__, but it was called {TrackComparison.gt_calls} times"
assert TrackComparison.gt_calls == 0, (
f"sorted() should not call __gt__, but it was called {TrackComparison.gt_calls} times"
)

# Reset and test list.sort()
TrackComparison.lt_calls = 0
TrackComparison.gt_calls = 0
items = [TrackComparison(3), TrackComparison(1), TrackComparison(2)]
items.sort()
assert TrackComparison.lt_calls > 0, "list.sort() should call __lt__"
assert TrackComparison.gt_calls == 0, f"list.sort() should not call __gt__, but it was called {TrackComparison.gt_calls} times"
assert TrackComparison.gt_calls == 0, (
f"list.sort() should not call __gt__, but it was called {TrackComparison.gt_calls} times"
)

# Reset and test sorted(reverse=True) - should still use __lt__, not __gt__
TrackComparison.lt_calls = 0
TrackComparison.gt_calls = 0
items = [TrackComparison(3), TrackComparison(1), TrackComparison(2)]
sorted(items, reverse=True)
assert TrackComparison.lt_calls > 0, "sorted(reverse=True) should call __lt__"
assert TrackComparison.gt_calls == 0, f"sorted(reverse=True) should not call __gt__, but it was called {TrackComparison.gt_calls} times"
assert TrackComparison.gt_calls == 0, (
f"sorted(reverse=True) should not call __gt__, but it was called {TrackComparison.gt_calls} times"
)

lst = [5, 1, 2, 3, 4]

Expand Down
Loading