Skip to content

Commit b7dab09

Browse files
authored
Merge pull request #3347 from AP2008/relocate-rich_compare
Relocate `vm.rich_compare` to `obj.rich_compare`
2 parents 497eead + c427462 commit b7dab09

8 files changed

Lines changed: 86 additions & 89 deletions

File tree

stdlib/src/bisect.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ mod _bisect {
7777
while lo < hi {
7878
// Handles issue 13496.
7979
let mid = (lo + hi) / 2;
80-
if vm.bool_cmp(&a.get_item(mid, vm)?, &x, Lt)? {
80+
if a.get_item(mid, vm)?.rich_compare_bool(&x, Lt, vm)? {
8181
lo = mid + 1;
8282
} else {
8383
hi = mid;
@@ -105,7 +105,7 @@ mod _bisect {
105105
while lo < hi {
106106
// Handles issue 13496.
107107
let mid = (lo + hi) / 2;
108-
if vm.bool_cmp(&x, &a.get_item(mid, vm)?, Lt)? {
108+
if x.rich_compare_bool(&a.get_item(mid, vm)?, Lt, vm)? {
109109
hi = mid;
110110
} else {
111111
lo = mid + 1;

vm/src/builtins/list.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,7 @@ fn do_sort(
613613
} else {
614614
PyComparisonOp::Gt
615615
};
616-
let cmp = |a: &PyObjectRef, b: &PyObjectRef| vm.bool_cmp(a, b, op);
616+
let cmp = |a: &PyObjectRef, b: &PyObjectRef| a.rich_compare_bool(b, op, vm);
617617

618618
if let Some(ref key_func) = key_func {
619619
let mut items = values

vm/src/frame.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1772,12 +1772,16 @@ impl ExecutingFrame<'_> {
17721772
let b = self.pop_value();
17731773
let a = self.pop_value();
17741774
let value = match *op {
1775-
bytecode::ComparisonOperator::Equal => vm.obj_cmp(a, b, PyComparisonOp::Eq)?,
1776-
bytecode::ComparisonOperator::NotEqual => vm.obj_cmp(a, b, PyComparisonOp::Ne)?,
1777-
bytecode::ComparisonOperator::Less => vm.obj_cmp(a, b, PyComparisonOp::Lt)?,
1778-
bytecode::ComparisonOperator::LessOrEqual => vm.obj_cmp(a, b, PyComparisonOp::Le)?,
1779-
bytecode::ComparisonOperator::Greater => vm.obj_cmp(a, b, PyComparisonOp::Gt)?,
1780-
bytecode::ComparisonOperator::GreaterOrEqual => vm.obj_cmp(a, b, PyComparisonOp::Ge)?,
1775+
bytecode::ComparisonOperator::Equal => a.rich_compare(b, PyComparisonOp::Eq, vm)?,
1776+
bytecode::ComparisonOperator::NotEqual => a.rich_compare(b, PyComparisonOp::Ne, vm)?,
1777+
bytecode::ComparisonOperator::Less => a.rich_compare(b, PyComparisonOp::Lt, vm)?,
1778+
bytecode::ComparisonOperator::LessOrEqual => {
1779+
a.rich_compare(b, PyComparisonOp::Le, vm)?
1780+
}
1781+
bytecode::ComparisonOperator::Greater => a.rich_compare(b, PyComparisonOp::Gt, vm)?,
1782+
bytecode::ComparisonOperator::GreaterOrEqual => {
1783+
a.rich_compare(b, PyComparisonOp::Ge, vm)?
1784+
}
17811785
bytecode::ComparisonOperator::Is => vm.ctx.new_bool(self._is(a, b)).into(),
17821786
bytecode::ComparisonOperator::IsNot => vm.ctx.new_bool(self._is_not(a, b)).into(),
17831787
bytecode::ComparisonOperator::In => vm.ctx.new_bool(self._in(vm, a, b)?).into(),

vm/src/protocol/object.rs

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@ use crate::{
55
builtins::{pystr::IntoPyStrRef, PyBytes, PyInt, PyStrRef, PyTupleRef},
66
bytesinner::ByteInnerNewOptions,
77
common::{hash::PyHash, str::to_ascii},
8-
function::OptionalArg,
8+
function::{IntoPyObject, OptionalArg},
99
protocol::PyIter,
1010
pyobject::IdProtocol,
1111
pyref_type_error,
1212
types::{Constructor, PyComparisonOp},
13-
PyObjectRef, PyResult, TryFromObject, TypeProtocol, VirtualMachine,
13+
utils::Either,
14+
IdProtocol, PyArithmeticValue, PyObjectRef, PyResult, TryFromObject, TypeProtocol,
15+
VirtualMachine,
1416
};
1517

1618
// RustPython doesn't need these items
@@ -79,11 +81,63 @@ impl PyObjectRef {
7981
self.call_set_attr(vm, attr_name, None)
8082
}
8183

84+
// Perform a comparison, raising TypeError when the requested comparison
85+
// operator is not supported.
86+
// see: CPython PyObject_RichCompare
87+
fn _cmp(
88+
&self,
89+
other: &Self,
90+
op: PyComparisonOp,
91+
vm: &VirtualMachine,
92+
) -> PyResult<Either<PyObjectRef, bool>> {
93+
let swapped = op.swapped();
94+
let call_cmp = |obj: &PyObjectRef, other, op| {
95+
let cmp = obj
96+
.class()
97+
.mro_find_map(|cls| cls.slots.richcompare.load())
98+
.unwrap();
99+
Ok(match cmp(obj, other, op, vm)? {
100+
Either::A(obj) => PyArithmeticValue::from_object(vm, obj).map(Either::A),
101+
Either::B(arithmetic) => arithmetic.map(Either::B),
102+
})
103+
};
104+
105+
let mut checked_reverse_op = false;
106+
let is_strict_subclass = {
107+
let self_class = self.class();
108+
let other_class = other.class();
109+
!self_class.is(&other_class) && other_class.issubclass(&self_class)
110+
};
111+
if is_strict_subclass {
112+
let res = vm.with_recursion("in comparison", || call_cmp(other, self, swapped))?;
113+
checked_reverse_op = true;
114+
if let PyArithmeticValue::Implemented(x) = res {
115+
return Ok(x);
116+
}
117+
}
118+
if let PyArithmeticValue::Implemented(x) =
119+
vm.with_recursion("in comparison", || call_cmp(self, other, op))?
120+
{
121+
return Ok(x);
122+
}
123+
if !checked_reverse_op {
124+
let res = vm.with_recursion("in comparison", || call_cmp(other, self, swapped))?;
125+
if let PyArithmeticValue::Implemented(x) = res {
126+
return Ok(x);
127+
}
128+
}
129+
match op {
130+
PyComparisonOp::Eq => Ok(Either::B(self.is(&other))),
131+
PyComparisonOp::Ne => Ok(Either::B(!self.is(&other))),
132+
_ => Err(vm.new_unsupported_binop_error(self, other, op.operator_token())),
133+
}
134+
}
135+
82136
// PyObject *PyObject_GenericGetDict(PyObject *o, void *context)
83137
// int PyObject_GenericSetDict(PyObject *o, PyObject *value, void *context)
84138

85139
pub fn rich_compare(self, other: Self, opid: PyComparisonOp, vm: &VirtualMachine) -> PyResult {
86-
vm.obj_cmp(self, other, opid)
140+
self._cmp(&other, opid, vm).map(|res| res.into_pyobject(vm))
87141
}
88142

89143
pub fn rich_compare_bool(
@@ -92,7 +146,10 @@ impl PyObjectRef {
92146
opid: PyComparisonOp,
93147
vm: &VirtualMachine,
94148
) -> PyResult<bool> {
95-
vm.bool_cmp(self, other, opid)
149+
match self._cmp(other, opid, vm)? {
150+
Either::A(obj) => obj.try_to_bool(vm),
151+
Either::B(other) => Ok(other),
152+
}
96153
}
97154

98155
pub fn repr(&self, vm: &VirtualMachine) -> PyResult<PyStrRef> {

vm/src/stdlib/builtins.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,14 +477,14 @@ mod builtins {
477477
let mut x_key = vm.invoke(key_func, (x.clone(),))?;
478478
for y in candidates_iter {
479479
let y_key = vm.invoke(key_func, (y.clone(),))?;
480-
if vm.bool_cmp(&y_key, &x_key, op)? {
480+
if y_key.rich_compare_bool(&x_key, op, vm)? {
481481
x = y;
482482
x_key = y_key;
483483
}
484484
}
485485
} else {
486486
for y in candidates_iter {
487-
if vm.bool_cmp(&y, &x, op)? {
487+
if y.rich_compare_bool(&x, op, vm)? {
488488
x = y;
489489
}
490490
}

vm/src/stdlib/io.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2379,7 +2379,7 @@ mod _io {
23792379
}
23802380
};
23812381
use crate::types::PyComparisonOp;
2382-
if vm.bool_cmp(&cookie, &vm.ctx.new_int(0).into(), PyComparisonOp::Lt)? {
2382+
if cookie.rich_compare_bool(&vm.ctx.new_int(0).into(), PyComparisonOp::Lt, vm)? {
23832383
return Err(
23842384
vm.new_value_error(format!("negative seek position {}", vm.to_repr(&cookie)?))
23852385
);

vm/src/stdlib/operator.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,37 +27,37 @@ mod _operator {
2727
/// Same as a < b.
2828
#[pyfunction]
2929
fn lt(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult {
30-
vm.obj_cmp(a, b, Lt)
30+
a.rich_compare(b, Lt, vm)
3131
}
3232

3333
/// Same as a <= b.
3434
#[pyfunction]
3535
fn le(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult {
36-
vm.obj_cmp(a, b, Le)
36+
a.rich_compare(b, Le, vm)
3737
}
3838

3939
/// Same as a > b.
4040
#[pyfunction]
4141
fn gt(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult {
42-
vm.obj_cmp(a, b, Gt)
42+
a.rich_compare(b, Gt, vm)
4343
}
4444

4545
/// Same as a >= b.
4646
#[pyfunction]
4747
fn ge(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult {
48-
vm.obj_cmp(a, b, Ge)
48+
a.rich_compare(b, Ge, vm)
4949
}
5050

5151
/// Same as a == b.
5252
#[pyfunction]
5353
fn eq(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult {
54-
vm.obj_cmp(a, b, Eq)
54+
a.rich_compare(b, Eq, vm)
5555
}
5656

5757
/// Same as a != b.
5858
#[pyfunction]
5959
fn ne(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> PyResult {
60-
vm.obj_cmp(a, b, Ne)
60+
a.rich_compare(b, Ne, vm)
6161
}
6262

6363
/// Same as not a.

vm/src/vm.rs

Lines changed: 3 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ use crate::{
2727
signal::NSIG,
2828
stdlib,
2929
types::PyComparisonOp,
30-
utils::Either,
3130
IdProtocol, ItemProtocol, PyArithmeticValue, PyContext, PyLease, PyMethod, PyObject,
3231
PyObjectRef, PyObjectWrap, PyRef, PyRefExact, PyResult, PyValue, TryFromObject, TypeProtocol,
3332
};
@@ -1786,69 +1785,6 @@ impl VirtualMachine {
17861785
.invoke((), self)
17871786
}
17881787

1789-
// Perform a comparison, raising TypeError when the requested comparison
1790-
// operator is not supported.
1791-
// see: CPython PyObject_RichCompare
1792-
fn _cmp(
1793-
&self,
1794-
v: &PyObjectRef,
1795-
w: &PyObjectRef,
1796-
op: PyComparisonOp,
1797-
) -> PyResult<Either<PyObjectRef, bool>> {
1798-
let swapped = op.swapped();
1799-
let call_cmp = |obj: &PyObjectRef, other, op| {
1800-
let cmp = obj
1801-
.class()
1802-
.mro_find_map(|cls| cls.slots.richcompare.load())
1803-
.unwrap();
1804-
Ok(match cmp(obj, other, op, self)? {
1805-
Either::A(obj) => PyArithmeticValue::from_object(self, obj).map(Either::A),
1806-
Either::B(arithmetic) => arithmetic.map(Either::B),
1807-
})
1808-
};
1809-
1810-
let mut checked_reverse_op = false;
1811-
let is_strict_subclass = {
1812-
let v_class = v.class();
1813-
let w_class = w.class();
1814-
!v_class.is(&w_class) && w_class.issubclass(&v_class)
1815-
};
1816-
if is_strict_subclass {
1817-
let res = self.with_recursion("in comparison", || call_cmp(w, v, swapped))?;
1818-
checked_reverse_op = true;
1819-
if let PyArithmeticValue::Implemented(x) = res {
1820-
return Ok(x);
1821-
}
1822-
}
1823-
if let PyArithmeticValue::Implemented(x) =
1824-
self.with_recursion("in comparison", || call_cmp(v, w, op))?
1825-
{
1826-
return Ok(x);
1827-
}
1828-
if !checked_reverse_op {
1829-
let res = self.with_recursion("in comparison", || call_cmp(w, v, swapped))?;
1830-
if let PyArithmeticValue::Implemented(x) = res {
1831-
return Ok(x);
1832-
}
1833-
}
1834-
match op {
1835-
PyComparisonOp::Eq => Ok(Either::B(v.is(&w))),
1836-
PyComparisonOp::Ne => Ok(Either::B(!v.is(&w))),
1837-
_ => Err(self.new_unsupported_binop_error(v, w, op.operator_token())),
1838-
}
1839-
}
1840-
1841-
pub fn bool_cmp(&self, a: &PyObjectRef, b: &PyObjectRef, op: PyComparisonOp) -> PyResult<bool> {
1842-
match self._cmp(a, b, op)? {
1843-
Either::A(obj) => obj.try_to_bool(self),
1844-
Either::B(b) => Ok(b),
1845-
}
1846-
}
1847-
1848-
pub fn obj_cmp(&self, a: PyObjectRef, b: PyObjectRef, op: PyComparisonOp) -> PyResult {
1849-
self._cmp(&a, &b, op).map(|res| res.into_pyobject(self))
1850-
}
1851-
18521788
pub fn obj_len_opt(&self, obj: &PyObjectRef) -> Option<PyResult<usize>> {
18531789
self.get_special_method(obj.clone(), "__len__")
18541790
.map(Result::ok)
@@ -2019,7 +1955,7 @@ impl VirtualMachine {
20191955
}
20201956

20211957
pub fn bool_eq(&self, a: &PyObjectRef, b: &PyObjectRef) -> PyResult<bool> {
2022-
self.bool_cmp(a, b, PyComparisonOp::Eq)
1958+
a.rich_compare_bool(b, PyComparisonOp::Eq, self)
20231959
}
20241960

20251961
pub fn identical_or_equal(&self, a: &PyObjectRef, b: &PyObjectRef) -> PyResult<bool> {
@@ -2031,7 +1967,7 @@ impl VirtualMachine {
20311967
}
20321968

20331969
pub fn bool_seq_lt(&self, a: &PyObjectRef, b: &PyObjectRef) -> PyResult<Option<bool>> {
2034-
let value = if self.bool_cmp(a, b, PyComparisonOp::Lt)? {
1970+
let value = if a.rich_compare_bool(b, PyComparisonOp::Lt, self)? {
20351971
Some(true)
20361972
} else if !self.bool_eq(a, b)? {
20371973
Some(false)
@@ -2042,7 +1978,7 @@ impl VirtualMachine {
20421978
}
20431979

20441980
pub fn bool_seq_gt(&self, a: &PyObjectRef, b: &PyObjectRef) -> PyResult<Option<bool>> {
2045-
let value = if self.bool_cmp(a, b, PyComparisonOp::Gt)? {
1981+
let value = if a.rich_compare_bool(b, PyComparisonOp::Gt, self)? {
20461982
Some(true)
20471983
} else if !self.bool_eq(a, b)? {
20481984
Some(false)

0 commit comments

Comments
 (0)