Skip to content

Commit f1be6e9

Browse files
committed
fix
1 parent 5953600 commit f1be6e9

2 files changed

Lines changed: 39 additions & 13 deletions

File tree

crates/vm/src/builtins/iter.rs

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
use super::{PyInt, PyTupleRef, PyType};
66
use crate::{
7-
AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine,
7+
Context, Py, PyObject, PyObjectRef, PyPayload, PyResult, VirtualMachine,
88
class::PyClassImpl,
99
function::ArgCallable,
1010
object::{Traverse, TraverseFn},
@@ -199,24 +199,20 @@ impl PySequenceIterator {
199199
#[pymethod]
200200
fn __length_hint__(&self, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
201201
vm.with_recursion("in __length_hint__", || {
202-
let obj = {
202+
let (obj, position) = {
203203
let internal = self.internal.lock();
204204
match &internal.status {
205-
IterStatus::Active(obj) => Some(obj.clone()),
206-
IterStatus::Exhausted => None,
205+
IterStatus::Active(obj) => (Some(obj.clone()), internal.position),
206+
IterStatus::Exhausted => (None, 0),
207207
}
208208
};
209209
if let Some(obj) = obj {
210210
let seq = obj.sequence_unchecked();
211-
match seq.length(vm) {
212-
Ok(x) => Ok(PyInt::from(x).into_pyobject(vm)),
213-
Err(err) => {
214-
if err.fast_isinstance(vm.ctx.exceptions.recursion_error) {
215-
Err(err)
216-
} else {
217-
Ok(vm.ctx.not_implemented())
218-
}
219-
}
211+
match seq.length_opt(vm) {
212+
Some(len) => len.map(|len| {
213+
PyInt::from(len.saturating_sub(position)).into_pyobject(vm)
214+
}),
215+
None => Ok(vm.ctx.not_implemented()),
220216
}
221217
} else {
222218
Ok(PyInt::from(0).into_pyobject(vm))

extra_tests/snippets/builtin_iter.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,33 @@ def run():
3939
assert not t.is_alive(), "iterator.__length_hint__ deadlocked"
4040
err = q.get_nowait()
4141
assert isinstance(err, RecursionError)
42+
43+
44+
class NoLen:
45+
def __getitem__(self, index):
46+
if index < 3:
47+
return index
48+
raise IndexError
49+
50+
51+
no_len_it = iter(NoLen())
52+
assert no_len_it.__length_hint__() is NotImplemented
53+
next(no_len_it)
54+
assert no_len_it.__length_hint__() is NotImplemented
55+
56+
57+
class Seq:
58+
def __init__(self):
59+
self.items = [1, 2, 3]
60+
61+
def __getitem__(self, index):
62+
return self.items[index]
63+
64+
def __len__(self):
65+
return len(self.items)
66+
67+
68+
seq_it = iter(Seq())
69+
assert seq_it.__length_hint__() == 3
70+
next(seq_it)
71+
assert seq_it.__length_hint__() == 2

0 commit comments

Comments
 (0)