Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions Lib/test/test_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,6 @@ def __hash__(self):
with self.assertRaises(Exc):
d1 == d2

@unittest.skip("TODO: RUSTPYTHON")
def test_keys_contained(self):
self.helper_keys_contained(lambda x: x.keys())
self.helper_keys_contained(lambda x: x.items())
Expand Down Expand Up @@ -631,8 +630,6 @@ def helper_keys_contained(self, fn):
self.assertTrue(larger != larger3)
self.assertFalse(larger == larger3)

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_errors_in_view_containment_check(self):
class C:
def __eq__(self, other):
Expand Down
173 changes: 150 additions & 23 deletions vm/src/obj/objdict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,21 @@ use std::fmt;
use crossbeam_utils::atomic::AtomicCell;

use super::objiter;
use super::objset::PySet;
use super::objstr;
use super::objtype::{self, PyClassRef};
use crate::dictdatatype::{self, DictKey};
use crate::exceptions::PyBaseExceptionRef;
use crate::function::{KwArgs, OptionalArg, PyFuncArgs};
use crate::pyobject::{
BorrowValue, IdProtocol, IntoPyObject, ItemProtocol, PyAttributes, PyClassImpl, PyContext,
PyIterable, PyObjectRef, PyRef, PyResult, PyValue, TryFromObject, TypeProtocol,
BorrowValue, IdProtocol, IntoPyObject, ItemProtocol, PyArithmaticValue, PyAttributes,
PyClassImpl, PyComparisonValue, PyContext, PyIterable, PyObjectRef, PyRef, PyResult, PyValue,
TryFromObject, TypeProtocol,
};
use crate::vm::{ReprGuard, VirtualMachine};

use std::mem::size_of;
use PyArithmaticValue::{Implemented, NotImplemented};

pub type DictContentType = dictdatatype::Dict;

Expand Down Expand Up @@ -146,46 +149,65 @@ impl PyDict {
!self.entries.is_empty()
}

