Skip to content

Commit f133d7e

Browse files
committed
temp
1 parent ea1e750 commit f133d7e

File tree

3 files changed

+55
-43
lines changed

3 files changed

+55
-43
lines changed

Cargo.lock

Lines changed: 1 addition & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ once_cell = "1.20.3"
184184
parking_lot = "0.12.3"
185185
paste = "1.0.15"
186186
proc-macro2 = "1.0.105"
187-
pymath = { version = "0.1.3", features = ["mul_add", "malachite-bigint", "complex"] }
187+
pymath = { git = "https://github.com/RustPython/pymath.git", rev = "564ebb2780a05a9460ad12a2752a67244dc5f89e", features = ["mul_add", "malachite-bigint", "complex"] }
188188
quote = "1.0.43"
189189
radium = "1.1.1"
190190
rand = "0.9"

crates/stdlib/src/math.rs

Lines changed: 53 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use crate::vm::{VirtualMachine, builtins::PyBaseExceptionRef};
55
#[pymodule]
66
mod math {
77
use crate::vm::{
8-
PyObject, PyObjectRef, PyRef, PyResult, VirtualMachine,
8+
AsObject, PyObject, PyObjectRef, PyRef, PyResult, VirtualMachine,
99
builtins::{PyFloat, PyInt, PyIntRef, PyStrInterned, try_bigint_to_f64, try_f64_to_bigint},
1010
function::{ArgIndex, ArgIntoFloat, ArgIterable, Either, OptionalArg, PosArgs},
1111
identifier,
@@ -534,10 +534,11 @@ mod math {
534534

535535
if let Some(int_item) = item.downcast_ref::<PyInt>()
536536
&& let Ok(b) = int_item.as_bigint().try_into() as Result<i64, _>
537-
&& let Some(product) = int_result.checked_mul(b) {
538-
int_result = product;
539-
continue;
540-
}
537+
&& let Some(product) = int_result.checked_mul(b)
538+
{
539+
int_result = product;
540+
continue;
541+
}
541542

542543
// Overflow or non-int: restore to PyObject and continue
543544
obj_result = Some(vm.ctx.new_int(int_result).into());
@@ -590,10 +591,11 @@ mod math {
590591
continue;
591592
}
592593
if let Some(i) = item.downcast_ref::<PyInt>()
593-
&& let Ok(v) = i.as_bigint().try_into() as Result<i64, _> {
594-
flt_result *= v as f64;
595-
continue;
596-
}
594+
&& let Ok(v) = i.as_bigint().try_into() as Result<i64, _>
595+
{
596+
flt_result *= v as f64;
597+
continue;
598+
}
597599

598600
// Non-float/int: restore and continue with generic path
599601
obj_result = Some(vm.ctx.new_float(flt_result).into());
@@ -646,12 +648,14 @@ mod math {
646648
_ => return Err(vm.new_value_error("Inputs are not the same length")),
647649
};
648650

649-
// Integer fast path
651+
// Integer fast path (only for exact int types, not subclasses)
650652
if int_path_enabled {
651653
if !finished {
652654
let (p_i, q_i) = (p_i.as_ref().unwrap(), q_i.as_ref().unwrap());
653-
if let (Some(p_int), Some(q_int)) =
654-
(p_i.downcast_ref::<PyInt>(), q_i.downcast_ref::<PyInt>())
655+
if p_i.class().is(vm.ctx.types.int_type)
656+
&& q_i.class().is(vm.ctx.types.int_type)
657+
&& let (Some(p_int), Some(q_int)) =
658+
(p_i.downcast_ref::<PyInt>(), q_i.downcast_ref::<PyInt>())
655659
&& let (Ok(p_val), Ok(q_val)) = (
656660
p_int.as_bigint().try_into() as Result<i64, _>,
657661
q_int.as_bigint().try_into() as Result<i64, _>,
@@ -677,16 +681,22 @@ mod math {
677681
}
678682
}
679683

680-
// Float fast path - only when at least one value is float
684+
// Float fast path - only when at least one value is exact float type
685+
// (not subclasses, to preserve custom __mul__/__add__ behavior)
681686
if flt_path_enabled {
682687
if !finished {
683688
let (p_i, q_i) = (p_i.as_ref().unwrap(), q_i.as_ref().unwrap());
684689

685-
let p_is_float = p_i.downcast_ref::<PyFloat>().is_some();
686-
let q_is_float = q_i.downcast_ref::<PyFloat>().is_some();
690+
let p_is_exact_float = p_i.class().is(vm.ctx.types.float_type);
691+
let q_is_exact_float = q_i.class().is(vm.ctx.types.float_type);
692+
let p_is_exact_int = p_i.class().is(vm.ctx.types.int_type);
693+
let q_is_exact_int = q_i.class().is(vm.ctx.types.int_type);
694+
let p_is_exact_numeric = p_is_exact_float || p_is_exact_int;
695+
let q_is_exact_numeric = q_is_exact_float || q_is_exact_int;
696+
let has_exact_float = p_is_exact_float || q_is_exact_float;
687697

688-
// Only use float path if at least one is float (like CPython)
689-
if p_is_float || q_is_float {
698+
// Only use float path if at least one is exact float and both are exact int/float
699+
if has_exact_float && p_is_exact_numeric && q_is_exact_numeric {
690700
let p_flt = if let Some(f) = p_i.downcast_ref::<PyFloat>() {
691701
Some(f.to_f64())
692702
} else if let Some(i) = p_i.downcast_ref::<PyInt>() {
@@ -854,11 +864,13 @@ mod math {
854864

855865
// Fast path: n fits in i64
856866
if let Some(ni) = n_big.to_i64()
857-
&& ni >= 0 && ki > 1 {
858-
let result = pymath::math::integer::perm(ni, Some(ki as i64))
859-
.map_err(|_| vm.new_value_error("perm() error"))?;
860-
return Ok(result.into());
861-
}
867+
&& ni >= 0
868+
&& ki > 1
869+
{
870+
let result = pymath::math::integer::perm(ni, Some(ki as i64))
871+
.map_err(|_| vm.new_value_error("perm() error"))?;
872+
return Ok(result.into());
873+
}
862874

863875
// BigInt path: use perm_bigint
864876
let result = pymath::math::perm_bigint(n_big, ki);
@@ -881,25 +893,26 @@ mod math {
881893

882894
// Fast path: n fits in i64
883895
if let Some(ni) = n_big.to_i64()
884-
&& ni >= 0 {
885-
// k overflow or k > n means result is 0
886-
let ki = match k_big.to_i64() {
887-
Some(k) if k >= 0 && k <= ni => k,
888-
_ => return Ok(BigInt::from(0u8)),
889-
};
890-
// Apply symmetry: use min(k, n-k)
891-
let ki = ki.min(ni - ki);
892-
if ki > 1 {
893-
let result = pymath::math::integer::comb(ni, ki)
894-
.map_err(|_| vm.new_value_error("comb() error"))?;
895-
return Ok(result.into());
896-
}
897-
// ki <= 1 cases
898-
if ki == 0 {
899-
return Ok(BigInt::from(1u8));
900-
}
901-
return Ok(n_big.clone()); // ki == 1
896+
&& ni >= 0
897+
{
898+
// k overflow or k > n means result is 0
899+
let ki = match k_big.to_i64() {
900+
Some(k) if k >= 0 && k <= ni => k,
901+
_ => return Ok(BigInt::from(0u8)),
902+
};
903+
// Apply symmetry: use min(k, n-k)
904+
let ki = ki.min(ni - ki);
905+
if ki > 1 {
906+
let result = pymath::math::integer::comb(ni, ki)
907+
.map_err(|_| vm.new_value_error("comb() error"))?;
908+
return Ok(result.into());
902909
}
910+
// ki <= 1 cases
911+
if ki == 0 {
912+
return Ok(BigInt::from(1u8));
913+
}
914+
return Ok(n_big.clone()); // ki == 1
915+
}
903916

904917
// BigInt path: n doesn't fit in i64
905918
// Apply symmetry: k = min(k, n - k)

0 commit comments

Comments
 (0)