Skip to content

Commit dec9942

Browse files
Add iter support to c-api (#8035)
1 parent d82250b commit dec9942

2 files changed

Lines changed: 139 additions & 0 deletions

File tree

crates/capi/src/abstract_.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
use crate::{PyObject, pystate::with_vm};
22
use alloc::slice;
33
use core::ffi::c_int;
4+
pub use iter::*;
45
pub use mapping::*;
56
pub use number::*;
67
use rustpython_vm::builtins::{PyDict, PyStr, PyTuple};
78
use rustpython_vm::function::{FuncArgs, KwArgs, PosArgs};
89
use rustpython_vm::{AsObject, Py, PyObjectRef, PyResult, VirtualMachine};
910
pub use sequence::*;
1011

12+
mod iter;
1113
mod mapping;
1214
mod number;
1315
mod sequence;
@@ -170,3 +172,11 @@ pub unsafe extern "C" fn PyObject_IsInstance(inst: *mut PyObject, cls: *mut PyOb
170172
inst.is_instance(cls, vm)
171173
})
172174
}
175+
176+
#[unsafe(no_mangle)]
177+
pub unsafe extern "C" fn PyObject_Size(obj: *mut PyObject) -> isize {
178+
with_vm(|vm| {
179+
let obj = unsafe { &*obj };
180+
obj.length(vm)
181+
})
182+
}

crates/capi/src/abstract_/iter.rs

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
use crate::{PyObject, pystate::with_vm};
2+
use core::ffi::c_int;
3+
use rustpython_vm::PyObjectRef;
4+
use rustpython_vm::builtins::PyGenerator;
5+
use rustpython_vm::protocol::{PyIter, PyIterReturn};
6+
7+
#[unsafe(no_mangle)]
8+
pub unsafe extern "C" fn PyIter_Check(obj: *mut PyObject) -> c_int {
9+
with_vm(|_vm| Ok(PyIter::check(unsafe { &*obj })))
10+
}
11+
12+
#[unsafe(no_mangle)]
13+
pub unsafe extern "C" fn PyObject_GetIter(obj: *mut PyObject) -> *mut PyObject {
14+
with_vm(|vm| {
15+
let obj = unsafe { &*obj };
16+
obj.get_iter(vm).map(PyObjectRef::from)
17+
})
18+
}
19+
20+
#[unsafe(no_mangle)]
21+
pub unsafe extern "C" fn PyIter_NextItem(iter: *mut PyObject, item: *mut *mut PyObject) -> c_int {
22+
with_vm(|vm| {
23+
unsafe {
24+
*item = core::ptr::null_mut();
25+
}
26+
27+
let iter = PyIter::new(unsafe { &*iter });
28+
match iter.next(vm)? {
29+
PyIterReturn::Return(next_item) => {
30+
unsafe {
31+
*item = next_item.into_raw().as_ptr();
32+
};
33+
Ok(true)
34+
}
35+
PyIterReturn::StopIteration(_) => Ok(false),
36+
}
37+
})
38+
}
39+
40+
#[unsafe(no_mangle)]
41+
pub unsafe extern "C" fn PyIter_Next(iter: *mut PyObject) -> *mut PyObject {
42+
with_vm(|vm| {
43+
let iter = PyIter::new(unsafe { &*iter });
44+
match iter.next(vm)? {
45+
PyIterReturn::Return(next_item) => Ok(next_item.into_raw().as_ptr()),
46+
PyIterReturn::StopIteration(_) => Ok(core::ptr::null_mut()),
47+
}
48+
})
49+
}
50+
51+
#[unsafe(no_mangle)]
52+
pub unsafe extern "C" fn PyIter_Send(
53+
iter: *mut PyObject,
54+
arg: *mut PyObject,
55+
presult: *mut *mut PyObject,
56+
) -> c_int {
57+
with_vm(|vm| {
58+
unsafe {
59+
*presult = core::ptr::null_mut();
60+
}
61+
62+
let iter_obj = unsafe { &*iter };
63+
let arg_obj = unsafe { &*arg };
64+
65+
let ret = if vm.is_none(arg_obj) {
66+
PyIter::new(iter_obj).next(vm)?
67+
} else {
68+
iter_obj
69+
.try_downcast_ref::<PyGenerator>(vm)?
70+
.as_coro()
71+
.send(iter_obj, arg_obj.to_owned(), vm)?
72+
};
73+
74+
match ret {
75+
PyIterReturn::Return(next_item) => {
76+
unsafe {
77+
*presult = next_item.into_raw().as_ptr();
78+
};
79+
Ok(true)
80+
}
81+
PyIterReturn::StopIteration(ret_val) => {
82+
let ret_val = ret_val.unwrap_or_else(|| vm.ctx.none());
83+
unsafe {
84+
*presult = ret_val.into_raw().as_ptr();
85+
};
86+
Ok(false)
87+
}
88+
}
89+
})
90+
}
91+
92+
#[cfg(false)]
93+
mod tests {
94+
use pyo3::prelude::*;
95+
use pyo3::types::{PyAnyMethods, PyIterator, PyList, PySendResult};
96+
97+
#[test]
98+
fn next_item() {
99+
Python::attach(|py| {
100+
let list = PyList::new(py, [1, 2, 3]).unwrap();
101+
let iter = list.try_iter().unwrap();
102+
let items: Vec<i32> = iter.map(|x| x.unwrap().extract::<i32>().unwrap()).collect();
103+
assert_eq!(items, vec![1, 2, 3]);
104+
})
105+
}
106+
107+
#[test]
108+
fn send_generator() {
109+
Python::attach(|py| {
110+
let generator = py
111+
.eval(c"(x for x in (1, 2))", None, None)
112+
.unwrap()
113+
.cast_into::<PyIterator>()
114+
.unwrap();
115+
116+
let first = generator.send(py.None().bind(py)).unwrap();
117+
assert!(matches!(
118+
first,
119+
PySendResult::Next(value) if value.extract::<i32>().unwrap() == 1
120+
));
121+
122+
let second = generator.send(py.None().bind(py)).unwrap();
123+
assert!(matches!(
124+
second,
125+
PySendResult::Next(value) if value.extract::<i32>().unwrap() == 2
126+
));
127+
})
128+
}
129+
}

0 commit comments

Comments
 (0)