Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions stdlib/src/bisect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ mod _bisect {
while lo < hi {
// Handles issue 13496.
let mid = (lo + hi) / 2;
if vm.bool_cmp(&a.get_item(mid, vm)?, &x, Lt)? {
if a.get_item(mid, vm)?.rich_compare_bool(&x, Lt, vm)? {
lo = mid + 1;
} else {
hi = mid;
Expand Down Expand Up @@ -105,7 +105,7 @@ mod _bisect {
while lo < hi {
// Handles issue 13496.
let mid = (lo + hi) / 2;
if vm.bool_cmp(&x, &a.get_item(mid, vm)?, Lt)? {
if x.rich_compare_bool(&a.get_item(mid, vm)?, Lt, vm)? {
hi = mid;
} else {
lo = mid + 1;
Expand Down
2 changes: 1 addition & 1 deletion vm/src/builtins/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ fn do_sort(
} else {
PyComparisonOp::Gt
};
let cmp = |a: &PyObjectRef, b: &PyObjectRef| vm.bool_cmp(a, b, op);
let cmp = |a: &PyObjectRef, b: &PyObjectRef| a.rich_compare_bool(b, op, vm);

if let Some(ref key_func) = key_func {
let mut items = values
Expand Down
16 changes: 10 additions & 6 deletions vm/src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1772,12 +1772,16 @@ impl ExecutingFrame<'_> {
let b = self.pop_value();
let a = self.pop_value();
let value = match *op {
bytecode::ComparisonOperator::Equal => vm.obj_cmp(a, b, PyComparisonOp::Eq)?,
bytecode::ComparisonOperator::NotEqual => vm.obj_cmp(a, b, PyComparisonOp::Ne)?,
bytecode::ComparisonOperator::Less => vm.obj_cmp(a, b, PyComparisonOp::Lt)?,
bytecode::ComparisonOperator::LessOrEqual => vm.obj_cmp(a, b, PyComparisonOp::Le)?,
bytecode::ComparisonOperator::Greater => vm.obj_cmp(a, b, PyComparisonOp::Gt)?,
bytecode::ComparisonOperator::GreaterOrEqual => vm.obj_cmp(a, b, PyComparisonOp::Ge)?,
bytecode::ComparisonOperator::Equal => a.rich_compare(b, PyComparisonOp::Eq, vm)?,
bytecode::ComparisonOperator::NotEqual => a.rich_compare(b, PyComparisonOp::Ne, vm)?,
bytecode::ComparisonOperator::Less => a.rich_compare(b, PyComparisonOp::Lt, vm)?,
bytecode::ComparisonOperator::LessOrEqual => {
a.rich_compare(b, PyComparisonOp::Le, vm)?
}
bytecode::ComparisonOperator::Greater => a.rich_compare(b, PyComparisonOp::Gt, vm)?,
bytecode::ComparisonOperator::GreaterOrEqual => {
a.rich_compare(b, PyComparisonOp::Ge, vm)?
}
bytecode::ComparisonOperator::Is => vm.ctx.new_bool(self._is(a, b)).into(),
bytecode::ComparisonOperator::IsNot => vm.ctx.new_bool(self._is_not(a, b)).into(),
bytecode::ComparisonOperator::In => vm.ctx.new_bool(self._in(vm, a, b)?).into(),
Expand Down
65 changes: 61 additions & 4 deletions vm/src/protocol/object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ use crate::{
builtins::{pystr::IntoPyStrRef, PyBytes, PyInt, PyStrRef},
bytesinner::ByteInnerNewOptions,
common::{hash::PyHash, str::to_ascii},
function::OptionalArg,
function::{IntoPyObject, OptionalArg},
protocol::PyIter,
pyref_type_error,
types::{Constructor, PyComparisonOp},
PyObjectRef, PyResult, TryFromObject, TypeProtocol, VirtualMachine,
utils::Either,
IdProtocol, PyArithmeticValue, PyObjectRef, PyResult, TryFromObject, TypeProtocol,
VirtualMachine,
};

// RustPython doesn't need these items
Expand Down Expand Up @@ -78,11 +80,63 @@ impl PyObjectRef {
self.call_set_attr(vm, attr_name, None)
}

// Perform a comparison, raising TypeError when the requested comparison
// operator is not supported.
// see: CPython PyObject_RichCompare
fn _cmp(
&self,
other: &Self,
op: PyComparisonOp,
vm: &VirtualMachine,
) -> PyResult<Either<PyObjectRef, bool>> {
let swapped = op.swapped();
let call_cmp = |obj: &PyObjectRef, other, op| {
let cmp = obj
.class()
.mro_find_map(|cls| cls.slots.richcompare.load())
.unwrap();
Ok(match cmp(obj, other, op, vm)? {
Either::A(obj) => PyArithmeticValue::from_object(vm, obj).map(Either::A),
Either::B(arithmetic) => arithmetic.map(Either::B),
})
};

let mut checked_reverse_op = false;
let is_strict_subclass = {
let self_class = self.class();
let other_class = other.class();
!self_class.is(&other_class) && other_class.issubclass(&self_class)
};
if is_strict_subclass {
let res = vm.with_recursion("in comparison", || call_cmp(other, self, swapped))?;
checked_reverse_op = true;
if let PyArithmeticValue::Implemented(x) = res {
return Ok(x);
}
}
if let PyArithmeticValue::Implemented(x) =
vm.with_recursion("in comparison", || call_cmp(self, other, op))?
{
return Ok(x);
}
if !checked_reverse_op {
let res = vm.with_recursion("in comparison", || call_cmp(other, self, swapped))?;
if let PyArithmeticValue::Implemented(x) = res {
return Ok(x);
}
}
match op {
PyComparisonOp::Eq => Ok(Either::B(self.is(&other))),
PyComparisonOp::Ne => Ok(Either::B(!self.is(&other))),
_ => Err(vm.new_unsupported_binop_error(self, other, op.operator_token())),
}
}

// PyObject *PyObject_GenericGetDict(PyObject *o, void *context)
// int PyObject_GenericSetDict(PyObject *o, PyObject *value, void *context)

pub fn rich_compare(self, other: Self, opid: PyComparisonOp, vm: &VirtualMachine) -> PyResult {
vm.obj_cmp(self, other, opid)
self._cmp(&other, opid, vm).map(|res| res.into_pyobject(vm))
}

pub fn rich_compare_bool(
Expand All @@ -91,7 +145,10 @@ impl PyObjectRef {
opid: PyComparisonOp,
vm: &VirtualMachine,
) -> PyResult<bool> {
vm.bool_cmp(self, other, opid)
match self._cmp(other, opid, vm)? {
Either::A(obj) => obj.try_to_bool(vm),
Either::B(other) => Ok(other),
}
}

pub fn repr(&self, vm: &VirtualMachine) -> PyResult<PyStrRef> {
Expand Down
4 changes: 2 additions & 2 deletions vm/src/stdlib/builtins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -477,14 +477,14 @@ mod builtins {
let mut x_key = vm.invoke(key_func, (x.clone(),))?;
for y in candidates_iter {
let y_key = vm.invoke(key_func, (y.clone(),))?;
if vm.bool_cmp(&y_key, &x_key, op)? {
if y_key.rich_compare_bool(&x_key, op, vm)? {
x = y;
x_key = y_key;
}
}
} else {
for y in candidates_iter {
if vm.bool_cmp(&y, &x, op)? {
if y.rich_compare_bool(&x, op, vm)? {
x = y;
}
}
Expand Down
2 changes: 1 addition & 1 deletion vm/src/stdlib/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2379,7 +2379,7 @@ mod _io {
}
};
use crate::types::PyComparisonOp;
if vm.bool_cmp(&cookie, &vm.ctx.new_int(0).into(), PyComparisonOp::Lt)? {
if cookie.rich_compare_bool(&vm.ctx.new_int(0).into(), PyComparisonOp::Lt, vm)? {
return Err(
vm.new_value_error(format!("negative seek position {}", vm.to_repr(&cookie)?))
);
Expand Down
12 changes: 6 additions & 6 deletions vm/src/stdlib/operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,37 +27,37 @@ mod _operator {
/// Same as a < b.
#[pyfunction]
fn lt(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult {
vm.obj_cmp(a, b, Lt)
a.rich_compare(b, Lt, vm)
}

/// Same as a <= b.
#[pyfunction]
fn le(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult {
vm.obj_cmp(a, b, Le)
a.rich_compare(b, Le, vm)
}

/// Same as a > b.
#[pyfunction]
fn gt(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult {
vm.obj_cmp(a, b, Gt)
a.rich_compare(b, Gt, vm)
}

/// Same as a >= b.
#[pyfunction]
fn ge(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult {
vm.obj_cmp(a, b, Ge)
a.rich_compare(b, Ge, vm)
}

/// Same as a == b.
#[pyfunction]
fn eq(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult {
vm.obj_cmp(a, b, Eq)
a.rich_compare(b, Eq, vm)
}

/// Same as a != b.
#[pyfunction]
fn ne(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult {
vm.obj_cmp(a, b, Ne)
a.rich_compare(b, Ne, vm)
}

/// Same as not a.
Expand Down
70 changes: 3 additions & 67 deletions vm/src/vm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ use crate::{
signal::NSIG,
stdlib,
types::PyComparisonOp,
utils::Either,
IdProtocol, ItemProtocol, PyArithmeticValue, PyContext, PyLease, PyMethod, PyObject,
PyObjectRef, PyObjectWrap, PyRef, PyRefExact, PyResult, PyValue, TryFromObject, TypeProtocol,
};
Expand Down Expand Up @@ -1817,69 +1816,6 @@ impl VirtualMachine {
.invoke((), self)
}

// Perform a comparison, raising TypeError when the requested comparison
// operator is not supported.
// see: CPython PyObject_RichCompare
fn _cmp(
&self,
v: &PyObjectRef,
w: &PyObjectRef,
op: PyComparisonOp,
) -> PyResult<Either<PyObjectRef, bool>> {
let swapped = op.swapped();
let call_cmp = |obj: &PyObjectRef, other, op| {
let cmp = obj
.class()
.mro_find_map(|cls| cls.slots.richcompare.load())
.unwrap();
Ok(match cmp(obj, other, op, self)? {
Either::A(obj) => PyArithmeticValue::from_object(self, obj).map(Either::A),
Either::B(arithmetic) => arithmetic.map(Either::B),
})
};

let mut checked_reverse_op = false;
let is_strict_subclass = {
let v_class = v.class();
let w_class = w.class();
!v_class.is(&w_class) && w_class.issubclass(&v_class)
};
if is_strict_subclass {
let res = self.with_recursion("in comparison", || call_cmp(w, v, swapped))?;
checked_reverse_op = true;
if let PyArithmeticValue::Implemented(x) = res {
return Ok(x);
}
}
if let PyArithmeticValue::Implemented(x) =
self.with_recursion("in comparison", || call_cmp(v, w, op))?
{
return Ok(x);
}
if !checked_reverse_op {
let res = self.with_recursion("in comparison", || call_cmp(w, v, swapped))?;
if let PyArithmeticValue::Implemented(x) = res {
return Ok(x);
}
}
match op {
PyComparisonOp::Eq => Ok(Either::B(v.is(&w))),
PyComparisonOp::Ne => Ok(Either::B(!v.is(&w))),
_ => Err(self.new_unsupported_binop_error(v, w, op.operator_token())),
}
}

pub fn bool_cmp(&self, a: &PyObjectRef, b: &PyObjectRef, op: PyComparisonOp) -> PyResult<bool> {
match self._cmp(a, b, op)? {
Either::A(obj) => obj.try_to_bool(self),
Either::B(b) => Ok(b),
}
}

pub fn obj_cmp(&self, a: PyObjectRef, b: PyObjectRef, op: PyComparisonOp) -> PyResult {
self._cmp(&a, &b, op).map(|res| res.into_pyobject(self))
}

pub fn obj_len_opt(&self, obj: &PyObjectRef) -> Option<PyResult<usize>> {
self.get_special_method(obj.clone(), "__len__")
.map(Result::ok)
Expand Down Expand Up @@ -2050,7 +1986,7 @@ impl VirtualMachine {
}

pub fn bool_eq(&self, a: &PyObjectRef, b: &PyObjectRef) -> PyResult<bool> {
self.bool_cmp(a, b, PyComparisonOp::Eq)
a.rich_compare_bool(b, PyComparisonOp::Eq, self)
}

pub fn identical_or_equal(&self, a: &PyObjectRef, b: &PyObjectRef) -> PyResult<bool> {
Expand All @@ -2062,7 +1998,7 @@ impl VirtualMachine {
}

pub fn bool_seq_lt(&self, a: &PyObjectRef, b: &PyObjectRef) -> PyResult<Option<bool>> {
let value = if self.bool_cmp(a, b, PyComparisonOp::Lt)? {
let value = if a.rich_compare_bool(b, PyComparisonOp::Lt, self)? {
Some(true)
} else if !self.bool_eq(a, b)? {
Some(false)
Expand All @@ -2073,7 +2009,7 @@ impl VirtualMachine {
}

pub fn bool_seq_gt(&self, a: &PyObjectRef, b: &PyObjectRef) -> PyResult<Option<bool>> {
let value = if self.bool_cmp(a, b, PyComparisonOp::Gt)? {
let value = if a.rich_compare_bool(b, PyComparisonOp::Gt, self)? {
Some(true)
} else if !self.bool_eq(a, b)? {
Some(false)
Expand Down