Skip to content

Commit 9ddd07a

Browse files
authored
Preserve imaginary zero signs when adding real values to complex numbers (RustPython#7421)
* Preserve imaginary zero signs when adding real values to complex numbers * Refactor complex_add with match expression * Correct complex real subtract op * Remove unnecessary vm arugment
1 parent 9a297aa commit 9ddd07a

2 files changed

Lines changed: 58 additions & 3 deletions

File tree

Lib/test/test_builtin.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1971,7 +1971,6 @@ def test_setattr(self):
19711971

19721972
# test_str(): see test_str.py and test_bytes.py for str() tests.
19731973

1974-
@unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: floats 0.0 and -0.0 are not identical: zeros have different signs
19751974
def test_sum(self):
19761975
self.assertEqual(sum([]), 0)
19771976
self.assertEqual(sum(list(range(2,8))), 27)

crates/vm/src/builtins/complex.rs

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,44 @@ impl PyComplex {
287287
Ok(vm.ctx.not_implemented())
288288
}
289289
}
290+
291+
fn complex_real_binop<CCF, RCF, CRF, R>(
292+
a: &PyObject,
293+
b: &PyObject,
294+
cc_op: CCF,
295+
cr_op: CRF,
296+
rc_op: RCF,
297+
vm: &VirtualMachine,
298+
) -> PyResult
299+
where
300+
CCF: FnOnce(Complex64, Complex64) -> R,
301+
CRF: FnOnce(Complex64, f64) -> R,
302+
RCF: FnOnce(f64, Complex64) -> R,
303+
R: ToPyResult,
304+
{
305+
let value = match (a.downcast_ref::<PyComplex>(), b.downcast_ref::<PyComplex>()) {
306+
// complex + complex
307+
(Some(a_complex), Some(b_complex)) => cc_op(a_complex.value, b_complex.value),
308+
(Some(a_complex), None) => {
309+
let Some(b_real) = float::to_op_float(b, vm)? else {
310+
return Ok(vm.ctx.not_implemented());
311+
};
312+
313+
// complex + real
314+
cr_op(a_complex.value, b_real)
315+
}
316+
(None, Some(b_complex)) => {
317+
let Some(a_real) = float::to_op_float(a, vm)? else {
318+
return Ok(vm.ctx.not_implemented());
319+
};
320+
321+
// real + complex
322+
rc_op(a_real, b_complex.value)
323+
}
324+
(None, None) => return Ok(vm.ctx.not_implemented()),
325+
};
326+
value.to_pyresult(vm)
327+
}
290328
}
291329

292330
#[pyclass(
@@ -396,8 +434,26 @@ impl Hashable for PyComplex {
396434
impl AsNumber for PyComplex {
397435
fn as_number() -> &'static PyNumberMethods {
398436
static AS_NUMBER: PyNumberMethods = PyNumberMethods {
399-
add: Some(|a, b, vm| PyComplex::number_op(a, b, |a, b, _vm| a + b, vm)),
400-
subtract: Some(|a, b, vm| PyComplex::number_op(a, b, |a, b, _vm| a - b, vm)),
437+
add: Some(|a, b, vm| {
438+
PyComplex::complex_real_binop(
439+
a,
440+
b,
441+
|a, b| a + b,
442+
|a_complex, b_real| Complex64::new(a_complex.re + b_real, a_complex.im),
443+
|a_real, b_complex| Complex64::new(a_real + b_complex.re, b_complex.im),
444+
vm,
445+
)
446+
}),
447+
subtract: Some(|a, b, vm| {
448+
PyComplex::complex_real_binop(
449+
a,
450+
b,
451+
|a, b| a - b,
452+
|a_complex, b_real| Complex64::new(a_complex.re - b_real, a_complex.im),
453+
|a_real, b_complex| Complex64::new(a_real - b_complex.re, -b_complex.im),
454+
vm,
455+
)
456+
}),
401457
multiply: Some(|a, b, vm| PyComplex::number_op(a, b, |a, b, _vm| a * b, vm)),
402458
power: Some(|a, b, c, vm| {
403459
if vm.is_none(c) {

0 commit comments

Comments
 (0)