Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions Lib/test/test_asyncgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,8 +559,6 @@ async def call_with_kwarg():
with self.assertRaises(TypeError):
self.loop.run_until_complete(call_with_kwarg())

# TODO: RUSTPYTHON, error message mismatch
@unittest.expectedFailure
def test_anext_bad_await(self):
async def bad_awaitable():
class BadAwaitable:
Expand Down Expand Up @@ -630,8 +628,6 @@ async def do_test():
result = self.loop.run_until_complete(do_test())
self.assertEqual(result, "completed")

# TODO: RUSTPYTHON, anext coroutine iteration issue
@unittest.expectedFailure
def test_anext_iter(self):
@types.coroutine
def _async_yield(v):
Expand Down Expand Up @@ -1489,8 +1485,6 @@ async def main():

self.assertEqual(messages, [])

# TODO: RUSTPYTHON, ValueError: not enough values to unpack (expected 1, got 0)
@unittest.expectedFailure
def test_async_gen_asyncio_shutdown_exception_01(self):
messages = []

Expand Down
91 changes: 80 additions & 11 deletions crates/vm/src/builtins/asyncgenerator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::{
AsObject, Context, Py, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
builtins::PyBaseExceptionRef,
class::PyClassImpl,
common::lock::PyMutex,
coroutine::Coro,
frame::FrameRef,
function::OptionalArg,
Expand All @@ -17,6 +18,10 @@ use crossbeam_utils::atomic::AtomicCell;
pub struct PyAsyncGen {
inner: Coro,
running_async: AtomicCell<bool>,
// whether hooks have been initialized
ag_hooks_inited: AtomicCell<bool>,
// ag_origin_or_finalizer - stores the finalizer callback
ag_finalizer: PyMutex<Option<PyObjectRef>>,
}
type PyAsyncGenRef = PyRef<PyAsyncGen>;

Expand All @@ -37,6 +42,48 @@ impl PyAsyncGen {
Self {
inner: Coro::new(frame, name, qualname),
running_async: AtomicCell::new(false),
ag_hooks_inited: AtomicCell::new(false),
ag_finalizer: PyMutex::new(None),
}
}

/// Initialize async generator hooks.
/// Returns Ok(()) if successful, Err if firstiter hook raised an exception.
fn init_hooks(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<()> {
// = async_gen_init_hooks
if zelf.ag_hooks_inited.load() {
return Ok(());
}

zelf.ag_hooks_inited.store(true);

// Get and store finalizer from thread-local storage
let finalizer = crate::vm::thread::ASYNC_GEN_FINALIZER.with_borrow(|f| f.as_ref().cloned());
if let Some(finalizer) = finalizer {
*zelf.ag_finalizer.lock() = Some(finalizer);
}

// Call firstiter hook
let firstiter = crate::vm::thread::ASYNC_GEN_FIRSTITER.with_borrow(|f| f.as_ref().cloned());
if let Some(firstiter) = firstiter {
let obj: PyObjectRef = zelf.to_owned().into();
firstiter.call((obj,), vm)?;
}

Ok(())
}

/// Call finalizer hook if set
#[allow(dead_code)]
fn call_finalizer(zelf: &Py<Self>, vm: &VirtualMachine) {
// = gen_dealloc
let finalizer = zelf.ag_finalizer.lock().clone();
if let Some(finalizer) = finalizer
&& !zelf.inner.closed.load()
{
// Call finalizer, ignore any errors (PyErr_WriteUnraisable)
let obj: PyObjectRef = zelf.to_owned().into();
let _ = finalizer.call((obj,), vm);
}
}

Expand Down Expand Up @@ -91,17 +138,23 @@ impl PyRef<PyAsyncGen> {
}

#[pymethod]
fn __anext__(self, vm: &VirtualMachine) -> PyAsyncGenASend {
Self::asend(self, vm.ctx.none(), vm)
fn __anext__(self, vm: &VirtualMachine) -> PyResult<PyAsyncGenASend> {
PyAsyncGen::init_hooks(&self, vm)?;
Ok(PyAsyncGenASend {
ag: self,
state: AtomicCell::new(AwaitableState::Init),
value: vm.ctx.none(),
})
}

#[pymethod]
const fn asend(self, value: PyObjectRef, _vm: &VirtualMachine) -> PyAsyncGenASend {
PyAsyncGenASend {
fn asend(self, value: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyAsyncGenASend> {
PyAsyncGen::init_hooks(&self, vm)?;
Ok(PyAsyncGenASend {
ag: self,
state: AtomicCell::new(AwaitableState::Init),
value,
}
})
}

#[pymethod]
Expand All @@ -111,8 +164,9 @@ impl PyRef<PyAsyncGen> {
exc_val: OptionalArg,
exc_tb: OptionalArg,
vm: &VirtualMachine,
) -> PyAsyncGenAThrow {
PyAsyncGenAThrow {
) -> PyResult<PyAsyncGenAThrow> {
PyAsyncGen::init_hooks(&self, vm)?;
Ok(PyAsyncGenAThrow {
ag: self,
aclose: false,
state: AtomicCell::new(AwaitableState::Init),
Expand All @@ -121,12 +175,13 @@ impl PyRef<PyAsyncGen> {
exc_val.unwrap_or_none(vm),
exc_tb.unwrap_or_none(vm),
),
}
})
}

#[pymethod]
fn aclose(self, vm: &VirtualMachine) -> PyAsyncGenAThrow {
PyAsyncGenAThrow {
fn aclose(self, vm: &VirtualMachine) -> PyResult<PyAsyncGenAThrow> {
PyAsyncGen::init_hooks(&self, vm)?;
Ok(PyAsyncGenAThrow {
ag: self,
aclose: true,
state: AtomicCell::new(AwaitableState::Init),
Expand All @@ -135,7 +190,7 @@ impl PyRef<PyAsyncGen> {
vm.ctx.none(),
vm.ctx.none(),
),
}
})
}
}

Expand Down Expand Up @@ -441,6 +496,7 @@ impl IterNext for PyAsyncGenAThrow {
pub struct PyAnextAwaitable {
wrapped: PyObjectRef,
default_value: PyObjectRef,
state: AtomicCell<AwaitableState>,
}

impl PyPayload for PyAnextAwaitable {
Expand All @@ -456,6 +512,7 @@ impl PyAnextAwaitable {
Self {
wrapped,
default_value,
state: AtomicCell::new(AwaitableState::Init),
}
}

Expand All @@ -464,6 +521,13 @@ impl PyAnextAwaitable {
zelf
}

fn check_closed(&self, vm: &VirtualMachine) -> PyResult<()> {
if let AwaitableState::Closed = self.state.load() {
return Err(vm.new_runtime_error("cannot reuse already awaited __anext__()/asend()"));
}
Ok(())
}

/// Get the awaitable iterator from wrapped object.
// = anextawaitable_getiter.
fn get_awaitable_iter(&self, vm: &VirtualMachine) -> PyResult {
Expand Down Expand Up @@ -523,6 +587,8 @@ impl PyAnextAwaitable {

#[pymethod]
fn send(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult {
self.check_closed(vm)?;
self.state.store(AwaitableState::Iter);
let awaitable = self.get_awaitable_iter(vm)?;
let result = vm.call_method(&awaitable, "send", (val,));
self.handle_result(result, vm)
Expand All @@ -536,6 +602,8 @@ impl PyAnextAwaitable {
exc_tb: OptionalArg,
vm: &VirtualMachine,
) -> PyResult {
self.check_closed(vm)?;
self.state.store(AwaitableState::Iter);
let awaitable = self.get_awaitable_iter(vm)?;
let result = vm.call_method(
&awaitable,
Expand All @@ -551,6 +619,7 @@ impl PyAnextAwaitable {

#[pymethod]
fn close(&self, vm: &VirtualMachine) -> PyResult<()> {
self.state.store(AwaitableState::Closed);
if let Ok(awaitable) = self.get_awaitable_iter(vm) {
let _ = vm.call_method(&awaitable, "close", ());
}
Expand Down
33 changes: 29 additions & 4 deletions crates/vm/src/builtins/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::{
protocol::PyIterReturn,
types::{IterNext, Iterable, Representable, SelfIter, Unconstructible},
};
use crossbeam_utils::atomic::AtomicCell;

#[pyclass(module = false, name = "coroutine")]
#[derive(Debug)]
Expand Down Expand Up @@ -56,8 +57,11 @@ impl PyCoroutine {
}

#[pymethod(name = "__await__")]
const fn r#await(zelf: PyRef<Self>) -> PyCoroutineWrapper {
PyCoroutineWrapper { coro: zelf }
fn r#await(zelf: PyRef<Self>) -> PyCoroutineWrapper {
PyCoroutineWrapper {
coro: zelf,
closed: AtomicCell::new(false),
}
}

#[pygetset]
Expand Down Expand Up @@ -140,6 +144,7 @@ impl IterNext for PyCoroutine {
// PyCoroWrapper_Type in CPython
pub struct PyCoroutineWrapper {
coro: PyRef<PyCoroutine>,
closed: AtomicCell<bool>,
}

impl PyPayload for PyCoroutineWrapper {
Expand All @@ -151,9 +156,22 @@ impl PyPayload for PyCoroutineWrapper {

#[pyclass(with(IterNext, Iterable))]
impl PyCoroutineWrapper {
fn check_closed(&self, vm: &VirtualMachine) -> PyResult<()> {
if self.closed.load() {
return Err(vm.new_runtime_error("cannot reuse already awaited coroutine"));
}
Ok(())
}

#[pymethod]
fn send(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyIterReturn> {
self.coro.send(val, vm)
self.check_closed(vm)?;
let result = self.coro.send(val, vm);
// Mark as closed if exhausted
if let Ok(PyIterReturn::StopIteration(_)) = &result {
self.closed.store(true);
}
result
}
Comment on lines 166 to 175
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check if there's a pattern in the codebase for closing on errors
rg -n "closed\.store\(true\)" --type rust -C3

Repository: RustPython/RustPython

Length of output: 4621


🏁 Script executed:

#!/bin/bash
# Get the full coroutine.rs file to understand the structure
wc -l crates/vm/src/builtins/coroutine.rs

Repository: RustPython/RustPython

Length of output: 104


🏁 Script executed:

#!/bin/bash
# Find the Coro struct and understand its send method
rg -n "impl.*Coro" --type rust -A 5 crates/vm/src/builtins/

Repository: RustPython/RustPython

Length of output: 3678


🏁 Script executed:

#!/bin/bash
# Check if there are other methods that handle closing and errors
rg -n "fn.*(&self.*vm.*PyResult" crates/vm/src/builtins/coroutine.rs -A 10

Repository: RustPython/RustPython

Length of output: 144


🏁 Script executed:

#!/bin/bash
# Get the Coro struct and send method implementation
sed -n '40,120p' crates/vm/src/coroutine.rs | head -80

Repository: RustPython/RustPython

Length of output: 2533


🏁 Script executed:

#!/bin/bash
# Find the send method in the Coro struct
rg -n "fn send" crates/vm/src/coroutine.rs -A 15

Repository: RustPython/RustPython

Length of output: 647


🏁 Script executed:

#!/bin/bash
# See the full context of the PyCoroutineWrapper send and throw methods
sed -n '166,200p' crates/vm/src/builtins/coroutine.rs

Repository: RustPython/RustPython

Length of output: 1028


🏁 Script executed:

#!/bin/bash
# Look at the rest of Coro.send to see full error handling
sed -n '92,125p' crates/vm/src/coroutine.rs

Repository: RustPython/RustPython

Length of output: 1455


🏁 Script executed:

#!/bin/bash
# Verify throw method in Coro as well to see if it uses maybe_close
rg -n "pub fn throw" crates/vm/src/coroutine.rs -A 20

Repository: RustPython/RustPython

Length of output: 866


Mark wrapper as closed on errors, not just StopIteration.

The wrapper's closed flag only closes when StopIteration is returned. When send() returns an error, the wrapper remains "open", allowing another call that will unnecessarily pass through check_closed() before the underlying Coro (which already closed itself via maybe_close) returns StopIteration. Close the wrapper on errors to match the underlying Coro's error-handling pattern and align with CPython behavior where a coroutine cannot resume after raising an exception.

🤖 Prompt for AI Agents
In crates/vm/src/builtins/coroutine.rs around lines 166 to 175, the wrapper's
closed flag is only set when send() returns Ok(PyIterReturn::StopIteration(_)),
but it must also be marked closed when send() returns an Err so the wrapper
cannot be resumed after the underlying coroutine has raised; update the logic to
set self.closed.store(true) both for Ok(PyIterReturn::StopIteration(_)) and for
Err(_) (i.e., treat any error as terminal), preserving the original return value
so the error or StopIteration still propagates to the caller.


#[pymethod]
Expand All @@ -164,11 +182,18 @@ impl PyCoroutineWrapper {
exc_tb: OptionalArg,
vm: &VirtualMachine,
) -> PyResult<PyIterReturn> {
self.coro.throw(exc_type, exc_val, exc_tb, vm)
self.check_closed(vm)?;
let result = self.coro.throw(exc_type, exc_val, exc_tb, vm);
// Mark as closed if exhausted
if let Ok(PyIterReturn::StopIteration(_)) = &result {
self.closed.store(true);
}
result
}

#[pymethod]
fn close(&self, vm: &VirtualMachine) -> PyResult<()> {
self.closed.store(true);
self.coro.close(vm)
}
}
Expand Down
12 changes: 11 additions & 1 deletion crates/vm/src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,8 @@ impl ExecutingFrame<'_> {
Ok(None)
}
bytecode::Instruction::GetAwaitable => {
use crate::protocol::PyIter;

let awaited_obj = self.pop_value();
let awaitable = if awaited_obj.downcastable::<PyCoroutine>() {
awaited_obj
Expand All @@ -932,7 +934,15 @@ impl ExecutingFrame<'_> {
)
},
)?;
await_method.call((), vm)?
let result = await_method.call((), vm)?;
// Check that __await__ returned an iterator
if !PyIter::check(&result) {
return Err(vm.new_type_error(format!(
"__await__() returned non-iterator of type '{}'",
result.class().name()
)));
}
result
};
self.push_value(awaitable);
Ok(None)
Expand Down
Loading