Skip to content

Commit baeca0a

Browse files
committed
fix PySequenceIterator to use sequence protocol
1 parent 0fd014f commit baeca0a

4 files changed

Lines changed: 47 additions & 15 deletions

File tree

vm/src/builtins/iter.rs

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
use super::{PyInt, PyTupleRef, PyTypeRef};
66
use crate::{
77
function::ArgCallable,
8-
protocol::PyIterReturn,
8+
protocol::{PyIterReturn, PySequence},
99
types::{IterNext, IterNextIterable},
10-
ItemProtocol, PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, VirtualMachine,
10+
PyClassImpl, PyContext, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, VirtualMachine,
1111
};
1212
use rustpython_common::{
1313
lock::{PyMutex, PyRwLock, PyRwLockUpgradableReadGuard},
@@ -160,7 +160,7 @@ pub fn builtins_reversed(vm: &VirtualMachine) -> &PyObjectRef {
160160
#[pyclass(module = false, name = "iterator")]
161161
#[derive(Debug)]
162162
pub struct PySequenceIterator {
163-
internal: PyMutex<PositionIterInternal<PyObjectRef>>,
163+
internal: PyMutex<PositionIterInternal<PySequence>>,
164164
}
165165

166166
impl PyValue for PySequenceIterator {
@@ -171,18 +171,21 @@ impl PyValue for PySequenceIterator {
171171

172172
#[pyimpl(with(IterNext))]
173173
impl PySequenceIterator {
174-
pub fn new(obj: PyObjectRef) -> Self {
175-
Self {
176-
internal: PyMutex::new(PositionIterInternal::new(obj, 0)),
177-
}
174+
pub fn new(obj: PyObjectRef, vm: &VirtualMachine) -> PyResult<Self> {
175+
Ok(Self {
176+
internal: PyMutex::new(PositionIterInternal::new(
177+
PySequence::try_from_object(vm, obj)?,
178+
0,
179+
)),
180+
})
178181
}
179182

180183
#[pymethod(magic)]
181184
fn length_hint(&self, vm: &VirtualMachine) -> PyObjectRef {
182185
let internal = self.internal.lock();
183186
if let IterStatus::Active(obj) = &internal.status {
184-
vm.obj_len(obj)
185-
.map(|x| PyInt::from(x).into_object(vm))
187+
obj.length(vm)
188+
.map(|x| vm.ctx.new_int(x).into())
186189
.unwrap_or_else(|_| vm.ctx.not_implemented())
187190
} else {
188191
PyInt::from(0).into_object(vm)
@@ -191,7 +194,9 @@ impl PySequenceIterator {
191194

192195
#[pymethod(magic)]
193196
fn reduce(&self, vm: &VirtualMachine) -> PyTupleRef {
194-
self.internal.lock().builtins_iter_reduce(|x| x.clone(), vm)
197+
self.internal
198+
.lock()
199+
.builtins_iter_reduce(|x| x.obj.clone(), vm)
195200
}
196201

197202
#[pymethod(magic)]
@@ -205,7 +210,7 @@ impl IterNext for PySequenceIterator {
205210
fn next(zelf: &PyRef<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
206211
zelf.internal
207212
.lock()
208-
.next(|obj, pos| PyIterReturn::from_getitem_result(obj.get_item(pos, vm), vm))
213+
.next(|obj, pos| PyIterReturn::from_getitem_result(obj.get_item(pos as isize, vm), vm))
209214
}
210215
}
211216

vm/src/function/argument.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ impl<T> ArgIterable<T> {
6060
pub fn iter<'a>(&self, vm: &'a VirtualMachine) -> PyResult<PyIterIter<'a, T>> {
6161
let iter = PyIter::new(match self.iterfn {
6262
Some(f) => f(self.iterable.clone(), vm)?,
63-
None => PySequenceIterator::new(self.iterable.clone()).into_object(vm),
63+
None => PySequenceIterator::new(self.iterable.clone(), vm)?.into_object(vm),
6464
});
6565
iter.into_iter(vm)
6666
}

vm/src/protocol/iter.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,9 @@ impl TryFromObject for PyIter<PyObjectRef> {
121121
)))
122122
}
123123
} else if PySequence::check(&iter_target, vm) {
124-
Ok(Self(PySequenceIterator::new(iter_target).into_object(vm)))
124+
Ok(Self(
125+
PySequenceIterator::new(iter_target, vm)?.into_object(vm),
126+
))
125127
} else {
126128
Err(vm.new_type_error(format!(
127129
"'{}' object is not iterable",

vm/src/protocol/sequence.rs

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
use std::borrow::{Borrow, Cow};
2+
use std::fmt::Debug;
23

34
use itertools::Itertools;
45

56
use crate::{
67
builtins::{PyList, PySlice},
78
common::static_cell,
89
function::IntoPyObject,
9-
IdProtocol, PyArithmeticValue, PyObjectRef, PyResult, PyValue, TypeProtocol, VirtualMachine,
10+
IdProtocol, PyArithmeticValue, PyObjectRef, PyResult, PyValue, TryFromObject, TypeProtocol,
11+
VirtualMachine,
1012
};
1113

1214
// Sequence Protocol
@@ -35,8 +37,24 @@ impl PySequenceMethods {
3537
}
3638
}
3739

40+
impl Debug for PySequenceMethods {
41+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42+
f.debug_struct("PySequenceMethods")
43+
.field("length", &self.length.map(|x| x as usize))
44+
.field("concat", &self.concat.map(|x| x as usize))
45+
.field("repeat", &self.repeat.map(|x| x as usize))
46+
.field("item", &self.item.map(|x| x as usize))
47+
.field("ass_item", &self.ass_item.map(|x| x as usize))
48+
.field("contains", &self.contains.map(|x| x as usize))
49+
.field("inplace_concat", &self.inplace_concat.map(|x| x as usize))
50+
.field("inplace_repeat", &self.inplace_repeat.map(|x| x as usize))
51+
.finish()
52+
}
53+
}
54+
55+
#[derive(Debug)]
3856
pub struct PySequence {
39-
obj: PyObjectRef,
57+
pub obj: PyObjectRef,
4058
methods: Cow<'static, PySequenceMethods>,
4159
}
4260

@@ -291,6 +309,13 @@ impl PySequence {
291309
}
292310
}
293311

312+
impl TryFromObject for PySequence {
313+
fn try_from_object(vm: &VirtualMachine, obj: PyObjectRef) -> PyResult<Self> {
314+
PySequence::from_object(vm, obj)
315+
.ok_or_else(|| vm.new_type_error("'{}' is not a sequence".to_string()))
316+
}
317+
}
318+
294319
pub fn try_add_for_concat(a: &PyObjectRef, b: &PyObjectRef, vm: &VirtualMachine) -> PyResult {
295320
if PySequence::check(b, vm) {
296321
let ret = vm._add(a, b)?;

0 commit comments

Comments
 (0)