Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 163 additions & 33 deletions crates/derive-impl/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,85 @@ pub(crate) fn impl_pyclass_impl(attr: PunctuatedNestedMeta, item: Item) -> Resul
Ok(tokens)
}

/// Validates that when a base class is specified, the struct has the base type as its first field.
/// This ensures proper memory layout for subclassing (required for #[repr(transparent)] to work correctly).
fn validate_base_field(item: &Item, base_path: &syn::Path) -> Result<()> {
let Item::Struct(item_struct) = item else {
// Only validate structs - enums with base are already an error elsewhere
return Ok(());
};

// Get the base type name for error messages
let base_name = base_path
.segments
.last()
.map(|s| s.ident.to_string())
.unwrap_or_else(|| quote!(#base_path).to_string());

match &item_struct.fields {
syn::Fields::Named(fields) => {
let Some(first_field) = fields.named.first() else {
bail_span!(
item_struct,
"#[pyclass] with base = {base_name} requires the first field to be of type {base_name}, but the struct has no fields"
);
};
if !type_matches_path(&first_field.ty, base_path) {
bail_span!(
first_field,
"#[pyclass] with base = {base_name} requires the first field to be of type {base_name}"
);
}
}
syn::Fields::Unnamed(fields) => {
let Some(first_field) = fields.unnamed.first() else {
bail_span!(
item_struct,
"#[pyclass] with base = {base_name} requires the first field to be of type {base_name}, but the struct has no fields"
);
};
if !type_matches_path(&first_field.ty, base_path) {
bail_span!(
first_field,
"#[pyclass] with base = {base_name} requires the first field to be of type {base_name}"
);
}
}
syn::Fields::Unit => {
bail_span!(
item_struct,
"#[pyclass] with base = {base_name} requires the first field to be of type {base_name}, but the struct is a unit struct"
);
}
}

Ok(())
}

/// Check if a type matches a given path (handles simple cases like `Foo` or `path::to::Foo`)
fn type_matches_path(ty: &syn::Type, path: &syn::Path) -> bool {
// Compare by converting both to string representation for macro hygiene
let ty_str = quote!(#ty).to_string().replace(' ', "");
let path_str = quote!(#path).to_string().replace(' ', "");

// Check if both are the same or if the type ends with the path's last segment
if ty_str == path_str {
return true;
}

// Also match if just the last segment matches (e.g., foo::Bar matches Bar)
let syn::Type::Path(type_path) = ty else {
return false;
};
let Some(type_last) = type_path.path.segments.last() else {
return false;
};
let Some(path_last) = path.segments.last() else {
return false;
};
type_last.ident == path_last.ident
}

fn generate_class_def(
ident: &Ident,
name: &str,
Expand Down Expand Up @@ -339,7 +418,6 @@ fn generate_class_def(
} else {
quote!(false)
};
let basicsize = quote!(std::mem::size_of::<#ident>());
let is_pystruct = attrs.iter().any(|attr| {
attr.path().is_ident("derive")
&& if let Ok(Meta::List(l)) = attr.parse_meta() {
Expand All @@ -350,6 +428,25 @@ fn generate_class_def(
false
}
});
// Check if the type has #[repr(transparent)] - only then we can safely
// generate PySubclass impl (requires same memory layout as base type)
let is_repr_transparent = attrs.iter().any(|attr| {
attr.path().is_ident("repr")
&& if let Ok(Meta::List(l)) = attr.parse_meta() {
l.nested
.into_iter()
.any(|n| n.get_ident().is_some_and(|p| p == "transparent"))
} else {
false
}
});
// If repr(transparent) with a base, the type has the same memory layout as base,
// so basicsize should be 0 (no additional space beyond the base type)
let basicsize = if is_repr_transparent && base.is_some() {
quote!(0)
} else {
quote!(std::mem::size_of::<#ident>())
};
if base.is_some() && is_pystruct {
bail_span!(ident, "PyStructSequence cannot have `base` class attr",);
}
Expand Down Expand Up @@ -379,12 +476,31 @@ fn generate_class_def(
}
});

let base_or_object = if let Some(base) = base {
let base_or_object = if let Some(ref base) = base {
quote! { #base }
} else {
quote! { ::rustpython_vm::builtins::PyBaseObject }
};

// Generate PySubclass impl for #[repr(transparent)] types with base class
// (tuple struct assumed, so &self.0 works)
let subclass_impl = if !is_pystruct && is_repr_transparent {
base.as_ref().map(|typ| {
quote! {
impl ::rustpython_vm::class::PySubclass for #ident {
type Base = #typ;

#[inline]
fn as_base(&self) -> &Self::Base {
&self.0
}
}
}
})
} else {
None
};

Comment on lines +485 to +503
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

PySubclass::as_base() assumes tuple structs (self.0) but validate_base_field also allows named-field structs.
For #[repr(transparent)] struct S { base: Base } this will generate invalid code. Either (a) restrict repr(transparent)+base to tuple structs, or (b) generate as_base() using the actual first field (named vs unnamed) discovered during validation.

🤖 Prompt for AI Agents
In crates/derive-impl/src/pyclass.rs around lines 485-503, the generated
PySubclass::as_base() unconditionally uses tuple struct access (&self.0) which
is invalid for named-field repr(transparent) structs; update the codegen to use
the actual first field determined during validation: have the validator record
whether the single field is named or unnamed and its identifier or index, then
emit as_base() using &self.<field_ident> for named fields and &self.0 for
unnamed tuple fields (alternatively, you may tighten validation to only allow
tuple structs and emit a clear compile error for named fields).

let tokens = quote! {
impl ::rustpython_vm::class::PyClassDef for #ident {
const NAME: &'static str = #name;
Expand All @@ -409,6 +525,8 @@ fn generate_class_def(

#base_class
}

#subclass_impl
};
Ok(tokens)
}
Expand All @@ -426,11 +544,16 @@ pub(crate) fn impl_pyclass(attr: PunctuatedNestedMeta, item: Item) -> Result<Tok
let metaclass = class_meta.metaclass()?;
let unhashable = class_meta.unhashable()?;

// Validate that if base is specified, the first field must be of the base type
if let Some(ref base_path) = base {
validate_base_field(&item, base_path)?;
}

let class_def = generate_class_def(
ident,
&class_name,
module_name.as_deref(),
base,
base.clone(),
metaclass,
unhashable,
attrs,
Expand Down Expand Up @@ -485,19 +608,47 @@ pub(crate) fn impl_pyclass(attr: PunctuatedNestedMeta, item: Item) -> Result<Tok
}
};

let impl_payload = if let Some(ctx_type_name) = class_meta.ctx_name()? {
let ctx_type_ident = Ident::new(&ctx_type_name, ident.span()); // FIXME span
// Generate PyPayload impl based on whether base exists
#[allow(clippy::collapsible_else_if)]
let impl_payload = if let Some(base_type) = &base {
let class_fn = if let Some(ctx_type_name) = class_meta.ctx_name()? {
let ctx_type_ident = Ident::new(&ctx_type_name, ident.span());
quote! { ctx.types.#ctx_type_ident }
} else {
quote! { <Self as ::rustpython_vm::class::StaticType>::static_type() }
};

// We need this to make extend mechanism work:
quote! {
// static_assertions::const_assert!(std::mem::size_of::<#base_type>() <= std::mem::size_of::<#ident>());
impl ::rustpython_vm::PyPayload for #ident {
#[inline]
fn payload_type_id() -> ::std::any::TypeId {
<#base_type as ::rustpython_vm::PyPayload>::payload_type_id()
}

#[inline]
fn validate_downcastable_from(obj: &::rustpython_vm::PyObject) -> bool {
<Self as ::rustpython_vm::class::PyClassDef>::BASICSIZE <= obj.class().slots.basicsize && obj.class().fast_issubclass(<Self as ::rustpython_vm::class::StaticType>::static_type())
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I found another way to do with more tricks. Let's check it with OSError

}

fn class(ctx: &::rustpython_vm::vm::Context) -> &'static ::rustpython_vm::Py<::rustpython_vm::builtins::PyType> {
ctx.types.#ctx_type_ident
#class_fn
}
}
}
} else {
quote! {}
if let Some(ctx_type_name) = class_meta.ctx_name()? {
let ctx_type_ident = Ident::new(&ctx_type_name, ident.span());
quote! {
impl ::rustpython_vm::PyPayload for #ident {
fn class(ctx: &::rustpython_vm::vm::Context) -> &'static ::rustpython_vm::Py<::rustpython_vm::builtins::PyType> {
ctx.types.#ctx_type_ident
}
}
}
} else {
quote! {}
}
};

let empty_impl = if let Some(attrs) = class_meta.impl_attrs()? {
Expand Down Expand Up @@ -536,26 +687,6 @@ pub(crate) fn impl_pyexception(attr: PunctuatedNestedMeta, item: Item) -> Result
let class_name = class_meta.class_name()?;

let base_class_name = class_meta.base()?;
let impl_payload = if let Some(ctx_type_name) = class_meta.ctx_name()? {
let ctx_type_ident = Ident::new(&ctx_type_name, ident.span()); // FIXME span

// We need this to make extend mechanism work:
quote! {
impl ::rustpython_vm::PyPayload for #ident {
fn class(ctx: &::rustpython_vm::vm::Context) -> &'static ::rustpython_vm::Py<::rustpython_vm::builtins::PyType> {
ctx.exceptions.#ctx_type_ident
}
}
}
} else {
quote! {
impl ::rustpython_vm::PyPayload for #ident {
fn class(_ctx: &::rustpython_vm::vm::Context) -> &'static ::rustpython_vm::Py<::rustpython_vm::builtins::PyType> {
<Self as ::rustpython_vm::class::StaticType>::static_type()
}
}
}
};
let impl_pyclass = if class_meta.has_impl()? {
quote! {
#[pyexception]
Expand All @@ -568,7 +699,6 @@ pub(crate) fn impl_pyexception(attr: PunctuatedNestedMeta, item: Item) -> Result
let ret = quote! {
#[pyclass(module = false, name = #class_name, base = #base_class_name)]
#item
#impl_payload
#impl_pyclass
};
Ok(ret)
Expand All @@ -585,7 +715,8 @@ pub(crate) fn impl_pyexception_impl(attr: PunctuatedNestedMeta, item: Item) -> R
let mut extra_attrs = Vec::new();
for nested in &attr {
if let NestedMeta::Meta(Meta::List(MetaList { path, nested, .. })) = nested {
if path.is_ident("with") {
// If we already found the constructor trait, no need to keep looking for it
if !has_slot_new && path.is_ident("with") {
// Check if Constructor is in the list
for meta in nested {
if let NestedMeta::Meta(Meta::Path(p)) = meta
Expand Down Expand Up @@ -1078,9 +1209,8 @@ impl GetSetNursery {
item_ident: Ident,
) -> Result<()> {
assert!(!self.validated, "new item is not allowed after validation");
if !matches!(kind, GetSetItemKind::Get) && !cfgs.is_empty() {
bail_span!(item_ident, "Only the getter can have #[cfg]",);
}
// Note: Both getter and setter can have #[cfg], but they must have matching cfgs
// since the map key is (name, cfgs). This ensures getter and setter are paired correctly.
let entry = self.map.entry((name.clone(), cfgs)).or_default();
let func = match kind {
GetSetItemKind::Get => &mut entry.0,
Expand Down
39 changes: 34 additions & 5 deletions crates/derive-impl/src/pystructseq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -446,8 +446,10 @@ pub(crate) fn impl_pystruct_sequence(
};

let output = quote! {
// The Python type struct (user-defined, possibly empty)
#pytype_vis struct #pytype_ident;
// The Python type struct - newtype wrapping PyTuple
#[derive(Debug)]
#[repr(transparent)]
#pytype_vis struct #pytype_ident(pub ::rustpython_vm::builtins::PyTuple);

// PyClassDef for Python type
impl ::rustpython_vm::class::PyClassDef for #pytype_ident {
Expand Down Expand Up @@ -476,10 +478,37 @@ pub(crate) fn impl_pystruct_sequence(
}
}

// MaybeTraverse (empty - no GC fields in empty struct)
// Subtype uses base type's payload_type_id
impl ::rustpython_vm::PyPayload for #pytype_ident {
#[inline]
fn payload_type_id() -> ::std::any::TypeId {
<::rustpython_vm::builtins::PyTuple as ::rustpython_vm::PyPayload>::payload_type_id()
}

#[inline]
fn validate_downcastable_from(obj: &::rustpython_vm::PyObject) -> bool {
obj.class().fast_issubclass(<Self as ::rustpython_vm::class::StaticType>::static_type())
}

fn class(_ctx: &::rustpython_vm::vm::Context) -> &'static ::rustpython_vm::Py<::rustpython_vm::builtins::PyType> {
<Self as ::rustpython_vm::class::StaticType>::static_type()
}
}

// MaybeTraverse - delegate to inner PyTuple
impl ::rustpython_vm::object::MaybeTraverse for #pytype_ident {
fn try_traverse(&self, _traverse_fn: &mut ::rustpython_vm::object::TraverseFn<'_>) {
// Empty struct has no fields to traverse
fn try_traverse(&self, traverse_fn: &mut ::rustpython_vm::object::TraverseFn<'_>) {
self.0.try_traverse(traverse_fn)
}
}

// PySubclass for proper inheritance
impl ::rustpython_vm::class::PySubclass for #pytype_ident {
type Base = ::rustpython_vm::builtins::PyTuple;

#[inline]
fn as_base(&self) -> &Self::Base {
&self.0
}
}

Expand Down
22 changes: 12 additions & 10 deletions crates/stdlib/src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ mod _socket {
use crate::common::lock::{PyMappedRwLockReadGuard, PyRwLock, PyRwLockReadGuard};
use crate::vm::{
AsObject, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
builtins::{PyBaseExceptionRef, PyListRef, PyStrRef, PyTupleRef, PyTypeRef},
builtins::{PyBaseExceptionRef, PyListRef, PyOSError, PyStrRef, PyTupleRef, PyTypeRef},
common::os::ErrorExt,
convert::{IntoPyException, ToPyObject, TryFromBorrowedObject, TryFromObject},
function::{ArgBytesLike, ArgMemoryBuffer, Either, FsPath, OptionalArg, OptionalOption},
Expand Down Expand Up @@ -1826,6 +1826,11 @@ mod _socket {
Self::Py(exc)
}
}
impl From<PyRef<PyOSError>> for IoOrPyException {
fn from(exc: PyRef<PyOSError>) -> Self {
Self::Py(exc.upcast())
}
}
impl From<io::Error> for IoOrPyException {
fn from(err: io::Error) -> Self {
Self::Io(err)
Expand All @@ -1844,7 +1849,7 @@ mod _socket {
#[inline]
fn into_pyexception(self, vm: &VirtualMachine) -> PyBaseExceptionRef {
match self {
Self::Timeout => timeout_error(vm),
Self::Timeout => timeout_error(vm).upcast(),
Self::Py(exc) => exc,
Self::Io(err) => err.into_pyexception(vm),
}
Expand Down Expand Up @@ -2412,18 +2417,15 @@ mod _socket {
SocketError::GaiError => gaierror(vm),
SocketError::HError => herror(vm),
};
vm.new_exception(
exception_cls,
vec![vm.new_pyobj(err.error_num()), vm.ctx.new_str(strerr).into()],
)
.into()
vm.new_os_subtype_error(exception_cls, Some(err.error_num()), strerr)
.into()
}

fn timeout_error(vm: &VirtualMachine) -> PyBaseExceptionRef {
fn timeout_error(vm: &VirtualMachine) -> PyRef<PyOSError> {
timeout_error_msg(vm, "timed out".to_owned())
}
pub(crate) fn timeout_error_msg(vm: &VirtualMachine, msg: String) -> PyBaseExceptionRef {
vm.new_exception_msg(timeout(vm), msg)
pub(crate) fn timeout_error_msg(vm: &VirtualMachine, msg: String) -> PyRef<PyOSError> {
vm.new_os_subtype_error(timeout(vm), None, msg)
}

fn get_ipv6_addr_str(ipv6: Ipv6Addr) -> String {
Expand Down
Loading
Loading