@@ -5,7 +5,7 @@ use crate::vm::{VirtualMachine, builtins::PyBaseExceptionRef};
55#[ pymodule]
66mod math {
77 use crate :: vm:: {
8- PyObject , PyObjectRef , PyRef , PyResult , VirtualMachine ,
8+ AsObject , PyObject , PyObjectRef , PyRef , PyResult , VirtualMachine ,
99 builtins:: { PyFloat , PyInt , PyIntRef , PyStrInterned , try_bigint_to_f64, try_f64_to_bigint} ,
1010 function:: { ArgIndex , ArgIntoFloat , ArgIterable , Either , OptionalArg , PosArgs } ,
1111 identifier,
@@ -534,10 +534,11 @@ mod math {
534534
535535 if let Some ( int_item) = item. downcast_ref :: < PyInt > ( )
536536 && let Ok ( b) = int_item. as_bigint ( ) . try_into ( ) as Result < i64 , _ >
537- && let Some ( product) = int_result. checked_mul ( b) {
538- int_result = product;
539- continue ;
540- }
537+ && let Some ( product) = int_result. checked_mul ( b)
538+ {
539+ int_result = product;
540+ continue ;
541+ }
541542
542543 // Overflow or non-int: restore to PyObject and continue
543544 obj_result = Some ( vm. ctx . new_int ( int_result) . into ( ) ) ;
@@ -590,10 +591,11 @@ mod math {
590591 continue ;
591592 }
592593 if let Some ( i) = item. downcast_ref :: < PyInt > ( )
593- && let Ok ( v) = i. as_bigint ( ) . try_into ( ) as Result < i64 , _ > {
594- flt_result *= v as f64 ;
595- continue ;
596- }
594+ && let Ok ( v) = i. as_bigint ( ) . try_into ( ) as Result < i64 , _ >
595+ {
596+ flt_result *= v as f64 ;
597+ continue ;
598+ }
597599
598600 // Non-float/int: restore and continue with generic path
599601 obj_result = Some ( vm. ctx . new_float ( flt_result) . into ( ) ) ;
@@ -646,12 +648,14 @@ mod math {
646648 _ => return Err ( vm. new_value_error ( "Inputs are not the same length" ) ) ,
647649 } ;
648650
649- // Integer fast path
651+ // Integer fast path (only for exact int types, not subclasses)
650652 if int_path_enabled {
651653 if !finished {
652654 let ( p_i, q_i) = ( p_i. as_ref ( ) . unwrap ( ) , q_i. as_ref ( ) . unwrap ( ) ) ;
653- if let ( Some ( p_int) , Some ( q_int) ) =
654- ( p_i. downcast_ref :: < PyInt > ( ) , q_i. downcast_ref :: < PyInt > ( ) )
655+ if p_i. class ( ) . is ( vm. ctx . types . int_type )
656+ && q_i. class ( ) . is ( vm. ctx . types . int_type )
657+ && let ( Some ( p_int) , Some ( q_int) ) =
658+ ( p_i. downcast_ref :: < PyInt > ( ) , q_i. downcast_ref :: < PyInt > ( ) )
655659 && let ( Ok ( p_val) , Ok ( q_val) ) = (
656660 p_int. as_bigint ( ) . try_into ( ) as Result < i64 , _ > ,
657661 q_int. as_bigint ( ) . try_into ( ) as Result < i64 , _ > ,
@@ -677,16 +681,22 @@ mod math {
677681 }
678682 }
679683
680- // Float fast path - only when at least one value is float
684+ // Float fast path - only when at least one value is exact float type
685+ // (not subclasses, to preserve custom __mul__/__add__ behavior)
681686 if flt_path_enabled {
682687 if !finished {
683688 let ( p_i, q_i) = ( p_i. as_ref ( ) . unwrap ( ) , q_i. as_ref ( ) . unwrap ( ) ) ;
684689
685- let p_is_float = p_i. downcast_ref :: < PyFloat > ( ) . is_some ( ) ;
686- let q_is_float = q_i. downcast_ref :: < PyFloat > ( ) . is_some ( ) ;
690+ let p_is_exact_float = p_i. class ( ) . is ( vm. ctx . types . float_type ) ;
691+ let q_is_exact_float = q_i. class ( ) . is ( vm. ctx . types . float_type ) ;
692+ let p_is_exact_int = p_i. class ( ) . is ( vm. ctx . types . int_type ) ;
693+ let q_is_exact_int = q_i. class ( ) . is ( vm. ctx . types . int_type ) ;
694+ let p_is_exact_numeric = p_is_exact_float || p_is_exact_int;
695+ let q_is_exact_numeric = q_is_exact_float || q_is_exact_int;
696+ let has_exact_float = p_is_exact_float || q_is_exact_float;
687697
688- // Only use float path if at least one is float (like CPython)
689- if p_is_float || q_is_float {
698+ // Only use float path if at least one is exact float and both are exact int/float
699+ if has_exact_float && p_is_exact_numeric && q_is_exact_numeric {
690700 let p_flt = if let Some ( f) = p_i. downcast_ref :: < PyFloat > ( ) {
691701 Some ( f. to_f64 ( ) )
692702 } else if let Some ( i) = p_i. downcast_ref :: < PyInt > ( ) {
@@ -854,11 +864,13 @@ mod math {
854864
855865 // Fast path: n fits in i64
856866 if let Some ( ni) = n_big. to_i64 ( )
857- && ni >= 0 && ki > 1 {
858- let result = pymath:: math:: integer:: perm ( ni, Some ( ki as i64 ) )
859- . map_err ( |_| vm. new_value_error ( "perm() error" ) ) ?;
860- return Ok ( result. into ( ) ) ;
861- }
867+ && ni >= 0
868+ && ki > 1
869+ {
870+ let result = pymath:: math:: integer:: perm ( ni, Some ( ki as i64 ) )
871+ . map_err ( |_| vm. new_value_error ( "perm() error" ) ) ?;
872+ return Ok ( result. into ( ) ) ;
873+ }
862874
863875 // BigInt path: use perm_bigint
864876 let result = pymath:: math:: perm_bigint ( n_big, ki) ;
@@ -881,25 +893,26 @@ mod math {
881893
882894 // Fast path: n fits in i64
883895 if let Some ( ni) = n_big. to_i64 ( )
884- && ni >= 0 {
885- // k overflow or k > n means result is 0
886- let ki = match k_big. to_i64 ( ) {
887- Some ( k) if k >= 0 && k <= ni => k,
888- _ => return Ok ( BigInt :: from ( 0u8 ) ) ,
889- } ;
890- // Apply symmetry: use min(k, n-k)
891- let ki = ki. min ( ni - ki) ;
892- if ki > 1 {
893- let result = pymath:: math:: integer:: comb ( ni, ki)
894- . map_err ( |_| vm. new_value_error ( "comb() error" ) ) ?;
895- return Ok ( result. into ( ) ) ;
896- }
897- // ki <= 1 cases
898- if ki == 0 {
899- return Ok ( BigInt :: from ( 1u8 ) ) ;
900- }
901- return Ok ( n_big. clone ( ) ) ; // ki == 1
896+ && ni >= 0
897+ {
898+ // k overflow or k > n means result is 0
899+ let ki = match k_big. to_i64 ( ) {
900+ Some ( k) if k >= 0 && k <= ni => k,
901+ _ => return Ok ( BigInt :: from ( 0u8 ) ) ,
902+ } ;
903+ // Apply symmetry: use min(k, n-k)
904+ let ki = ki. min ( ni - ki) ;
905+ if ki > 1 {
906+ let result = pymath:: math:: integer:: comb ( ni, ki)
907+ . map_err ( |_| vm. new_value_error ( "comb() error" ) ) ?;
908+ return Ok ( result. into ( ) ) ;
902909 }
910+ // ki <= 1 cases
911+ if ki == 0 {
912+ return Ok ( BigInt :: from ( 1u8 ) ) ;
913+ }
914+ return Ok ( n_big. clone ( ) ) ; // ki == 1
915+ }
903916
904917 // BigInt path: n doesn't fit in i64
905918 // Apply symmetry: k = min(k, n - k)
0 commit comments