Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .cspell.dict/python-more.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ aenter
aexit
aiter
anext
anextawaitable
appendleft
argcount
arrayiterator
Expand Down
16 changes: 2 additions & 14 deletions Lib/test/test_asyncgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,6 @@ async def async_gen_wrapper():

self.compare_generators(sync_gen_wrapper(), async_gen_wrapper())

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_async_gen_api_01(self):
async def gen():
yield 123
Expand Down Expand Up @@ -467,16 +465,12 @@ async def test_throw():
result = self.loop.run_until_complete(test_throw())
self.assertEqual(result, "completed")

# TODO: RUSTPYTHON, NameError: name 'anext' is not defined
@unittest.expectedFailure
def test_async_generator_anext(self):
async def agen():
yield 1
yield 2
self.check_async_iterator_anext(agen)

# TODO: RUSTPYTHON, NameError: name 'anext' is not defined
@unittest.expectedFailure
def test_python_async_iterator_anext(self):
class MyAsyncIter:
"""Asynchronously yield 1, then 2."""
Expand All @@ -492,8 +486,6 @@ async def __anext__(self):
return self.yielded
self.check_async_iterator_anext(MyAsyncIter)

# TODO: RUSTPYTHON, NameError: name 'anext' is not defined
@unittest.expectedFailure
def test_python_async_iterator_types_coroutine_anext(self):
import types
class MyAsyncIterWithTypesCoro:
Expand Down Expand Up @@ -523,8 +515,6 @@ async def consume():
res = self.loop.run_until_complete(consume())
self.assertEqual(res, [1, 2])

# TODO: RUSTPYTHON, NameError: name 'aiter' is not defined
@unittest.expectedFailure
def test_async_gen_aiter_class(self):
results = []
class Gen:
Expand All @@ -549,8 +539,6 @@ async def gen():
applied_twice = aiter(applied_once)
self.assertIs(applied_once, applied_twice)

# TODO: RUSTPYTHON, NameError: name 'anext' is not defined
@unittest.expectedFailure
def test_anext_bad_args(self):
async def gen():
yield 1
Expand All @@ -571,7 +559,7 @@ async def call_with_kwarg():
with self.assertRaises(TypeError):
self.loop.run_until_complete(call_with_kwarg())

# TODO: RUSTPYTHON, NameError: name 'anext' is not defined
# TODO: RUSTPYTHON, error message mismatch
@unittest.expectedFailure
def test_anext_bad_await(self):
async def bad_awaitable():
Expand Down Expand Up @@ -642,7 +630,7 @@ async def do_test():
result = self.loop.run_until_complete(do_test())
self.assertEqual(result, "completed")

# TODO: RUSTPYTHON, NameError: name 'anext' is not defined
# TODO: RUSTPYTHON, anext coroutine iteration issue
@unittest.expectedFailure
def test_anext_iter(self):
@types.coroutine
Expand Down
1 change: 0 additions & 1 deletion Lib/test/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def g3(): return (yield from f())

class GeneratorTest(unittest.TestCase):

@unittest.expectedFailure # TODO: RUSTPYTHON
def test_name(self):
def func():
yield 1
Expand Down
2 changes: 0 additions & 2 deletions Lib/test/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2052,8 +2052,6 @@ async def corofunc():
else:
self.fail('StopIteration was expected')

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_gen(self):
def gen_func():
yield 1
Expand Down
157 changes: 155 additions & 2 deletions crates/vm/src/builtins/asyncgenerator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ impl PyAsyncGen {
&self.inner
}

pub fn new(frame: FrameRef, name: PyStrRef) -> Self {
pub fn new(frame: FrameRef, name: PyStrRef, qualname: PyStrRef) -> Self {
Self {
inner: Coro::new(frame, name),
inner: Coro::new(frame, name, qualname),
running_async: AtomicCell::new(false),
}
}
Expand All @@ -50,6 +50,16 @@ impl PyAsyncGen {
self.inner.set_name(name)
}

#[pygetset]
fn __qualname__(&self) -> PyStrRef {
self.inner.qualname()
}

#[pygetset(setter)]
fn set___qualname__(&self, qualname: PyStrRef) {
self.inner.set_qualname(qualname)
}

#[pygetset]
fn ag_await(&self, _vm: &VirtualMachine) -> Option<PyObjectRef> {
self.inner.frame().yield_from_target()
Expand Down Expand Up @@ -424,8 +434,151 @@ impl IterNext for PyAsyncGenAThrow {
}
}

/// Awaitable wrapper for anext() builtin with default value.
/// When StopAsyncIteration is raised, it converts it to StopIteration(default).
#[pyclass(module = false, name = "anext_awaitable")]
#[derive(Debug)]
pub struct PyAnextAwaitable {
wrapped: PyObjectRef,
default_value: PyObjectRef,
}

impl PyPayload for PyAnextAwaitable {
#[inline]
fn class(ctx: &Context) -> &'static Py<PyType> {
ctx.types.anext_awaitable
}
}

