Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/codegen/src/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use num_traits::ToPrimitive;
use rustpython_ast as ast;
use rustpython_compiler_core::{
self as bytecode, Arg as OpArgMarker, CodeObject, ConstantData, Instruction, Location, NameIdx,
OpArg, OpArgType,
OpArg, OpArgType, UnsafeCodeObject,
};
use std::borrow::Cow;

Expand Down
1 change: 1 addition & 0 deletions compiler/codegen/src/ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ impl CodeInfo {
locations.clear()
}

// SAFETY: we assume that the compiler produces correct output
CodeObject {
flags,
posonlyarg_count,
Expand Down
97 changes: 88 additions & 9 deletions compiler/core/src/bytecode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ impl ConstantBag for BasicBag {

/// Primary container of a single code object. Each python function has
/// a codeobject. Also a module has a codeobject.
#[derive(Clone, Serialize, Deserialize)]
/// Once UnsafeCodeObject is turned into CodeObject, it guarantees correctness of all
/// aspects of the contained bytecode, in order for the VM to make optimizations
/// that skip safety checks.
#[non_exhaustive] // to prevent manual construction by-passing UnsafeCodeObject
#[derive(Clone, serde::Serialize, serde::Deserialize)]
pub struct CodeObject<C: Constant = ConstantData> {
pub instructions: Box<[CodeUnit]>,
pub locations: Box<[Location]>,
Expand Down Expand Up @@ -864,6 +868,17 @@ impl<N: AsRef<str>> fmt::Debug for Arguments<'_, N> {
}

impl<C: Constant> CodeObject<C> {
/// Create a new CodeObject
///
/// # Safety
///
/// Caller is responsible for ensuring that `code` is correct python bytecode.
/// The easiest way to do so is get a CodeObject from the
/// `rustpython-compiler` crate.
pub unsafe fn new(code: UnsafeCodeObject<C>) -> Self {
code.0
}

/// Get all arguments of the code object
/// like inspect.getargs
pub fn arg_names(&self) -> Arguments<C::Name> {
Expand Down Expand Up @@ -1031,6 +1046,16 @@ impl<C: Constant> CodeObject<C> {
cell2arg: self.cell2arg.clone(),
}
}

/// Serialize this bytecode to bytes.
pub fn to_bytes(&self) -> Vec<u8>
where
C: serde::Serialize,
C::Name: serde::Serialize,
{
let data = bincode::serialize(&self).expect("CodeObject is not serializable");
lz4_flex::compress_prepend_size(&data)
}
}

/// Error that occurs during code deserialization
Expand All @@ -1054,7 +1079,7 @@ impl fmt::Display for CodeDeserializeError {

impl std::error::Error for CodeDeserializeError {}

impl CodeObject<ConstantData> {
impl UnsafeCodeObject<ConstantData> {
/// Load a code object from bytes
pub fn from_bytes(data: &[u8]) -> Result<Self, CodeDeserializeError> {
use lz4_flex::block::DecompressError;
Expand All @@ -1072,12 +1097,6 @@ impl CodeObject<ConstantData> {
})?;
Ok(data)
}

/// Serialize this bytecode to bytes.
pub fn to_bytes(&self) -> Vec<u8> {
let data = bincode::serialize(&self).expect("CodeObject is not serializable");
lz4_flex::compress_prepend_size(&data)
}
}

impl<C: Constant> fmt::Display for CodeObject<C> {
Expand Down Expand Up @@ -1446,6 +1465,21 @@ impl<C: Constant> InstrDisplayContext for CodeObject<C> {
.as_ref()
}
}
impl<C: Constant> InstrDisplayContext for UnsafeCodeObject<C> {
type Constant = C;
fn get_constant(&self, i: usize) -> &C {
(**self).get_constant(i)
}
fn get_name(&self, i: usize) -> &str {
(**self).get_name(i)
}
fn get_varname(&self, i: usize) -> &str {
(**self).get_varname(i)
}
fn get_cellname(&self, i: usize) -> &str {
(**self).get_cellname(i)
}
}

impl fmt::Display for ConstantData {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
Expand Down Expand Up @@ -1478,7 +1512,10 @@ pub mod frozen_lib {
use std::io;

/// Decode a library to a iterable of frozen modules
pub fn decode_lib(bytes: &[u8]) -> FrozenModulesIter {
///
/// # Safety
/// `bytes` must be the output from [`encode_lib`].
pub unsafe fn decode_lib(bytes: &[u8]) -> FrozenModulesIter {
let data = lz4_flex::decompress_size_prepended(bytes).unwrap();
let r = VecReader { data, pos: 0 };
let mut de = bincode::Deserializer::with_bincode_read(r, options());
Expand Down Expand Up @@ -1601,3 +1638,45 @@ pub mod frozen_lib {
}
}
}

/// A wrapper around [`CodeObject`] that doesn't guarantee correctness of all
/// aspects of the contained bytecode.
#[derive(serde::Serialize, serde::Deserialize)]
#[repr(transparent)]
pub struct UnsafeCodeObject<C: Constant = ConstantData>(
#[serde(bound(
deserialize = "CodeObject<C>: serde::Deserialize<'de>",
serialize = "CodeObject<C>: serde::Serialize"
))]
CodeObject<C>,
);

impl<C: Constant> Clone for UnsafeCodeObject<C>
where
CodeObject<C>: Clone,
{
fn clone(&self) -> Self {
Self(self.0.clone())
}
}

impl<C: Constant> fmt::Display for UnsafeCodeObject<C> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(&**self, f)
}
}

