Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions Lib/test/test_sqlite3/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,6 @@ def test_null_character(self):
self.assertRaisesRegex(sqlite.ProgrammingError, "null char",
cur.execute, query)

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_surrogates(self):
con = sqlite.connect(":memory:")
self.assertRaises(UnicodeEncodeError, con, "select '\ud8ff'")
Expand Down
9 changes: 5 additions & 4 deletions stdlib/src/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,7 @@ mod _sqlite {
type Args = (PyStrRef,);

fn call(zelf: &Py<Self>, args: Self::Args, vm: &VirtualMachine) -> PyResult {
if let Some(stmt) = Statement::new(zelf, &args.0, vm)? {
if let Some(stmt) = Statement::new(zelf, args.0, vm)? {
Ok(stmt.into_ref(&vm.ctx).into())
} else {
Ok(vm.ctx.none())
Expand Down Expand Up @@ -1480,7 +1480,7 @@ mod _sqlite {
stmt.lock().reset();
}

let Some(stmt) = Statement::new(&zelf.connection, &sql, vm)? else {
let Some(stmt) = Statement::new(&zelf.connection, sql, vm)? else {
drop(inner);
return Ok(zelf);
};
Expand Down Expand Up @@ -1552,7 +1552,7 @@ mod _sqlite {
stmt.lock().reset();
}

let Some(stmt) = Statement::new(&zelf.connection, &sql, vm)? else {
let Some(stmt) = Statement::new(&zelf.connection, sql, vm)? else {
drop(inner);
return Ok(zelf);
};
Expand Down Expand Up @@ -2291,9 +2291,10 @@ mod _sqlite {
impl Statement {
fn new(
connection: &Connection,
sql: &PyStr,
sql: PyStrRef,
vm: &VirtualMachine,
) -> PyResult<Option<Self>> {
let sql = sql.try_into_utf8(vm)?;
let sql_cstr = sql.to_cstring(vm)?;
let sql_len = sql.byte_len() + 1;

Expand Down
49 changes: 43 additions & 6 deletions vm/src/builtins/str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ use rustpython_common::{
str::DeduceStrKind,
wtf8::{CodePoint, Wtf8, Wtf8Buf, Wtf8Chunk},
};
use std::sync::LazyLock;
use std::{borrow::Cow, char, fmt, ops::Range};
use std::{mem, sync::LazyLock};
use unic_ucd_bidi::BidiClass;
use unic_ucd_category::GeneralCategory;
use unic_ucd_ident::{is_xid_continue, is_xid_start};
Expand Down Expand Up @@ -80,6 +80,30 @@ impl fmt::Debug for PyStr {
}
}

#[repr(transparent)]
#[derive(Debug)]
pub struct PyUtf8Str(PyStr);

// TODO: Remove this Deref which may hide missing optimized methods of PyUtf8Str
impl std::ops::Deref for PyUtf8Str {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
impl std::ops::Deref for PyUtf8Str {
// TODO: Remove this Deref which may hide missing optimized methods of PyUtf8Str
impl std::ops::Deref for PyUtf8Str {

type Target = PyStr;
fn deref(&self) -> &Self::Target {
&self.0
}
}

impl PyUtf8Str {
/// Returns the underlying string slice.
pub fn as_str(&self) -> &str {
debug_assert!(
self.0.is_utf8(),
"PyUtf8Str invariant violated: inner string is not valid UTF-8"
);
// Safety: This is safe because the type invariant guarantees UTF-8 validity.
unsafe { self.0.to_str().unwrap_unchecked() }
}
}

impl AsRef<str> for PyStr {
#[track_caller] // <- can remove this once it doesn't panic
fn as_ref(&self) -> &str {
Expand Down Expand Up @@ -433,21 +457,29 @@ impl PyStr {
self.data.as_str()
}

pub fn try_to_str(&self, vm: &VirtualMachine) -> PyResult<&str> {
self.to_str().ok_or_else(|| {
fn ensure_valid_utf8(&self, vm: &VirtualMachine) -> PyResult<()> {
if self.is_utf8() {
Ok(())
} else {
let start = self
.as_wtf8()
.code_points()
.position(|c| c.to_char().is_none())
.unwrap();
vm.new_unicode_encode_error_real(
Err(vm.new_unicode_encode_error_real(
identifier!(vm, utf_8).to_owned(),
vm.ctx.new_str(self.data.clone()),
start,
start + 1,
vm.ctx.new_str("surrogates not allowed"),
)
})
))
}
}

pub fn try_to_str(&self, vm: &VirtualMachine) -> PyResult<&str> {
self.ensure_valid_utf8(vm)?;
// SAFETY: ensure_valid_utf8 passed, so unwrap is safe.
Ok(unsafe { self.to_str().unwrap_unchecked() })
}

pub fn to_string_lossy(&self) -> Cow<'_, str> {
Expand Down Expand Up @@ -1486,6 +1518,11 @@ impl PyStrRef {
s.push_wtf8(other);
*self = PyStr::from(s).into_ref(&vm.ctx);
}

pub fn try_into_utf8(self, vm: &VirtualMachine) -> PyResult<PyRef<PyUtf8Str>> {
self.ensure_valid_utf8(vm)?;
Ok(unsafe { mem::transmute::<PyRef<PyStr>, PyRef<PyUtf8Str>>(self) })
}
}

impl Representable for PyStr {
Expand Down
Loading