#[pyclass(with(IterNext, Iterable))]
impl PyAnextAwaitable {
pub fn new(wrapped: PyObjectRef, default_value: PyObjectRef) -> Self {
Self {
wrapped,
default_value,
}
}

#[pymethod(name = "__await__")]
fn r#await(zelf: PyRef<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
zelf
}

/// Get the awaitable iterator from wrapped object.
// = anextawaitable_getiter.
fn get_awaitable_iter(&self, vm: &VirtualMachine) -> PyResult {
use crate::builtins::PyCoroutine;
use crate::protocol::PyIter;

let wrapped = &self.wrapped;

// If wrapped is already an async_generator_asend, it's an iterator
if wrapped.class().is(vm.ctx.types.async_generator_asend)
|| wrapped.class().is(vm.ctx.types.async_generator_athrow)
{
return Ok(wrapped.clone());
}

// _PyCoro_GetAwaitableIter equivalent
let awaitable = if wrapped.class().is(vm.ctx.types.coroutine_type) {
// Coroutine - get __await__ later
wrapped.clone()
} else {
// Try to get __await__ method
if let Some(await_method) = vm.get_method(wrapped.clone(), identifier!(vm, __await__)) {
await_method?.call((), vm)?
} else {
return Err(vm.new_type_error(format!(
"object {} can't be used in 'await' expression",
wrapped.class().name()
)));
}
};

// If awaitable is a coroutine, get its __await__
if awaitable.class().is(vm.ctx.types.coroutine_type) {
let coro_await = vm.call_method(&awaitable, "__await__", ())?;
// Check that __await__ returned an iterator
if !PyIter::check(&coro_await) {
return Err(vm.new_type_error("__await__ returned a non-iterable"));
}
return Ok(coro_await);
}

// Check the result is an iterator, not a coroutine
if awaitable.downcast_ref::<PyCoroutine>().is_some() {
return Err(vm.new_type_error("__await__() returned a coroutine"));
}

// Check that the result is an iterator
if !PyIter::check(&awaitable) {
return Err(vm.new_type_error(format!(
"__await__() returned non-iterator of type '{}'",
awaitable.class().name()
)));
}

Ok(awaitable)
}

#[pymethod]
fn send(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult {
let awaitable = self.get_awaitable_iter(vm)?;
let result = vm.call_method(&awaitable, "send", (val,));
self.handle_result(result, vm)
}

#[pymethod]
fn throw(
&self,
exc_type: PyObjectRef,
exc_val: OptionalArg,
exc_tb: OptionalArg,
vm: &VirtualMachine,
) -> PyResult {
let awaitable = self.get_awaitable_iter(vm)?;
let result = vm.call_method(
&awaitable,
"throw",
(
exc_type,
exc_val.unwrap_or_none(vm),
exc_tb.unwrap_or_none(vm),
),
);
self.handle_result(result, vm)
}
Comment on lines +524 to +550
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

PyAnextAwaitable assumes the await-iterator has send/throw; CPython allows plain iterators from __await__.

If RustPython’s await-driving machinery can handle plain iterators, PyAnextAwaitable.send() should fall back to advancing via __next__ (when val is None) if there’s no send, and throw() should propagate if there’s no throw—otherwise this wrapper will reject valid __await__ implementations.


#[pymethod]
fn close(&self, vm: &VirtualMachine) -> PyResult<()> {
if let Ok(awaitable) = self.get_awaitable_iter(vm) {
let _ = vm.call_method(&awaitable, "close", ());
}
Ok(())
}

/// Convert StopAsyncIteration to StopIteration(default_value)
fn handle_result(&self, result: PyResult, vm: &VirtualMachine) -> PyResult {
match result {
Ok(value) => Ok(value),
Err(exc) if exc.fast_isinstance(vm.ctx.exceptions.stop_async_iteration) => {
Err(vm.new_stop_iteration(Some(self.default_value.clone())))
}
Err(exc) => Err(exc),
}
}
}

impl SelfIter for PyAnextAwaitable {}
impl IterNext for PyAnextAwaitable {
fn next(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
PyIterReturn::from_pyresult(zelf.send(vm.ctx.none(), vm), vm)
}
}
Comment on lines +437 to +577
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

PyAnextAwaitable is likely functionally wrong: it recreates the underlying __await__ iterator each .send() / .throw() call.

That breaks iterator state (and can repeat side effects) because __await__() is expected to be called once, and the returned iterator is then driven until completion. Cache the iterator inside PyAnextAwaitable (and mark “closed” after completion) like other awaitables in this file.

