Skip to content

Commit b3d0d0d

Browse files
committed
fix ssl MSG_PEEK
1 parent e2fda95 commit b3d0d0d

File tree

2 files changed

+105
-14
lines changed

2 files changed

+105
-14
lines changed

crates/stdlib/src/ssl.rs

Lines changed: 94 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// spell-checker: ignore ssleof aesccm aesgcm getblocking setblocking ENDTLS TLSEXT
1+
// spell-checker: ignore ssleof aesccm aesgcm capath getblocking setblocking ENDTLS TLSEXT
22

33
//! Pure Rust SSL/TLS implementation using rustls
44
//!
@@ -2786,6 +2786,16 @@ mod _ssl {
27862786
recv_method.call((self.sock.clone(), vm.ctx.new_int(size)), vm)
27872787
}
27882788

2789+
/// Peek at socket data without consuming it (MSG_PEEK).
2790+
/// Used during TLS shutdown to avoid consuming post-TLS cleartext data.
2791+
pub(crate) fn sock_peek(&self, size: usize, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
2792+
let socket_mod = vm.import("socket", 0)?;
2793+
let socket_class = socket_mod.get_attr("socket", vm)?;
2794+
let recv_method = socket_class.get_attr("recv", vm)?;
2795+
let msg_peek = socket_mod.get_attr("MSG_PEEK", vm)?;
2796+
recv_method.call((self.sock.clone(), vm.ctx.new_int(size), msg_peek), vm)
2797+
}
2798+
27892799
/// Socket send - just sends data, caller must handle pending flush
27902800
/// Use flush_pending_tls_output before this if ordering is important
27912801
pub(crate) fn sock_send(&self, data: &[u8], vm: &VirtualMachine) -> PyResult<PyObjectRef> {
@@ -4287,45 +4297,118 @@ mod _ssl {
42874297
conn: &mut TlsConnection,
42884298
vm: &VirtualMachine,
42894299
) -> PyResult<bool> {
4290-
// Try to read incoming data
4300+
// In socket mode, peek first to avoid consuming post-TLS cleartext
4301+
// data. During STARTTLS, after close_notify exchange, the socket
4302+
// transitions to cleartext. Without peeking, sock_recv may consume
4303+
// cleartext data meant for the application after unwrap().
4304+
if self.incoming_bio.is_none() {
4305+
return self.try_read_close_notify_socket(conn, vm);
4306+
}
4307+
4308+
// BIO mode: read from incoming BIO
42914309
match self.sock_recv(SSL3_RT_MAX_PLAIN_LENGTH, vm) {
42924310
Ok(bytes_obj) => {
42934311
let bytes = ArgBytesLike::try_from_object(vm, bytes_obj)?;
42944312
let data = bytes.borrow_buf();
42954313

42964314
if data.is_empty() {
4297-
// Empty read could mean EOF or just "no data yet" in BIO mode
42984315
if let Some(ref bio) = self.incoming_bio {
42994316
// BIO mode: check if EOF was signaled via write_eof()
43004317
let bio_obj: PyObjectRef = bio.clone().into();
43014318
let eof_attr = bio_obj.get_attr("eof", vm)?;
43024319
let is_eof = eof_attr.try_to_bool(vm)?;
43034320
if !is_eof {
4304-
// No EOF signaled, just no data available yet
43054321
return Ok(false);
43064322
}
43074323
}
4308-
// Socket mode or BIO with EOF: peer closed connection
4309-
// This is "ragged EOF" - peer closed without close_notify
43104324
return Ok(true);
43114325
}
43124326

4313-
// Feed data to TLS connection
43144327
let data_slice: &[u8] = data.as_ref();
43154328
let mut cursor = std::io::Cursor::new(data_slice);
43164329
let _ = conn.read_tls(&mut cursor);
4330+
let _ = conn.process_new_packets();
4331+
Ok(false)
4332+
}
4333+
Err(e) => {
4334+
if is_blocking_io_error(&e, vm) {
4335+
return Ok(false);
4336+
}
4337+
Ok(true)
4338+
}
4339+
}
4340+
}
43174341

4318-
// Process packets
4342+
/// Socket-mode close_notify reader that respects TLS record boundaries.
4343+
/// Uses MSG_PEEK to inspect data before consuming, preventing accidental
4344+
/// consumption of post-TLS cleartext data during STARTTLS transitions.
4345+
///
4346+
/// Equivalent to OpenSSL's `SSL_set_read_ahead(ssl, 0)` — rustls has no
4347+
/// such knob, so we enforce record-level reads manually via peek.
4348+
fn try_read_close_notify_socket(
4349+
&self,
4350+
conn: &mut TlsConnection,
4351+
vm: &VirtualMachine,
4352+
) -> PyResult<bool> {
4353+
// Peek at the first 5 bytes (TLS record header size)
4354+
let peeked_obj = match self.sock_peek(5, vm) {
4355+
Ok(obj) => obj,
4356+
Err(e) => {
4357+
if is_blocking_io_error(&e, vm) {
4358+
return Ok(false);
4359+
}
4360+
return Ok(true);
4361+
}
4362+
};
4363+
4364+
let peeked = ArgBytesLike::try_from_object(vm, peeked_obj)?;
4365+
let peek_data = peeked.borrow_buf();
4366+
4367+
if peek_data.is_empty() {
4368+
return Ok(true); // EOF
4369+
}
4370+
4371+
// TLS record content types: ChangeCipherSpec(20), Alert(21),
4372+
// Handshake(22), ApplicationData(23)
4373+
let content_type = peek_data[0];
4374+
if !(20..=23).contains(&content_type) {
4375+
// Not a TLS record - post-TLS cleartext data.
4376+
// Peer has completed TLS shutdown; don't consume this data.
4377+
return Ok(true);
4378+
}
4379+
4380+
// Determine how many bytes to read for exactly one TLS record
4381+
let recv_size = if peek_data.len() >= 5 {
4382+
let record_length = u16::from_be_bytes([peek_data[3], peek_data[4]]) as usize;
4383+
5 + record_length
4384+
} else {
4385+
// Partial header available - read just these bytes for now
4386+
peek_data.len()
4387+
};
4388+
4389+
drop(peek_data);
4390+
drop(peeked);
4391+
4392+
// Now consume exactly one TLS record from the socket
4393+
match self.sock_recv(recv_size, vm) {
4394+
Ok(bytes_obj) => {
4395+
let bytes = ArgBytesLike::try_from_object(vm, bytes_obj)?;
4396+
let data = bytes.borrow_buf();
4397+
4398+
if data.is_empty() {
4399+
return Ok(true);
4400+
}
4401+
4402+
let data_slice: &[u8] = data.as_ref();
4403+
let mut cursor = std::io::Cursor::new(data_slice);
4404+
let _ = conn.read_tls(&mut cursor);
43194405
let _ = conn.process_new_packets();
43204406
Ok(false)
43214407
}
43224408
Err(e) => {
4323-
// BlockingIOError means no data yet
43244409
if is_blocking_io_error(&e, vm) {
43254410
return Ok(false);
43264411
}
4327-
// Connection reset, EOF, or other error means peer closed
4328-
// ECONNRESET, EPIPE, broken pipe, etc.
43294412
Ok(true)
43304413
}
43314414
}

extra_tests/snippets/builtin_list.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ def __gt__(self, other):
270270
lst.sort(key=C)
271271
assert lst == [1, 2, 3, 4, 5]
272272

273+
273274
# Test that sorted() uses __lt__ (not __gt__) for comparisons.
274275
# Track which comparison method is actually called during sort.
275276
class TrackComparison:
@@ -287,29 +288,36 @@ def __gt__(self, other):
287288
TrackComparison.gt_calls += 1
288289
return self.value > other.value
289290

291+
290292
# Reset and test sorted()
291293
TrackComparison.lt_calls = 0
292294
TrackComparison.gt_calls = 0
293295
items = [TrackComparison(3), TrackComparison(1), TrackComparison(2)]
294296
sorted(items)
295297
assert TrackComparison.lt_calls > 0, "sorted() should call __lt__"
296-
assert TrackComparison.gt_calls == 0, f"sorted() should not call __gt__, but it was called {TrackComparison.gt_calls} times"
298+
assert TrackComparison.gt_calls == 0, (
299+
f"sorted() should not call __gt__, but it was called {TrackComparison.gt_calls} times"
300+
)
297301

298302
# Reset and test list.sort()
299303
TrackComparison.lt_calls = 0
300304
TrackComparison.gt_calls = 0
301305
items = [TrackComparison(3), TrackComparison(1), TrackComparison(2)]
302306
items.sort()
303307
assert TrackComparison.lt_calls > 0, "list.sort() should call __lt__"
304-
assert TrackComparison.gt_calls == 0, f"list.sort() should not call __gt__, but it was called {TrackComparison.gt_calls} times"
308+
assert TrackComparison.gt_calls == 0, (
309+
f"list.sort() should not call __gt__, but it was called {TrackComparison.gt_calls} times"
310+
)
305311

306312
# Reset and test sorted(reverse=True) - should still use __lt__, not __gt__
307313
TrackComparison.lt_calls = 0
308314
TrackComparison.gt_calls = 0
309315
items = [TrackComparison(3), TrackComparison(1), TrackComparison(2)]
310316
sorted(items, reverse=True)
311317
assert TrackComparison.lt_calls > 0, "sorted(reverse=True) should call __lt__"
312-
assert TrackComparison.gt_calls == 0, f"sorted(reverse=True) should not call __gt__, but it was called {TrackComparison.gt_calls} times"
318+
assert TrackComparison.gt_calls == 0, (
319+
f"sorted(reverse=True) should not call __gt__, but it was called {TrackComparison.gt_calls} times"
320+
)
313321

314322
lst = [5, 1, 2, 3, 4]
315323

0 commit comments

Comments
 (0)