fn inner_eq(zelf: PyRef<Self>, other: &PyDict, vm: &VirtualMachine) -> PyResult<bool> {
if other.entries.len() != zelf.entries.len() {
return Ok(false);
fn inner_cmp(
zelf: PyRef<Self>,
other: PyDictRef,
size_func: fn(usize, usize) -> bool,
item: bool,
vm: &VirtualMachine,
) -> PyResult<PyComparisonValue> {
if size_func(zelf.len(), other.len()) {

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For eq and ne, this test leads the loop comparison.
But for other comparisons like lt, gt, the size comparison doesn't guarantee subset.

Assume a = {'a': 1, 'b' 2}; b = {'a': 1, 'b': 3, 'c': 4},
a.items() > b.items() is false but also b.items() > a.items() is false.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your advice! but don't the cpython also return false and false? I put it at first because it seems cpython check a length at first. link Should I change it as you said?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I am sorry. I thought something totally wrong. Yes, it seems current behaviour is exactly as expected and also same as also what I described. Sorry for making confusion!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No problem! I really want to thank you for your help.

return Ok(Implemented(false));
}
for (k, v1) in zelf {
match other.entries.get(vm, &k)? {
let (zelf, other) = if zelf.len() < other.len() {
(other, zelf)
} else {
(zelf, other)
};
for (k, v1) in other {
match zelf.get_item_option(k, vm)? {
Some(v2) => {
if v1.is(&v2) {
continue;
}
if !vm.bool_eq(v1, v2)? {
return Ok(false);
if item && !vm.bool_eq(v1, v2)? {
return Ok(Implemented(false));
}
}
None => {
return Ok(false);
return Ok(Implemented(false));
}
}
}
Ok(true)
Ok(Implemented(true))
}

#[pymethod(magic)]
fn eq(zelf: PyRef<Self>, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
if let Some(other) = other.payload::<PyDict>() {
let eq = Self::inner_eq(zelf, other, vm)?;
Ok(vm.ctx.new_bool(eq))
fn eq(
zelf: PyRef<Self>,
other: PyObjectRef,
vm: &VirtualMachine,
) -> PyResult<PyComparisonValue> {
if let Ok(other) = other.downcast::<PyDict>() {
Self::inner_cmp(
zelf,
other,
|zelf: usize, other: usize| -> bool { zelf != other },
true,
vm,
)
} else {
Ok(vm.ctx.not_implemented())
Ok(NotImplemented)
}
}

#[pymethod(magic)]
fn ne(zelf: PyRef<Self>, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
if let Some(other) = other.payload::<PyDict>() {
let neq = !Self::inner_eq(zelf, other, vm)?;
Ok(vm.ctx.new_bool(neq))
} else {
Ok(vm.ctx.not_implemented())
}
fn ne(
zelf: PyRef<Self>,
other: PyObjectRef,
vm: &VirtualMachine,
) -> PyResult<PyComparisonValue> {
Ok(Self::eq(zelf, other, vm)?.map(|v| !v))
}

#[pymethod(magic)]
Expand Down Expand Up @@ -611,6 +633,111 @@ macro_rules! dict_iterator {
fn reversed(&self) -> $reverse_iter_name {
$reverse_iter_name::new(self.dict.clone())
}

fn cmp(
zelf: PyRef<Self>,
other: PyObjectRef,
size_func: fn(usize, usize) -> bool,
vm: &VirtualMachine,
) -> PyResult<PyComparisonValue> {
match_class!(match other {
dictview @ Self => {
PyDict::inner_cmp(
zelf.dict.clone(),
dictview.dict.clone(),
size_func,
!zelf.class().is(&vm.ctx.types.dict_keys_type),
vm,
)
}
_set @ PySet => {
// TODO: Implement comparison for set
Ok(NotImplemented)
}
_ => {
Ok(NotImplemented)
}
})
}

#[pymethod(name = "__eq__")]
fn eq(
zelf: PyRef<Self>,
other: PyObjectRef,
vm: &VirtualMachine,
) -> PyResult<PyComparisonValue> {
Self::cmp(
zelf,
other,
|zelf: usize, other: usize| -> bool { zelf != other },
vm,
)
}

#[pymethod(name = "__ne__")]
fn ne(
zelf: PyRef<Self>,
other: PyObjectRef,
vm: &VirtualMachine,
) -> PyResult<PyComparisonValue> {
Ok(Self::eq(zelf, other, vm)?.map(|v| !v))
}

#[pymethod(name = "__lt__")]
fn lt(
zelf: PyRef<Self>,
other: PyObjectRef,
vm: &VirtualMachine,
) -> PyResult<PyComparisonValue> {
Self::cmp(
zelf,
other,
|zelf: usize, other: usize| -> bool { zelf >= other },
vm,
)
}

#[pymethod(name = "__le__")]
fn le(
zelf: PyRef<Self>,
other: PyObjectRef,
vm: &VirtualMachine,
) -> PyResult<PyComparisonValue> {
Self::cmp(
zelf,
other,
|zelf: usize, other: usize| -> bool { zelf > other },
vm,
)
}

#[pymethod(name = "__gt__")]
fn gt(
zelf: PyRef<Self>,
other: PyObjectRef,
vm: &VirtualMachine,
) -> PyResult<PyComparisonValue> {
Self::cmp(
zelf,
other,
|zelf: usize, other: usize| -> bool { zelf <= other },
vm,
)
}

#[pymethod(name = "__ge__")]
fn ge(
zelf: PyRef<Self>,
other: PyObjectRef,
vm: &VirtualMachine,
) -> PyResult<PyComparisonValue> {
Self::cmp(
zelf,
other,
|zelf: usize, other: usize| -> bool { zelf < other },
vm,
)
}
}

impl PyValue for $name {
Expand Down