Skip to content

Commit 7eb600b

Browse files
committed
Relocate vm.rich_compare to obj.rich_compare
1 parent 72c3a70 commit 7eb600b

8 files changed

Lines changed: 86 additions & 97 deletions

File tree

stdlib/src/bisect.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ mod _bisect {
7777
while lo < hi {
7878
// Handles issue 13496.
7979
let mid = (lo + hi) / 2;
80-
if vm.bool_cmp(&a.get_item(mid, vm)?, &x, Lt)? {
80+
if a.get_item(mid, vm)?.rich_compare_bool(&x, Lt, vm)? {
8181
lo = mid + 1;
8282
} else {
8383
hi = mid;
@@ -105,7 +105,7 @@ mod _bisect {
105105
while lo < hi {
106106
// Handles issue 13496.
107107
let mid = (lo + hi) / 2;
108-
if vm.bool_cmp(&x, &a.get_item(mid, vm)?, Lt)? {
108+
if x.rich_compare_bool(&a.get_item(mid, vm)?, Lt, vm)? {
109109
hi = mid;
110110
} else {
111111
lo = mid + 1;

vm/src/builtins/list.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ fn do_sort(
613613
} else {
614614
PyComparisonOp::Gt
615615
};
616-
let cmp = |a: &PyObjectRef, b: &PyObjectRef| vm.bool_cmp(a, b, op);
616+
let cmp = |a: &PyObjectRef, b: &PyObjectRef| a.rich_compare_bool(b, op, vm);
617617

618618
if let Some(ref key_func) = key_func {
619619
let mut items = values

vm/src/frame.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1772,12 +1772,16 @@ impl ExecutingFrame<'_> {
17721772
let b = self.pop_value();
17731773
let a = self.pop_value();
17741774
let value = match *op {
1775-
bytecode::ComparisonOperator::Equal => vm.obj_cmp(a, b, PyComparisonOp::Eq)?,
1776-
bytecode::ComparisonOperator::NotEqual => vm.obj_cmp(a, b, PyComparisonOp::Ne)?,
1777-
bytecode::ComparisonOperator::Less => vm.obj_cmp(a, b, PyComparisonOp::Lt)?,
1778-
bytecode::ComparisonOperator::LessOrEqual => vm.obj_cmp(a, b, PyComparisonOp::Le)?,
1779-
bytecode::ComparisonOperator::Greater => vm.obj_cmp(a, b, PyComparisonOp::Gt)?,
1780-
bytecode::ComparisonOperator::GreaterOrEqual => vm.obj_cmp(a, b, PyComparisonOp::Ge)?,
1775+
bytecode::ComparisonOperator::Equal => a.rich_compare(b, PyComparisonOp::Eq, vm)?,
1776+
bytecode::ComparisonOperator::NotEqual => a.rich_compare(b, PyComparisonOp::Ne, vm)?,
1777+
bytecode::ComparisonOperator::Less => a.rich_compare(b, PyComparisonOp::Lt, vm)?,
1778+
bytecode::ComparisonOperator::LessOrEqual => {
1779+
a.rich_compare(b, PyComparisonOp::Le, vm)?
1780+
}
1781+
bytecode::ComparisonOperator::Greater => a.rich_compare(b, PyComparisonOp::Gt, vm)?,
1782+
bytecode::ComparisonOperator::GreaterOrEqual => {
1783+
a.rich_compare(b, PyComparisonOp::Ge, vm)?
1784+
}
17811785
bytecode::ComparisonOperator::Is => vm.ctx.new_bool(self._is(a, b)).into(),
17821786
bytecode::ComparisonOperator::IsNot => vm.ctx.new_bool(self._is_not(a, b)).into(),
17831787
bytecode::ComparisonOperator::In => vm.ctx.new_bool(self._in(vm, a, b)?).into(),

vm/src/protocol/object.rs

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@ use crate::{
55
builtins::{pystr::IntoPyStrRef, PyBytes, PyInt, PyStrRef},
66
bytesinner::ByteInnerNewOptions,
77
common::{hash::PyHash, str::to_ascii},
8-
function::OptionalArg,
8+
function::{IntoPyObject, OptionalArg},
99
protocol::PyIter,
1010
pyref_type_error,
1111
types::{Constructor, PyComparisonOp},
12-
PyObjectRef, PyResult, TryFromObject, TypeProtocol, VirtualMachine,
12+
utils::Either,
13+
IdProtocol, PyArithmeticValue, PyObjectRef, PyResult, TryFromObject, TypeProtocol,
14+
VirtualMachine,
1315
};
1416

1517
// RustPython doesn't need these items
@@ -78,11 +80,63 @@ impl PyObjectRef {
7880
self.call_set_attr(vm, attr_name, None)
7981
}
8082

83+
// Perform a comparison, raising TypeError when the requested comparison
84+
// operator is not supported.
85+
// see: CPython PyObject_RichCompare
86+
fn _cmp(
87+
&self,
88+
other: &Self,
89+
op: PyComparisonOp,
90+
vm: &VirtualMachine,
91+
) -> PyResult<Either<PyObjectRef, bool>> {
92+
let swapped = op.swapped();
93+
let call_cmp = |obj: &PyObjectRef, other, op| {
94+
let cmp = obj
95+
.class()
96+
.mro_find_map(|cls| cls.slots.richcompare.load())
97+
.unwrap();
98+
Ok(match cmp(obj, other, op, vm)? {
99+
Either::A(obj) => PyArithmeticValue::from_object(vm, obj).map(Either::A),
100+
Either::B(arithmetic) => arithmetic.map(Either::B),
101+
})
102+
};
103+
104+
let mut checked_reverse_op = false;
105+
let is_strict_subclass = {
106+
let self_class = self.class();
107+
let other_class = other.class();
108+
!self_class.is(&other_class) && other_class.issubclass(&self_class)
109+
};
110+
if is_strict_subclass {
111+
let res = vm.with_recursion("in comparison", || call_cmp(other, self, swapped))?;
112+
checked_reverse_op = true;
113+
if let PyArithmeticValue::Implemented(x) = res {
114+
return Ok(x);
115+
}
116+
}
117+
if let PyArithmeticValue::Implemented(x) =
118+
vm.with_recursion("in comparison", || call_cmp(self, other, op))?
119+
{
120+
return Ok(x);
121+
}
122+
if !checked_reverse_op {
123+
let res = vm.with_recursion("in comparison", || call_cmp(other, self, swapped))?;
124+
if let PyArithmeticValue::Implemented(x) = res {
125+
return Ok(x);
126+
}
127+
}
128+
match op {
129+
PyComparisonOp::Eq => Ok(Either::B(self.is(&other))),
130+
PyComparisonOp::Ne => Ok(Either::B(!self.is(&other))),
131+
_ => Err(vm.new_unsupported_binop_error(self, other, op.operator_token())),
132+
}
133+
}
134+
81135
// PyObject *PyObject_GenericGetDict(PyObject *o, void *context)
82136
// int PyObject_GenericSetDict(PyObject *o, PyObject *value, void *context)
83137

84138
pub fn rich_compare(self, other: Self, opid: PyComparisonOp, vm: &VirtualMachine) -> PyResult {
85-
vm.obj_cmp(self, other, opid)
139+
self._cmp(&other, opid, vm).map(|res| res.into_pyobject(vm))
86140
}
87141

88142
pub fn rich_compare_bool(
@@ -91,7 +145,10 @@ impl PyObjectRef {
91145
opid: PyComparisonOp,
92146
vm: &VirtualMachine,
93147
) -> PyResult<bool> {
94-
vm.bool_cmp(self, other, opid)
148+
match self._cmp(other, opid, vm)? {
149+
Either::A(obj) => obj.try_to_bool(vm),
150+
Either::B(other) => Ok(other),
151+
}
95152
}
96153

97154
pub fn repr(&self, vm: &VirtualMachine) -> PyResult<PyStrRef> {

vm/src/stdlib/builtins.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,14 +477,14 @@ mod builtins {
477477
let mut x_key = vm.invoke(key_func, (x.clone(),))?;
478478
for y in candidates_iter {
479479
let y_key = vm.invoke(key_func, (y.clone(),))?;
480-
if vm.bool_cmp(&y_key, &x_key, op)? {
480+
if y_key.rich_compare_bool(&x_key, op, vm)? {
481481
x = y;
482482
x_key = y_key;
483483
}
484484
}
485485
} else {
486486
for y in candidates_iter {
487-
if vm.bool_cmp(&y, &x, op)? {
487+
if y.rich_compare_bool(&x, op, vm)? {
488488
x = y;
489489
}
490490
}

vm/src/stdlib/io.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2379,7 +2379,7 @@ mod _io {
23792379
}
23802380
};
23812381
use crate::types::PyComparisonOp;
2382-
if vm.bool_cmp(&cookie, &vm.ctx.new_int(0).into(), PyComparisonOp::Lt)? {
2382+
if cookie.rich_compare_bool(&vm.ctx.new_int(0).into(), PyComparisonOp::Lt, vm)? {
23832383
return Err(
23842384
vm.new_value_error(format!("negative seek position {}", vm.to_repr(&cookie)?))
23852385
);

vm/src/stdlib/operator.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,37 +27,37 @@ mod _operator {
2727
/// Same as a < b.
2828
#[pyfunction]
2929
fn lt(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult {
30-
vm.obj_cmp(a, b, Lt)
30+
a.rich_compare(b, Lt, vm)
3131
}
3232

3333
/// Same as a <= b.
3434
#[pyfunction]
3535
fn le(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult {
36-
vm.obj_cmp(a, b, Le)
36+
a.rich_compare(b, Le, vm)
3737
}
3838

3939
/// Same as a > b.
4040
#[pyfunction]
4141
fn gt(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult {
42-
vm.obj_cmp(a, b, Gt)
42+
a.rich_compare(b, Gt, vm)
4343
}
4444

4545
/// Same as a >= b.
4646
#[pyfunction]
4747
fn ge(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult {
48-
vm.obj_cmp(a, b, Ge)
48+
a.rich_compare(b, Ge, vm)
4949
}
5050

5151
/// Same as a == b.
5252
#[pyfunction]
5353
fn eq(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult {
54-
vm.obj_cmp(a, b, Eq)
54+
a.rich_compare(b, Eq, vm)
5555
}
5656

5757
/// Same as a != b.
5858
#[pyfunction]
5959
fn ne(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult {
60-
vm.obj_cmp(a, b, Ne)
60+
a.rich_compare(b, Ne, vm)
6161
}
6262

6363
/// Same as not a.

vm/src/vm.rs

Lines changed: 3 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ use crate::{
2727
signal::NSIG,
2828
stdlib,
2929
types::PyComparisonOp,
30-
utils::Either,
3130
IdProtocol, ItemProtocol, PyArithmeticValue, PyContext, PyLease, PyMethod, PyObject,
3231
PyObjectRef, PyObjectWrap, PyRef, PyRefExact, PyResult, PyValue, TryFromObject, TypeProtocol,
3332
};
@@ -1817,77 +1816,6 @@ impl VirtualMachine {
18171816
.invoke((), self)
18181817
}
18191818

1820-
// Perform a comparison, raising TypeError when the requested comparison
1821-
// operator is not supported.
1822-
// see: CPython PyObject_RichCompare
1823-
fn _cmp(
1824-
&self,
1825-
v: &PyObjectRef,
1826-
w: &PyObjectRef,
1827-
op: PyComparisonOp,
1828-
) -> PyResult<Either<PyObjectRef, bool>> {
1829-
let swapped = op.swapped();
1830-
let call_cmp = |obj: &PyObjectRef, other, op| {
1831-
let cmp = obj
1832-
.class()
1833-
.mro_find_map(|cls| cls.slots.richcompare.load())
1834-
.unwrap();
1835-
Ok(match cmp(obj, other, op, self)? {
1836-
Either::A(obj) => PyArithmeticValue::from_object(self, obj).map(Either::A),
1837-
Either::B(arithmetic) => arithmetic.map(Either::B),
1838-
})
1839-
};
1840-
1841-
let mut checked_reverse_op = false;
1842-
let is_strict_subclass = {
1843-
let v_class = v.class();
1844-
let w_class = w.class();
1845-
!v_class.is(&w_class) && w_class.issubclass(&v_class)
1846-
};
1847-
if is_strict_subclass {
1848-
let res = self.with_recursion("in comparison", || call_cmp(w, v, swapped))?;
1849-
checked_reverse_op = true;
1850-
if let PyArithmeticValue::Implemented(x) = res {
1851-
return Ok(x);
1852-
}
1853-
}
1854-
if let PyArithmeticValue::Implemented(x) =
1855-
self.with_recursion("in comparison", || call_cmp(v, w, op))?
1856-
{
1857-
return Ok(x);
1858-
}
1859-
if !checked_reverse_op {
1860-
let res = self.with_recursion("in comparison", || call_cmp(w, v, swapped))?;
1861-
if let PyArithmeticValue::Implemented(x) = res {
1862-
return Ok(x);
1863-
}
1864-
}
1865-
match op {
1866-
PyComparisonOp::Eq => Ok(Either::B(v.is(&w))),
1867-
PyComparisonOp::Ne => Ok(Either::B(!v.is(&w))),
1868-
_ => Err(self.new_unsupported_binop_error(v, w, op.operator_token())),
1869-
}
1870-
}
1871-
1872-
pub fn bool_cmp(&self, a: &PyObjectRef, b: &PyObjectRef, op: PyComparisonOp) -> PyResult<bool> {
1873-
match self._cmp(a, b, op)? {
1874-
Either::A(obj) => obj.try_to_bool(self),
1875-
Either::B(b) => Ok(b),
1876-
}
1877-
}
1878-
1879-
pub fn obj_cmp(&self, a: PyObjectRef, b: PyObjectRef, op: PyComparisonOp) -> PyResult {
1880-
self._cmp(&a, &b, op).map(|res| res.into_pyobject(self))
1881-
}
1882-
1883-
pub fn _hash(&self, obj: &PyObjectRef) -> PyResult<rustpython_common::hash::PyHash> {
1884-
let hash = obj
1885-
.class()
1886-
.mro_find_map(|cls| cls.slots.hash.load())
1887-
.unwrap(); // hash always exist
1888-
hash(obj, self)
1889-
}
1890-
18911819
pub fn obj_len_opt(&self, obj: &PyObjectRef) -> Option<PyResult<usize>> {
18921820
self.get_special_method(obj.clone(), "__len__")
18931821
.map(Result::ok)
@@ -2067,7 +1995,7 @@ impl VirtualMachine {
20671995
}
20681996

20691997
pub fn bool_eq(&self, a: &PyObjectRef, b: &PyObjectRef) -> PyResult<bool> {
2070-
self.bool_cmp(a, b, PyComparisonOp::Eq)
1998+
a.rich_compare_bool(b, PyComparisonOp::Eq, self)
20711999
}
20722000

20732001
pub fn identical_or_equal(&self, a: &PyObjectRef, b: &PyObjectRef) -> PyResult<bool> {
@@ -2079,7 +2007,7 @@ impl VirtualMachine {
20792007
}
20802008

20812009
pub fn bool_seq_lt(&self, a: &PyObjectRef, b: &PyObjectRef) -> PyResult<Option<bool>> {
2082-
let value = if self.bool_cmp(a, b, PyComparisonOp::Lt)? {
2010+
let value = if a.rich_compare_bool(b, PyComparisonOp::Lt, self)? {
20832011
Some(true)
20842012
} else if !self.bool_eq(a, b)? {
20852013
Some(false)
@@ -2090,7 +2018,7 @@ impl VirtualMachine {
20902018
}
20912019

20922020
pub fn bool_seq_gt(&self, a: &PyObjectRef, b: &PyObjectRef) -> PyResult<Option<bool>> {
2093-
let value = if self.bool_cmp(a, b, PyComparisonOp::Gt)? {
2021+
let value = if a.rich_compare_bool(b, PyComparisonOp::Gt, self)? {
20942022
Some(true)
20952023
} else if !self.bool_eq(a, b)? {
20962024
Some(false)

0 commit comments

Comments
 (0)