Skip to content

Commit 79abbf0

Browse files
committed
fix ctypes
1 parent 4bf0bac commit 79abbf0

8 files changed

Lines changed: 973 additions & 473 deletions

File tree

crates/vm/src/stdlib/ctypes.rs

Lines changed: 76 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ pub(crate) fn make_module(vm: &VirtualMachine) -> PyRef<PyModule> {
9595
pointer::PyCPointerType::make_class(ctx);
9696
structure::PyCStructType::make_class(ctx);
9797
union::PyCUnionType::make_class(ctx);
98+
function::PyCFuncPtrType::make_class(ctx);
9899
extend_module!(vm, &module, {
99100
"_CData" => PyCData::make_class(ctx),
100101
"_SimpleCData" => PyCSimple::make_class(ctx),
@@ -385,12 +386,8 @@ pub(crate) mod _ctypes {
385386
#[pyattr]
386387
const RTLD_GLOBAL: i32 = 0;
387388

388-
#[cfg(target_os = "windows")]
389-
#[pyattr]
390-
const SIZEOF_TIME_T: usize = 8;
391-
#[cfg(not(target_os = "windows"))]
392389
#[pyattr]
393-
const SIZEOF_TIME_T: usize = 4;
390+
const SIZEOF_TIME_T: usize = std::mem::size_of::<libc::time_t>();
394391

395392
#[pyattr]
396393
const CTYPES_MAX_ARGCOUNT: usize = 1024;
@@ -578,30 +575,42 @@ pub(crate) mod _ctypes {
578575
#[pyfunction(name = "dlopen")]
579576
fn load_library_unix(
580577
name: Option<crate::function::FsPath>,
581-
_load_flags: OptionalArg<i32>,
578+
load_flags: OptionalArg<i32>,
582579
vm: &VirtualMachine,
583580
) -> PyResult<usize> {
584-
// TODO: audit functions first
585-
// TODO: load_flags
581+
// Default mode: RTLD_NOW | RTLD_LOCAL, always force RTLD_NOW
582+
let mode = load_flags.unwrap_or(libc::RTLD_NOW | libc::RTLD_LOCAL) | libc::RTLD_NOW;
583+
586584
match name {
587585
Some(name) => {
588586
let cache = library::libcache();
589587
let mut cache_write = cache.write();
590588
let os_str = name.as_os_str(vm)?;
591-
let (id, _) = cache_write.get_or_insert_lib(&*os_str, vm).map_err(|e| {
592-
// Include filename in error message for better diagnostics
593-
let name_str = os_str.to_string_lossy();
594-
vm.new_os_error(format!("{}: {}", name_str, e))
595-
})?;
589+
let (id, _) = cache_write
590+
.get_or_insert_lib_with_mode(&*os_str, mode, vm)
591+
.map_err(|e| {
592+
let name_str = os_str.to_string_lossy();
593+
vm.new_os_error(format!("{}: {}", name_str, e))
594+
})?;
596595
Ok(id)
597596
}
598597
None => {
599-
// If None, call libc::dlopen(null, mode) to get the current process handle
600-
let handle = unsafe { libc::dlopen(std::ptr::null(), libc::RTLD_NOW) };
598+
// dlopen(NULL, mode) to get the current process handle (for pythonapi)
599+
let handle = unsafe { libc::dlopen(std::ptr::null(), mode) };
601600
if handle.is_null() {
602-
return Err(vm.new_os_error("dlopen() error"));
601+
let err = unsafe { libc::dlerror() };
602+
let msg = if err.is_null() {
603+
"dlopen() error".to_string()
604+
} else {
605+
unsafe { std::ffi::CStr::from_ptr(err).to_string_lossy().into_owned() }
606+
};
607+
return Err(vm.new_os_error(msg));
603608
}
604-
Ok(handle as usize)
609+
// Add to library cache so symbol lookup works
610+
let cache = library::libcache();
611+
let mut cache_write = cache.write();
612+
let id = cache_write.insert_raw_handle(handle);
613+
Ok(id)
605614
}
606615
}
607616
}
@@ -614,6 +623,48 @@ pub(crate) mod _ctypes {
614623
Ok(())
615624
}
616625

626+
#[cfg(not(windows))]
627+
#[pyfunction]
628+
fn dlclose(handle: usize, _vm: &VirtualMachine) -> PyResult<()> {
629+
// Remove from cache, which triggers SharedLibrary drop.
630+
// libloading::Library calls dlclose automatically on Drop.
631+
let cache = library::libcache();
632+
let mut cache_write = cache.write();
633+
cache_write.drop_lib(handle);
634+
Ok(())
635+
}
636+
637+
#[cfg(not(windows))]
638+
#[pyfunction]
639+
fn dlsym(
640+
handle: usize,
641+
name: crate::builtins::PyStrRef,
642+
vm: &VirtualMachine,
643+
) -> PyResult<usize> {
644+
let symbol_name = std::ffi::CString::new(name.as_str())
645+
.map_err(|_| vm.new_value_error("symbol name contains null byte"))?;
646+
647+
// Clear previous error
648+
unsafe { libc::dlerror() };
649+
650+
let ptr = unsafe { libc::dlsym(handle as *mut libc::c_void, symbol_name.as_ptr()) };
651+
652+
// Check for error via dlerror first
653+
let err = unsafe { libc::dlerror() };
654+
if !err.is_null() {
655+
let msg = unsafe { std::ffi::CStr::from_ptr(err).to_string_lossy().into_owned() };
656+
return Err(vm.new_os_error(msg));
657+
}
658+
659+
// Treat NULL symbol address as error
660+
// This handles cases like GNU IFUNCs that resolve to NULL
661+
if ptr.is_null() {
662+
return Err(vm.new_os_error(format!("symbol '{}' not found", name.as_str())));
663+
}
664+
665+
Ok(ptr as usize)
666+
}
667+
617668
#[pyfunction(name = "POINTER")]
618669
fn create_pointer_type(cls: PyObjectRef, vm: &VirtualMachine) -> PyResult {
619670
use crate::builtins::PyStr;
@@ -905,25 +956,24 @@ pub(crate) mod _ctypes {
905956

906957
#[pyfunction]
907958
fn get_errno() -> i32 {
908-
errno::errno().0
959+
super::function::get_errno_value()
909960
}
910961

911962
#[pyfunction]
912-
fn set_errno(value: i32) {
913-
errno::set_errno(errno::Errno(value));
963+
fn set_errno(value: i32) -> i32 {
964+
super::function::set_errno_value(value)
914965
}
915966

916967
#[cfg(windows)]
917968
#[pyfunction]
918969
fn get_last_error() -> PyResult<u32> {
919-
Ok(unsafe { windows_sys::Win32::Foundation::GetLastError() })
970+
Ok(super::function::get_last_error_value())
920971
}
921972

922973
#[cfg(windows)]
923974
#[pyfunction]
924-
fn set_last_error(value: u32) -> PyResult<()> {
925-
unsafe { windows_sys::Win32::Foundation::SetLastError(value) };
926-
Ok(())
975+
fn set_last_error(value: u32) -> u32 {
976+
super::function::set_last_error_value(value)
927977
}
928978

929979
#[pyattr]
@@ -1084,9 +1134,9 @@ pub(crate) mod _ctypes {
10841134
ffi_args.push(Arg::new(val));
10851135
}
10861136

1087-
let cif = Cif::new(arg_types, Type::isize());
1137+
let cif = Cif::new(arg_types, Type::c_int());
10881138
let code_ptr = CodePtr::from_ptr(func_addr as *const _);
1089-
let result: isize = unsafe { cif.call(code_ptr, &ffi_args) };
1139+
let result: libc::c_int = unsafe { cif.call(code_ptr, &ffi_args) };
10901140
Ok(vm.ctx.new_int(result).into())
10911141
}
10921142

crates/vm/src/stdlib/ctypes/array.rs

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,32 @@
11
use super::StgInfo;
22
use super::base::{CDATA_BUFFER_METHODS, PyCData};
3+
use super::type_info;
34
use crate::{
45
AsObject, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromObject, VirtualMachine,
56
atomic_func,
6-
builtins::{PyBytes, PyInt, PyList, PySlice, PyStr, PyType, PyTypeRef},
7+
builtins::{
8+
PyBytes, PyInt, PyList, PySlice, PyStr, PyType, PyTypeRef, genericalias::PyGenericAlias,
9+
},
710
class::StaticType,
811
function::{ArgBytesLike, FuncArgs, PySetterValue},
912
protocol::{BufferDescriptor, PyBuffer, PyNumberMethods, PySequenceMethods},
1013
types::{AsBuffer, AsNumber, AsSequence, Constructor, Initializer},
1114
};
1215
use num_traits::{Signed, ToPrimitive};
1316

17+
/// Get itemsize from a PEP 3118 format string
18+
/// Extracts the type code (last char after endianness prefix) and returns its size
19+
fn get_size_from_format(fmt: &str) -> usize {
20+
// Format is like "<f", ">q", etc. - strip endianness prefix and get type code
21+
let code = fmt
22+
.trim_start_matches(['<', '>', '@', '=', '!', '&'])
23+
.chars()
24+
.next()
25+
.map(|c| c.to_string());
26+
code.map(|c| type_info(&c).map(|t| t.size).unwrap_or(1))
27+
.unwrap_or(1)
28+
}
29+
1430
/// Creates array type for (element_type, length)
1531
/// Uses _array_type_cache to ensure identical calls return the same type object
1632
pub(super) fn array_type_from_ctype(
@@ -444,6 +460,11 @@ impl AsSequence for PyCArray {
444460
with(Constructor, AsSequence, AsBuffer)
445461
)]
446462
impl PyCArray {
463+
#[pyclassmethod]
464+
fn __class_getitem__(cls: PyTypeRef, args: PyObjectRef, vm: &VirtualMachine) -> PyGenericAlias {
465+
PyGenericAlias::from_args(cls, args, vm)
466+
}
467+
447468
fn int_to_bytes(i: &malachite_bigint::BigInt, size: usize) -> Vec<u8> {
448469
// Try unsigned first (handles values like 0xFFFFFFFF that overflow signed)
449470
// then fall back to signed (handles negative values)
@@ -1056,27 +1077,38 @@ impl AsBuffer for PyCArray {
10561077
.expect("PyCArray type must have StgInfo");
10571078
let format = stg_info.format.clone();
10581079
let shape = stg_info.shape.clone();
1059-
let element_size = stg_info.element_size;
10601080

10611081
let desc = if let Some(fmt) = format
10621082
&& !shape.is_empty()
10631083
{
1084+
// itemsize is the size of the base element type (item_info->size)
1085+
// For empty arrays, we still need the element size, not 0
1086+
let total_elements: usize = shape.iter().product();
1087+
let has_zero_dim = shape.contains(&0);
1088+
let itemsize = if total_elements > 0 && buffer_len > 0 {
1089+
buffer_len / total_elements
1090+
} else {
1091+
// For empty arrays, get itemsize from format type code
1092+
get_size_from_format(&fmt)
1093+
};
1094+
10641095
// Build dim_desc from shape (C-contiguous: row-major order)
10651096
// stride[i] = product(shape[i+1:]) * itemsize
1097+
// For empty arrays (any dimension is 0), all strides are 0
10661098
let mut dim_desc = Vec::with_capacity(shape.len());
1067-
let mut stride = element_size as isize;
1099+
let mut stride = itemsize as isize;
10681100

1069-
// Calculate strides from innermost to outermost dimension
10701101
for &dim_size in shape.iter().rev() {
1071-
dim_desc.push((dim_size, stride, 0));
1102+
let current_stride = if has_zero_dim { 0 } else { stride };
1103+
dim_desc.push((dim_size, current_stride, 0));
10721104
stride *= dim_size as isize;
10731105
}
10741106
dim_desc.reverse();
10751107

10761108
BufferDescriptor {
10771109
len: buffer_len,
10781110
readonly: false,
1079-
itemsize: element_size,
1111+
itemsize,
10801112
format: std::borrow::Cow::Owned(fmt),
10811113
dim_desc,
10821114
}

crates/vm/src/stdlib/ctypes/base.rs

Lines changed: 49 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -300,15 +300,27 @@ pub(super) fn get_field_format(
300300
big_endian: bool,
301301
vm: &VirtualMachine,
302302
) -> String {
303+
let endian_prefix = if big_endian { ">" } else { "<" };
304+
303305
// 1. Check StgInfo for format
304306
if let Some(type_obj) = field_type.downcast_ref::<PyType>()
305307
&& let Some(stg_info) = type_obj.stg_info_opt()
306308
&& let Some(fmt) = &stg_info.format
307309
{
308-
// Handle endian prefix for simple types
309-
if fmt.len() == 1 {
310-
let endian_prefix = if big_endian { ">" } else { "<" };
311-
return format!("{}{}", endian_prefix, fmt);
310+
// For structures (T{...}), arrays ((n)...), and pointers (&...), return as-is
311+
// These complex types have their own endianness markers inside
312+
if fmt.starts_with('T')
313+
|| fmt.starts_with('(')
314+
|| fmt.starts_with('&')
315+
|| fmt.starts_with("X{")
316+
{
317+
return fmt.clone();
318+
}
319+
320+
// For simple types, replace existing endian prefix with the correct one
321+
let base_fmt = fmt.trim_start_matches(['<', '>', '@', '=', '!']);
322+
if !base_fmt.is_empty() {
323+
return format!("{}{}", endian_prefix, base_fmt);
312324
}
313325
return fmt.clone();
314326
}
@@ -318,8 +330,7 @@ pub(super) fn get_field_format(
318330
&& let Some(type_str) = type_attr.downcast_ref::<PyStr>()
319331
{
320332
let s = type_str.as_str();
321-
if s.len() == 1 {
322-
let endian_prefix = if big_endian { ">" } else { "<" };
333+
if !s.is_empty() {
323334
return format!("{}{}", endian_prefix, s);
324335
}
325336
return s.to_string();
@@ -1168,29 +1179,30 @@ impl PyCData {
11681179
.ok_or_else(|| vm.new_value_error("Invalid library handle"))?
11691180
};
11701181

1171-
// Get symbol address using platform-specific API
1172-
let symbol_name = std::ffi::CString::new(name.as_str())
1173-
.map_err(|_| vm.new_value_error("Invalid symbol name"))?;
1174-
1175-
#[cfg(windows)]
1176-
let ptr: *const u8 = unsafe {
1177-
match windows_sys::Win32::System::LibraryLoader::GetProcAddress(
1178-
handle as windows_sys::Win32::Foundation::HMODULE,
1179-
symbol_name.as_ptr() as *const u8,
1180-
) {
1181-
Some(p) => p as *const u8,
1182-
None => std::ptr::null(),
1182+
// Look up the library in the cache and use lib.get() for symbol lookup
1183+
let library_cache = super::library::libcache().read();
1184+
let library = library_cache
1185+
.get_lib(handle)
1186+
.ok_or_else(|| vm.new_value_error("Library not found"))?;
1187+
let inner_lib = library.lib.lock();
1188+
1189+
let symbol_name_with_nul = format!("{}\0", name.as_str());
1190+
let ptr: *const u8 = if let Some(lib) = &*inner_lib {
1191+
unsafe {
1192+
lib.get::<*const u8>(symbol_name_with_nul.as_bytes())
1193+
.map(|sym| *sym)
1194+
.map_err(|_| {
1195+
vm.new_value_error(format!("symbol '{}' not found", name.as_str()))
1196+
})?
11831197
}
1198+
} else {
1199+
return Err(vm.new_value_error("Library closed"));
11841200
};
11851201

1186-
#[cfg(not(windows))]
1187-
let ptr: *const u8 =
1188-
unsafe { libc::dlsym(handle as *mut libc::c_void, symbol_name.as_ptr()) as *const u8 };
1189-
1202+
// dlsym can return NULL for symbols that resolve to NULL (e.g., GNU IFUNC)
1203+
// Treat NULL addresses as errors
11901204
if ptr.is_null() {
1191-
return Err(
1192-
vm.new_value_error(format!("symbol '{}' not found in library", name.as_str()))
1193-
);
1205+
return Err(vm.new_value_error(format!("symbol '{}' not found", name.as_str())));
11941206
}
11951207

11961208
// PyCData_AtAddress
@@ -1593,7 +1605,7 @@ impl PyCField {
15931605
/// PyCField_set
15941606
#[pyslot]
15951607
fn descr_set(
1596-
zelf: &crate::PyObject,
1608+
zelf: &PyObject,
15971609
obj: PyObjectRef,
15981610
value: PySetterValue<PyObjectRef>,
15991611
vm: &VirtualMachine,
@@ -1804,7 +1816,7 @@ pub enum FfiArgValue {
18041816
F64(f64),
18051817
Pointer(usize),
18061818
/// Pointer with owned data. The PyObjectRef keeps the pointed data alive.
1807-
OwnedPointer(usize, #[allow(dead_code)] crate::PyObjectRef),
1819+
OwnedPointer(usize, #[allow(dead_code)] PyObjectRef),
18081820
}
18091821

18101822
impl FfiArgValue {
@@ -2145,6 +2157,16 @@ pub(super) fn read_ptr_from_buffer(buffer: &[u8]) -> usize {
21452157
}
21462158
}
21472159

2160+
/// Check if a type is a "simple instance" (direct subclass of a simple type)
2161+
/// Returns TRUE for c_int, c_void_p, etc. (simple types with _type_ attribute)
2162+
/// Returns FALSE for Structure, Array, POINTER(T), etc.
2163+
pub(super) fn is_simple_instance(typ: &Py<PyType>) -> bool {
2164+
// _ctypes_simple_instance
2165+
// Check if the type's metaclass is PyCSimpleType
2166+
let metaclass = typ.class();
2167+
metaclass.fast_issubclass(super::simple::PyCSimpleType::static_type())
2168+
}
2169+
21482170
/// Set or initialize StgInfo on a type
21492171
pub(super) fn set_or_init_stginfo(type_ref: &PyType, stg_info: StgInfo) {
21502172
if type_ref.init_type_data(stg_info.clone()).is_err()

0 commit comments

Comments
 (0)