Minimal direction (illustrative diff sketch; adjust imports/types as needed):

 pub struct PyAnextAwaitable {
     wrapped: PyObjectRef,
     default_value: PyObjectRef,
+    await_iter: crate::common::lock::PyMutex<Option<PyObjectRef>>,
+    state: AtomicCell<AwaitableState>,
 }

 impl PyAnextAwaitable {
     pub fn new(wrapped: PyObjectRef, default_value: PyObjectRef) -> Self {
         Self {
             wrapped,
             default_value,
+            await_iter: crate::common::lock::PyMutex::new(None),
+            state: AtomicCell::new(AwaitableState::Init),
         }
     }

-    fn get_awaitable_iter(&self, vm: &VirtualMachine) -> PyResult {
+    fn get_awaitable_iter(&self, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
+        if let AwaitableState::Closed = self.state.load() {
+            return Err(vm.new_runtime_error("cannot reuse already awaited anext()"));
+        }
+        if let Some(it) = self.await_iter.lock().clone() {
+            return Ok(it);
+        }
         ...
-        Ok(awaitable)
+        *self.await_iter.lock() = Some(awaitable.clone());
+        self.state.store(AwaitableState::Iter);
+        Ok(awaitable)
     }

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
crates/vm/src/builtins/asyncgenerator.rs around lines 437 to 577: the current
implementation calls __await__ (or constructs the awaitable iterator) on every
send/throw which recreates the iterator and breaks iterator state; change
PyAnextAwaitable to hold an Option<PyObjectRef> field (e.g., awaitable:
Option<PyObjectRef>) to cache the iterator, initialize that field the first time
get_awaitable_iter is called, return the cached iterator on subsequent calls,
and set it to None/mark closed when the iterator completes (StopAsyncIteration)
or when close() is called; update get_awaitable_iter, send, throw and close to
use and mutate this cached field appropriately, following the same ownership and
VM error handling patterns used by other awaitable types in the file.


pub fn init(ctx: &Context) {
PyAsyncGen::extend_class(ctx, ctx.types.async_generator);
PyAsyncGenASend::extend_class(ctx, ctx.types.async_generator_asend);
PyAsyncGenAThrow::extend_class(ctx, ctx.types.async_generator_athrow);
PyAnextAwaitable::extend_class(ctx, ctx.types.anext_awaitable);
}
19 changes: 17 additions & 2 deletions crates/vm/src/builtins/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ impl PyCoroutine {
&self.inner
}

pub fn new(frame: FrameRef, name: PyStrRef) -> Self {
pub fn new(frame: FrameRef, name: PyStrRef, qualname: PyStrRef) -> Self {
Self {
inner: Coro::new(frame, name),
inner: Coro::new(frame, name, qualname),
}
}

Expand All @@ -45,6 +45,16 @@ impl PyCoroutine {
self.inner.set_name(name)
}

#[pygetset]
fn __qualname__(&self) -> PyStrRef {
self.inner.qualname()
}

#[pygetset(setter)]
fn set___qualname__(&self, qualname: PyStrRef) {
self.inner.set_qualname(qualname)
}

#[pymethod(name = "__await__")]
const fn r#await(zelf: PyRef<Self>) -> PyCoroutineWrapper {
PyCoroutineWrapper { coro: zelf }
Expand Down Expand Up @@ -156,6 +166,11 @@ impl PyCoroutineWrapper {
) -> PyResult<PyIterReturn> {
self.coro.throw(exc_type, exc_val, exc_tb, vm)
}

#[pymethod]
fn close(&self, vm: &VirtualMachine) -> PyResult<()> {
self.coro.close(vm)
}
}

impl SelfIter for PyCoroutineWrapper {}
Expand Down
12 changes: 9 additions & 3 deletions crates/vm/src/builtins/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -425,9 +425,15 @@ impl Py<PyFunction> {
let is_gen = code.flags.contains(bytecode::CodeFlags::IS_GENERATOR);
let is_coro = code.flags.contains(bytecode::CodeFlags::IS_COROUTINE);
match (is_gen, is_coro) {
(true, false) => Ok(PyGenerator::new(frame, self.__name__()).into_pyobject(vm)),
(false, true) => Ok(PyCoroutine::new(frame, self.__name__()).into_pyobject(vm)),
(true, true) => Ok(PyAsyncGen::new(frame, self.__name__()).into_pyobject(vm)),
(true, false) => {
Ok(PyGenerator::new(frame, self.__name__(), self.__qualname__()).into_pyobject(vm))
}
(false, true) => {
Ok(PyCoroutine::new(frame, self.__name__(), self.__qualname__()).into_pyobject(vm))
}
(true, true) => {
Ok(PyAsyncGen::new(frame, self.__name__(), self.__qualname__()).into_pyobject(vm))
}
(false, false) => vm.run_frame(frame),
}
}
Expand Down
14 changes: 12 additions & 2 deletions crates/vm/src/builtins/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ impl PyGenerator {
&self.inner
}

pub fn new(frame: FrameRef, name: PyStrRef) -> Self {
pub fn new(frame: FrameRef, name: PyStrRef, qualname: PyStrRef) -> Self {
Self {
inner: Coro::new(frame, name),
inner: Coro::new(frame, name, qualname),
}
}

Expand All @@ -48,6 +48,16 @@ impl PyGenerator {
self.inner.set_name(name)
}

#[pygetset]
fn __qualname__(&self) -> PyStrRef {
self.inner.qualname()
}

#[pygetset(setter)]
fn set___qualname__(&self, qualname: PyStrRef) {
self.inner.set_qualname(qualname)
}

#[pygetset]
fn gi_frame(&self, _vm: &VirtualMachine) -> FrameRef {
self.inner.frame()
Expand Down
Loading
Loading