Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 170 additions & 1 deletion crates/capi/src/dictobject.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::pystate::with_vm;
use core::ffi::c_int;
use core::ptr::NonNull;
use rustpython_vm::AsObject;
use rustpython_vm::PyPayload;
use rustpython_vm::builtins::PyDict;

define_py_check!(fn PyDict_Check, types.dict_type);
Expand Down Expand Up @@ -64,6 +65,100 @@ pub unsafe extern "C" fn PyDict_Size(dict: *mut PyObject) -> isize {
})
}

#[unsafe(no_mangle)]
pub unsafe extern "C" fn PyDict_Contains(dict: *mut PyObject, key: *mut PyObject) -> c_int {
with_vm(|vm| {
let dict = unsafe { &*dict }.try_downcast_ref::<PyDict>(vm)?;
let key = unsafe { &*key };
Ok(dict.inner_getitem_opt(key, vm)?.is_some())
})
}

#[unsafe(no_mangle)]
pub unsafe extern "C" fn PyDict_Copy(dict: *mut PyObject) -> *mut PyObject {
with_vm(|vm| {
let dict = unsafe { &*dict }.try_downcast_ref::<PyDict>(vm)?;
Ok(dict.copy().into_ref(&vm.ctx))
})
}

#[unsafe(no_mangle)]
pub unsafe extern "C" fn PyDict_DelItem(dict: *mut PyObject, key: *mut PyObject) -> c_int {
with_vm(|vm| {
let dict = unsafe { &*dict }.try_downcast_ref::<PyDict>(vm)?;
let key = unsafe { &*key };
dict.del_item(key, vm)
})
}

#[unsafe(no_mangle)]
pub unsafe extern "C" fn PyDict_Items(dict: *mut PyObject) -> *mut PyObject {
with_vm(|vm| {
let dict = unsafe { &*dict }.try_downcast_ref::<PyDict>(vm)?;
let items = dict
.items_vec()
.into_iter()
.map(|(k, v)| vm.ctx.new_tuple(vec![k, v]).into())
.collect();
Ok(vm.ctx.new_list(items))
})
}

#[unsafe(no_mangle)]
pub unsafe extern "C" fn PyDict_Keys(dict: *mut PyObject) -> *mut PyObject {
with_vm(|vm| {
let dict = unsafe { &*dict }.try_downcast_ref::<PyDict>(vm)?;
Ok(vm.ctx.new_list(dict.keys_vec()))
})
}

#[unsafe(no_mangle)]
pub unsafe extern "C" fn PyDict_Values(dict: *mut PyObject) -> *mut PyObject {
with_vm(|vm| {
let dict = unsafe { &*dict }.try_downcast_ref::<PyDict>(vm)?;
Ok(vm.ctx.new_list(dict.values_vec()))
})
}

#[unsafe(no_mangle)]
pub unsafe extern "C" fn PyDict_Merge(
dict: *mut PyObject,
other: *mut PyObject,
override_: c_int,
) -> c_int {
with_vm(|vm| {
let dict = unsafe { &*dict }.try_downcast_ref::<PyDict>(vm)?;
let other = unsafe { &*other }.to_owned();
if override_ != 0 {
dict.merge_object(other, vm)
} else {
dict.merge_object_if_missing(other, vm)
}
})
}

#[unsafe(no_mangle)]
pub unsafe extern "C" fn PyDict_Update(dict: *mut PyObject, other: *mut PyObject) -> c_int {
with_vm(|vm| {
let dict = unsafe { &*dict }.try_downcast_ref::<PyDict>(vm)?;
let other = unsafe { &*other }.to_owned();
dict.merge_object(other, vm)
})
}

#[unsafe(no_mangle)]
pub unsafe extern "C" fn PyDict_MergeFromSeq2(
dict: *mut PyObject,
seq2: *mut PyObject,
override_: c_int,
) -> c_int {
with_vm(|vm| {
let dict = unsafe { &*dict }.try_downcast_ref::<PyDict>(vm)?;
let seq2 = unsafe { &*seq2 }.to_owned();
dict.merge_from_seq2(seq2, override_ != 0, vm)
})
}