impl<C: Constant> fmt::Debug for UnsafeCodeObject<C> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&**self, f)
}
}

impl<C: Constant> std::ops::Deref for UnsafeCodeObject<C> {
type Target = CodeObject<C>;
fn deref(&self) -> &Self::Target {
&self.0
}
}

impl<C: Constant> UnsafeCodeObject<C> {}
12 changes: 9 additions & 3 deletions derive-impl/src/compile_bytecode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,8 +373,13 @@ pub fn impl_py_compile(
let bytes = LitByteStr::new(&bytes, Span::call_site());

let output = quote! {
#crate_name::CodeObject::from_bytes(#bytes)
.expect("Deserializing CodeObject failed")
unsafe {
// SAFETY: we just deserialized from a serialized CodeObject
#crate_name::CodeObject::new(
#crate_name::UnsafeCodeObject::from_bytes(#bytes)
.expect("Deserializing UnsafeCodeObject failed")
)
}
};

Ok(output)
Expand All @@ -395,7 +400,8 @@ pub fn impl_py_freeze(
let bytes = LitByteStr::new(&data, Span::call_site());

let output = quote! {
#crate_name::frozen_lib::decode_lib(#bytes)
// SAFETY: we just deserialized from a serialized CodeObject
unsafe { #crate_name::frozen_lib::decode_lib(#bytes) }
};

Ok(output)
Expand Down
54 changes: 27 additions & 27 deletions vm/src/builtins/code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,33 +344,33 @@ impl PyRef<PyCode> {
OptionalArg::Missing => self.code.varnames.iter().map(|s| s.to_object()).collect(),
};

Ok(PyCode {
code: CodeObject {
flags: CodeFlags::from_bits_truncate(flags),
posonlyarg_count,
arg_count,
kwonlyarg_count,
source_path: source_path.as_object().as_interned_str(vm).unwrap(),
first_line_number,
obj_name: obj_name.as_object().as_interned_str(vm).unwrap(),

max_stackdepth: self.code.max_stackdepth,
instructions: self.code.instructions.clone(),
locations: self.code.locations.clone(),
constants: constants.into_iter().map(Literal).collect(),
names: names
.into_iter()
.map(|o| o.as_interned_str(vm).unwrap())
.collect(),
varnames: varnames
.into_iter()
.map(|o| o.as_interned_str(vm).unwrap())
.collect(),
cellvars: self.code.cellvars.clone(),
freevars: self.code.freevars.clone(),
cell2arg: self.code.cell2arg.clone(),
},
})
// SAFETY: none, really, but this is something cpython lets people do, so ¯\_(ツ)_/¯
let code = CodeObject {
flags: CodeFlags::from_bits_truncate(flags),
posonlyarg_count,
arg_count,
kwonlyarg_count,
source_path: source_path.as_object().as_interned_str(vm).unwrap(),
first_line_number,
obj_name: obj_name.as_object().as_interned_str(vm).unwrap(),

max_stackdepth: self.code.max_stackdepth,
instructions: self.code.instructions.clone(),
locations: self.code.locations.clone(),
constants: constants.into_iter().map(Literal).collect(),
names: names
.into_iter()
.map(|o| o.as_interned_str(vm).unwrap())
.collect(),
varnames: varnames
.into_iter()
.map(|o| o.as_interned_str(vm).unwrap())
.collect(),
cellvars: self.code.cellvars.clone(),
freevars: self.code.freevars.clone(),
cell2arg: self.code.cell2arg.clone(),
};
Ok(PyCode { code })
}
}

Expand Down
6 changes: 3 additions & 3 deletions vm/src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,7 @@ impl ExecutingFrame<'_> {
let item = self.pop_value();
let obj = self.nth_value(i.get(arg));
let list: &Py<PyList> = unsafe {
// SAFETY: trust compiler
// SAFETY: invariants of CodeObject
obj.downcast_unchecked_ref()
};
list.append(item);
Expand All @@ -749,7 +749,7 @@ impl ExecutingFrame<'_> {
let item = self.pop_value();
let obj = self.nth_value(i.get(arg));
let set: &Py<PySet> = unsafe {
// SAFETY: trust compiler
// SAFETY: invariants of CodeObject
obj.downcast_unchecked_ref()
};
set.add(item, vm)?;
Expand All @@ -760,7 +760,7 @@ impl ExecutingFrame<'_> {
let key = self.pop_value();
let obj = self.nth_value(i.get(arg));
let dict: &Py<PyDict> = unsafe {
// SAFETY: trust compiler
// SAFETY: invariants of CodeObject
obj.downcast_unchecked_ref()
};
dict.set_item(&*key, value, vm)?;
Expand Down
5 changes: 4 additions & 1 deletion vm/src/stdlib/marshal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,13 +378,16 @@ mod decl {
return Err(too_short_error(vm));
}
let (bytes, buf) = buf.split_at(len);
let code = bytecode::CodeObject::from_bytes(bytes).map_err(|e| match e {
let code = bytecode::UnsafeCodeObject::from_bytes(bytes).map_err(|e| match e {
bytecode::CodeDeserializeError::Eof => vm.new_exception_msg(
vm.ctx.exceptions.eof_error.to_owned(),
"End of file while deserializing bytecode".to_owned(),
),
_ => vm.new_value_error("Couldn't deserialize python bytecode".to_owned()),
})?;
// SAFETY: none, really 😬 but CPython trusts Python to give it invalid bytecode and do all
// kinds of other unsafe actions, so we'll just sorta follow suit
let code = unsafe { bytecode::CodeObject::new(code) };
(vm.ctx.new_code(code).into(), buf)
}
};
Expand Down