Skip to content

Commit a575121

Browse files
Add more dict functions to c-api (RustPython#8043)
1 parent 52afc12 commit a575121

2 files changed

Lines changed: 232 additions & 23 deletions

File tree

crates/capi/src/dictobject.rs

Lines changed: 170 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use crate::pystate::with_vm;
44
use core::ffi::c_int;
55
use core::ptr::NonNull;
66
use rustpython_vm::AsObject;
7+
use rustpython_vm::PyPayload;
78
use rustpython_vm::builtins::PyDict;
89

910
define_py_check!(fn PyDict_Check, types.dict_type);
@@ -64,6 +65,100 @@ pub unsafe extern "C" fn PyDict_Size(dict: *mut PyObject) -> isize {
6465
})
6566
}
6667

68+
#[unsafe(no_mangle)]
69+
pub unsafe extern "C" fn PyDict_Contains(dict: *mut PyObject, key: *mut PyObject) -> c_int {
70+
with_vm(|vm| {
71+
let dict = unsafe { &*dict }.try_downcast_ref::<PyDict>(vm)?;
72+
let key = unsafe { &*key };
73+
Ok(dict.inner_getitem_opt(key, vm)?.is_some())
74+
})
75+
}
76+
77+
#[unsafe(no_mangle)]
78+
pub unsafe extern "C" fn PyDict_Copy(dict: *mut PyObject) -> *mut PyObject {
79+
with_vm(|vm| {
80+
let dict = unsafe { &*dict }.try_downcast_ref::<PyDict>(vm)?;
81+
Ok(dict.copy().into_ref(&vm.ctx))
82+
})
83+
}
84+
85+
#[unsafe(no_mangle)]
86+
pub unsafe extern "C" fn PyDict_DelItem(dict: *mut PyObject, key: *mut PyObject) -> c_int {
87+
with_vm(|vm| {
88+
let dict = unsafe { &*dict }.try_downcast_ref::<PyDict>(vm)?;
89+
let key = unsafe { &*key };
90+
dict.del_item(key, vm)
91+
})
92+
}
93+
94+
#[unsafe(no_mangle)]
95+
pub unsafe extern "C" fn PyDict_Items(dict: *mut PyObject) -> *mut PyObject {
96+
with_vm(|vm| {
97+
let dict = unsafe { &*dict }.try_downcast_ref::<PyDict>(vm)?;
98+
let items = dict
99+
.items_vec()
100+
.into_iter()
101+
.map(|(k, v)| vm.ctx.new_tuple(vec![k, v]).into())
102+
.collect();
103+
Ok(vm.ctx.new_list(items))
104+
})
105+
}
106+
107+
#[unsafe(no_mangle)]
108+
pub unsafe extern "C" fn PyDict_Keys(dict: *mut PyObject) -> *mut PyObject {
109+
with_vm(|vm| {
110+
let dict = unsafe { &*dict }.try_downcast_ref::<PyDict>(vm)?;
111+
Ok(vm.ctx.new_list(dict.keys_vec()))
112+
})
113+
}
114+
115+
#[unsafe(no_mangle)]
116+
pub unsafe extern "C" fn PyDict_Values(dict: *mut PyObject) -> *mut PyObject {
117+
with_vm(|vm| {
118+
let dict = unsafe { &*dict }.try_downcast_ref::<PyDict>(vm)?;
119+
Ok(vm.ctx.new_list(dict.values_vec()))
120+
})
121+
}
122+
123+
#[unsafe(no_mangle)]
124+
pub unsafe extern "C" fn PyDict_Merge(
125+
dict: *mut PyObject,
126+
other: *mut PyObject,
127+
override_: c_int,
128+
) -> c_int {
129+
with_vm(|vm| {
130+
let dict = unsafe { &*dict }.try_downcast_ref::<PyDict>(vm)?;
131+
let other = unsafe { &*other }.to_owned();
132+
if override_ != 0 {
133+
dict.merge_object(other, vm)
134+
} else {
135+
dict.merge_object_if_missing(other, vm)
136+
}
137+
})
138+
}
139+
140+
#[unsafe(no_mangle)]
141+
pub unsafe extern "C" fn PyDict_Update(dict: *mut PyObject, other: *mut PyObject) -> c_int {
142+
with_vm(|vm| {
143+
let dict = unsafe { &*dict }.try_downcast_ref::<PyDict>(vm)?;
144+
let other = unsafe { &*other }.to_owned();
145+
dict.merge_object(other, vm)
146+
})
147+
}
148+
149+
#[unsafe(no_mangle)]
150+
pub unsafe extern "C" fn PyDict_MergeFromSeq2(
151+
dict: *mut PyObject,
152+
seq2: *mut PyObject,
153+
override_: c_int,
154+
) -> c_int {
155+
with_vm(|vm| {
156+
let dict = unsafe { &*dict }.try_downcast_ref::<PyDict>(vm)?;
157+
let seq2 = unsafe { &*seq2 }.to_owned();
158+
dict.merge_from_seq2(seq2, override_ != 0, vm)
159+
})
160+
}
161+
67162
#[unsafe(no_mangle)]
68163
pub unsafe extern "C" fn PyDict_Next(
69164
dict: *mut PyObject,
@@ -95,7 +190,7 @@ pub unsafe extern "C" fn PyDict_Next(
95190
#[cfg(false)]
96191
mod tests {
97192
use pyo3::prelude::*;
98-
use pyo3::types::{IntoPyDict, PyDict, PyInt};
193+
use pyo3::types::{IntoPyDict, PyDict, PyDictMethods, PyInt, PyList};
99194

100195
#[test]
101196
fn test_create_empty_dict() {
@@ -129,4 +224,78 @@ mod tests {
129224
assert_eq!(values, vec![1, 2, 3, 4]);
130225
})
131226
}
227+
228+
#[test]
229+
fn dict_contains() {
230+
Python::attach(|py| {
231+
let dict = [(1, 2)].into_py_dict(py).unwrap();
232+
assert!(dict.contains(1).unwrap());
233+
assert!(!dict.contains(3).unwrap());
234+
})
235+
}
236+
237+
#[test]
238+
fn dict_copy_and_del_item() {
239+
Python::attach(|py| {
240+
let dict = [(1, 2), (3, 4)].into_py_dict(py).unwrap();
241+
let copied = dict.copy().unwrap();
242+
assert_eq!(copied.len(), 2);
243+
copied.del_item(1).unwrap();
244+
assert!(!copied.contains(1).unwrap());
245+
})
246+
}
247+
248+
#[test]
249+
fn dict_keys_values_items() {
250+
Python::attach(|py| {
251+
let dict = [(1, 2), (3, 4)].into_py_dict(py).unwrap();
252+
assert_eq!(dict.keys().len(), 2);
253+
assert_eq!(dict.values().len(), 2);
254+
assert_eq!(dict.items().len(), 2);
255+
})
256+
}
257+
258+
#[test]
259+
fn dict_update_and_merge() {
260+
Python::attach(|py| {
261+
let dict = [(1, 10)].into_py_dict(py).unwrap();
262+
let replacement = [(1, 20), (2, 30)].into_py_dict(py).unwrap();
263+
dict.update(replacement.as_mapping()).unwrap();
264+
assert_eq!(
265+
dict.get_item(1).unwrap().unwrap().extract::<i32>().unwrap(),
266+
20
267+
);
268+
assert_eq!(
269+
dict.get_item(2).unwrap().unwrap().extract::<i32>().unwrap(),
270+
30
271+
);
272+
273+
let merged_missing = [(1, 99), (3, 40)].into_py_dict(py).unwrap();
274+
dict.update_if_missing(merged_missing.as_mapping()).unwrap();
275+
assert_eq!(
276+
dict.get_item(1).unwrap().unwrap().extract::<i32>().unwrap(),
277+
20
278+
);
279+
assert_eq!(
280+
dict.get_item(3).unwrap().unwrap().extract::<i32>().unwrap(),
281+
40
282+
);
283+
})
284+
}
285+
286+
#[test]
287+
fn dict_merge_from_seq2() {
288+
Python::attach(|py| {
289+
let seq = PyList::new(py, [(1, 10), (1, 20), (2, 30)]).unwrap();
290+
let dict = PyDict::from_sequence(seq.as_any()).unwrap();
291+
assert_eq!(
292+
dict.get_item(1).unwrap().unwrap().extract::<i32>().unwrap(),
293+
20
294+
);
295+
assert_eq!(
296+
dict.get_item(2).unwrap().unwrap().extract::<i32>().unwrap(),
297+
30
298+
);
299+
})
300+
}
132301
}

crates/vm/src/builtins/dict.rs

Lines changed: 62 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,17 @@ impl PyDict {
143143
self.entries.items()
144144
}
145145

146-
// Used in update and ior.
147-
pub(crate) fn merge_object(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
146+
fn merge_object_with_override(
147+
&self,
148+
other: PyObjectRef,
149+
override_existing: bool,
150+
vm: &VirtualMachine,
151+
) -> PyResult<()> {
148152
let casted: Result<PyRefExact<Self>, _> = other.downcast_exact(vm);
149153
let other = match casted {
150-
Ok(dict_other) => return self.merge_dict(dict_other.into_pyref(), vm),
154+
Ok(dict_other) => {
155+
return self.merge_dict(dict_other.into_pyref(), override_existing, vm);
156+
}
151157
Err(other) => other,
152158
};
153159
let dict = &self.entries;
@@ -157,6 +163,9 @@ impl PyDict {
157163
Ok(keys_method) => {
158164
let keys = keys_method.call((), vm)?.get_iter(vm)?;
159165
while let PyIterReturn::Return(key) = keys.next(vm)? {
166+
if !override_existing && dict.contains(vm, &*key)? {
167+
continue;
168+
}
160169
let val = other.get_item(&*key, vm)?;
161170
dict.insert(vm, &*key, val)?;
162171
}
@@ -166,31 +175,62 @@ impl PyDict {
166175
Err(e) => return Err(e),
167176
};
168177
if !has_keys {
169-
let iter = other.get_iter(vm)?;
170-
loop {
171-
fn err(vm: &VirtualMachine) -> PyBaseExceptionRef {
172-
vm.new_value_error("Iterator must have exactly two elements")
173-
}
174-
let element = match iter.next(vm)? {
175-
PyIterReturn::Return(obj) => obj,
176-
PyIterReturn::StopIteration(_) => break,
177-
};
178-
let elem_iter = element.get_iter(vm)?;
179-
let key = elem_iter.next(vm)?.into_result().map_err(|_| err(vm))?;
180-
let value = elem_iter.next(vm)?.into_result().map_err(|_| err(vm))?;
181-
if matches!(elem_iter.next(vm)?, PyIterReturn::Return(_)) {
182-
return Err(err(vm));
183-
}
184-
dict.insert(vm, &*key, value)?;
178+
return self.merge_from_seq2(other, override_existing, vm);
179+
}
180+
Ok(())
181+
}
182+
183+
// Used in update and ior.
184+
pub fn merge_object(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
185+
self.merge_object_with_override(other, true, vm)
186+
}
187+
188+
pub fn merge_object_if_missing(&self, other: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> {
189+
self.merge_object_with_override(other, false, vm)
190+
}
191+
192+
pub fn merge_from_seq2(
193+
&self,
194+
seq2: PyObjectRef,
195+
override_existing: bool,
196+
vm: &VirtualMachine,
197+
) -> PyResult<()> {
198+
let iter = seq2.get_iter(vm)?;
199+
let dict = &self.entries;
200+
loop {
201+
fn err(vm: &VirtualMachine) -> PyBaseExceptionRef {
202+
vm.new_value_error("Iterator must have exactly two elements")
185203
}
204+
let element = match iter.next(vm)? {
205+
PyIterReturn::Return(obj) => obj,
206+
PyIterReturn::StopIteration(_) => break,
207+
};
208+
let elem_iter = element.get_iter(vm)?;
209+
let key = elem_iter.next(vm)?.into_result().map_err(|_| err(vm))?;
210+
let value = elem_iter.next(vm)?.into_result().map_err(|_| err(vm))?;
211+
if matches!(elem_iter.next(vm)?, PyIterReturn::Return(_)) {
212+
return Err(err(vm));
213+
}
214+
if !override_existing && dict.contains(vm, &*key)? {
215+
continue;
216+
}
217+
dict.insert(vm, &*key, value)?;
186218
}
187219
Ok(())
188220
}
189221

190-
fn merge_dict(&self, dict_other: PyDictRef, vm: &VirtualMachine) -> PyResult<()> {
222+
fn merge_dict(
223+
&self,
224+
dict_other: PyDictRef,
225+
override_existing: bool,
226+
vm: &VirtualMachine,
227+
) -> PyResult<()> {
191228
let dict = &self.entries;
192229
let dict_size = &dict_other.size();
193230
for (key, value) in &dict_other {
231+
if !override_existing && dict.contains(vm, &*key)? {
232+
continue;
233+
}
194234
dict.insert(vm, &*key, value)?;
195235
}
196236
if dict_other.entries.has_changed_size(dict_size) {
@@ -386,7 +426,7 @@ impl PyDict {
386426
let other_dict: Result<PyDictRef, _> = other.downcast();
387427
if let Ok(other) = other_dict {
388428
let self_cp = self.copy();
389-
self_cp.merge_dict(other, vm)?;
429+
self_cp.merge_dict(other, true, vm)?;
390430
return Ok(self_cp.into_pyobject(vm));
391431
}
392432
Ok(vm.ctx.not_implemented())
@@ -499,7 +539,7 @@ impl PyRef<PyDict> {
499539
let other_dict: Result<Self, _> = other.downcast();
500540
if let Ok(other) = other_dict {
501541
let other_cp = other.copy();
502-
other_cp.merge_dict(self, vm)?;
542+
other_cp.merge_dict(self, true, vm)?;
503543
return Ok(other_cp.into_pyobject(vm));
504544
}
505545
Ok(vm.ctx.not_implemented())

0 commit comments

Comments
 (0)