#[unsafe(no_mangle)]
pub unsafe extern "C" fn PyDict_Next(
dict: *mut PyObject,
Expand Down Expand Up @@ -95,7 +190,7 @@ pub unsafe extern "C" fn PyDict_Next(
#[cfg(false)]
mod tests {
use pyo3::prelude::*;
use pyo3::types::{IntoPyDict, PyDict, PyInt};
use pyo3::types::{IntoPyDict, PyDict, PyDictMethods, PyInt, PyList};

#[test]
fn test_create_empty_dict() {
Expand Down Expand Up @@ -129,4 +224,78 @@ mod tests {
assert_eq!(values, vec![1, 2, 3, 4]);
})
}

#[test]
fn dict_contains() {
Python::attach(|py| {
let dict = [(1, 2)].into_py_dict(py).unwrap();
assert!(dict.contains(1).unwrap());
assert!(!dict.contains(3).unwrap());
})
}

#[test]
fn dict_copy_and_del_item() {
Python::attach(|py| {
let dict = [(1, 2), (3, 4)].into_py_dict(py).unwrap();
let copied = dict.copy().unwrap();
assert_eq!(copied.len(), 2);
copied.del_item(1).unwrap();
assert!(!copied.contains(1).unwrap());
})
}

#[test]
fn dict_keys_values_items() {
Python::attach(|py| {
let dict = [(1, 2), (3, 4)].into_py_dict(py).unwrap();
assert_eq!(dict.keys().len(), 2);
assert_eq!(dict.values().len(), 2);
assert_eq!(dict.items().len(), 2);
})
}

#[test]
fn dict_update_and_merge() {
Python::attach(|py| {
let dict = [(1, 10)].into_py_dict(py).unwrap();
let replacement = [(1, 20), (2, 30)].into_py_dict(py).unwrap();
dict.update(replacement.as_mapping()).unwrap();
assert_eq!(
dict.get_item(1).unwrap().unwrap().extract::<i32>().unwrap(),
20
);
assert_eq!(
dict.get_item(2).unwrap().unwrap().extract::<i32>().unwrap(),
30
);

let merged_missing = [(1, 99), (3, 40)].into_py_dict(py).unwrap();
dict.update_if_missing(merged_missing.as_mapping()).unwrap();
assert_eq!(
dict.get_item(1).unwrap().unwrap().extract::<i32>().unwrap(),
20
);
assert_eq!(
dict.get_item(3).unwrap().unwrap().extract::<i32>().unwrap(),
40
);
})
}

#[test]
fn dict_merge_from_seq2() {
Python::attach(|py| {
let seq = PyList::new(py, [(1, 10), (1, 20), (2, 30)]).unwrap();
let dict = PyDict::from_sequence(seq.as_any()).unwrap();
assert_eq!(
dict.get_item(1).unwrap().unwrap().extract::<i32>().unwrap(),
20
);
assert_eq!(
dict.get_item(2).unwrap().unwrap().extract::<i32>().unwrap(),
30
);
})
}
}
84 changes: 62 additions & 22 deletions crates/vm/src/builtins/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,17 @@ impl PyDict {
self.entries.items()
}

// Used in update and ior.
pub(crate) fn merge_object(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
fn merge_object_with_override(
&self,
other: PyObjectRef,
override_existing: bool,
vm: &VirtualMachine,
) -> PyResult<()> {
let casted: Result<PyRefExact<Self>, _> = other.downcast_exact(vm);
let other = match casted {
Ok(dict_other) => return self.merge_dict(dict_other.into_pyref(), vm),
Ok(dict_other) => {
return self.merge_dict(dict_other.into_pyref(), override_existing, vm);
}
Err(other) => other,
};
let dict = &self.entries;
Expand All @@ -157,6 +163,9 @@ impl PyDict {
Ok(keys_method) => {
let keys = keys_method.call((), vm)?.get_iter(vm)?;
while let PyIterReturn::Return(key) = keys.next(vm)? {
if !override_existing && dict.contains(vm, &*key)? {
continue;
}
let val = other.get_item(&*key, vm)?;
dict.insert(vm, &*key, val)?;
}
Expand All @@ -166,31 +175,62 @@ impl PyDict {
Err(e) => return Err(e),
};
if !has_keys {
let iter = other.get_iter(vm)?;
loop {
fn err(vm: &VirtualMachine) -> PyBaseExceptionRef {
vm.new_value_error("Iterator must have exactly two elements")
}
let element = match iter.next(vm)? {
PyIterReturn::Return(obj) => obj,
PyIterReturn::StopIteration(_) => break,
};
let elem_iter = element.get_iter(vm)?;
let key = elem_iter.next(vm)?.into_result().map_err(|_| err(vm))?;
let value = elem_iter.next(vm)?.into_result().map_err(|_| err(vm))?;
if matches!(elem_iter.next(vm)?, PyIterReturn::Return(_)) {
return Err(err(vm));
}
dict.insert(vm, &*key, value)?;
return self.merge_from_seq2(other, override_existing, vm);
}
Ok(())
}

// Used in update and ior.
pub fn merge_object(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
self.merge_object_with_override(other, true, vm)
}

pub fn merge_object_if_missing(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
self.merge_object_with_override(other, false, vm)
}

pub fn merge_from_seq2(
&self,
seq2: PyObjectRef,
override_existing: bool,
vm: &VirtualMachine,
) -> PyResult<()> {
let iter = seq2.get_iter(vm)?;
let dict = &self.entries;
loop {
fn err(vm: &VirtualMachine) -> PyBaseExceptionRef {
vm.new_value_error("Iterator must have exactly two elements")
}
let element = match iter.next(vm)? {
PyIterReturn::Return(obj) => obj,
PyIterReturn::StopIteration(_) => break,
};
let elem_iter = element.get_iter(vm)?;
let key = elem_iter.next(vm)?.into_result().map_err(|_| err(vm))?;
let value = elem_iter.next(vm)?.into_result().map_err(|_| err(vm))?;
if matches!(elem_iter.next(vm)?, PyIterReturn::Return(_)) {
return Err(err(vm));
}
if !override_existing && dict.contains(vm, &*key)? {
continue;
}
dict.insert(vm, &*key, value)?;
}
Ok(())
}

fn merge_dict(&self, dict_other: PyDictRef, vm: &VirtualMachine) -> PyResult<()> {
fn merge_dict(
&self,
dict_other: PyDictRef,
override_existing: bool,
vm: &VirtualMachine,
) -> PyResult<()> {
let dict = &self.entries;
let dict_size = &dict_other.size();
for (key, value) in &dict_other {
if !override_existing && dict.contains(vm, &*key)? {
continue;
}
dict.insert(vm, &*key, value)?;
}
if dict_other.entries.has_changed_size(dict_size) {
Expand Down Expand Up @@ -386,7 +426,7 @@ impl PyDict {
let other_dict: Result<PyDictRef, _> = other.downcast();
if let Ok(other) = other_dict {
let self_cp = self.copy();
self_cp.merge_dict(other, vm)?;
self_cp.merge_dict(other, true, vm)?;
return Ok(self_cp.into_pyobject(vm));
}
Ok(vm.ctx.not_implemented())
Expand Down Expand Up @@ -499,7 +539,7 @@ impl PyRef<PyDict> {
let other_dict: Result<Self, _> = other.downcast();
if let Ok(other) = other_dict {
let other_cp = other.copy();
other_cp.merge_dict(self, vm)?;
other_cp.merge_dict(self, true, vm)?;
return Ok(other_cp.into_pyobject(vm));
}
Ok(vm.ctx.not_implemented())
Expand Down
Loading