Skip to content

Commit e7de2be

Browse files
committed
slots
1 parent 38716b2 commit e7de2be

File tree

8 files changed

+197
-22
lines changed

8 files changed

+197
-22
lines changed

crates/vm/src/builtins/complex.rs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,7 @@ use crate::{
55
class::PyClassImpl,
66
common::format::FormatSpec,
77
convert::{IntoPyException, ToPyObject, ToPyResult},
8-
function::{
9-
FuncArgs, OptionalArg, OptionalOption,
10-
PyArithmeticValue::{self, *},
11-
PyComparisonValue,
12-
},
8+
function::{FuncArgs, OptionalArg, PyComparisonValue},
139
protocol::PyNumberMethods,
1410
stdlib::warnings,
1511
types::{AsNumber, Comparable, Constructor, Hashable, PyComparisonOp, Representable},
@@ -367,7 +363,12 @@ impl AsNumber for PyComplex {
367363
}),
368364
absolute: Some(|number, vm| {
369365
let value = PyComplex::number_downcast(number).value;
370-
value.norm().to_pyresult(vm)
366+
let result = value.norm();
367+
// Check for overflow: hypot returns inf for finite inputs that overflow
368+
if result.is_infinite() && value.re.is_finite() && value.im.is_finite() {
369+
return Err(vm.new_overflow_error("absolute value too large".to_owned()));
370+
}
371+
result.to_pyresult(vm)
371372
}),
372373
boolean: Some(|number, _vm| Ok(!PyComplex::number_downcast(number).value.is_zero())),
373374
true_divide: Some(|a, b, vm| PyComplex::number_op(a, b, inner_div, vm)),

crates/vm/src/builtins/descriptor.rs

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -437,11 +437,12 @@ pub enum SlotFunc {
437437
MapAssSubscript(MapAssSubscriptFunc),
438438

439439
// Number sub-slots (nb_*) - grouped by signature
440-
NumBoolean(PyNumberUnaryFunc<bool>), // __bool__
441-
NumUnary(PyNumberUnaryFunc), // __int__, __float__, __index__
442-
NumBinary(PyNumberBinaryFunc), // __add__, __sub__, __mul__, etc.
443-
NumBinaryRight(PyNumberBinaryFunc), // __radd__, __rsub__, etc. (swapped args)
444-
NumTernary(PyNumberTernaryFunc), // __pow__
440+
NumBoolean(PyNumberUnaryFunc<bool>), // __bool__
441+
NumUnary(PyNumberUnaryFunc), // __int__, __float__, __index__
442+
NumBinary(PyNumberBinaryFunc), // __add__, __sub__, __mul__, etc.
443+
NumBinaryRight(PyNumberBinaryFunc), // __radd__, __rsub__, etc. (swapped args)
444+
NumTernary(PyNumberTernaryFunc), // __pow__
445+
NumTernaryRight(PyNumberTernaryFunc), // __rpow__ (swapped first two args)
445446
}
446447

447448
impl std::fmt::Debug for SlotFunc {
@@ -479,6 +480,7 @@ impl std::fmt::Debug for SlotFunc {
479480
SlotFunc::NumBinary(_) => write!(f, "SlotFunc::NumBinary(...)"),
480481
SlotFunc::NumBinaryRight(_) => write!(f, "SlotFunc::NumBinaryRight(...)"),
481482
SlotFunc::NumTernary(_) => write!(f, "SlotFunc::NumTernary(...)"),
483+
SlotFunc::NumTernaryRight(_) => write!(f, "SlotFunc::NumTernaryRight(...)"),
482484
}
483485
}
484486
}
@@ -649,6 +651,12 @@ impl SlotFunc {
649651
let z = z.unwrap_or_else(|| vm.ctx.none());
650652
func(&obj, &y, &z, vm)
651653
}
654+
SlotFunc::NumTernaryRight(func) => {
655+
let (y, z): (PyObjectRef, crate::function::OptionalArg<PyObjectRef>) =
656+
args.bind(vm)?;
657+
let z = z.unwrap_or_else(|| vm.ctx.none());
658+
func(&y, &obj, &z, vm) // Swapped: y ** obj % z
659+
}
652660
}
653661
}
654662
}

crates/vm/src/builtins/float.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@ use crate::{
88
common::{float_ops, format::FormatSpec, hash},
99
convert::{IntoPyException, ToPyObject, ToPyResult},
1010
function::{
11-
ArgBytesLike, FuncArgs, OptionalArg, OptionalOption,
12-
PyArithmeticValue::{self, *},
11+
ArgBytesLike, FuncArgs, OptionalArg, OptionalOption, PyArithmeticValue::*,
1312
PyComparisonValue,
1413
},
1514
protocol::PyNumberMethods,

crates/vm/src/builtins/int.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ use crate::{
1212
},
1313
convert::{IntoPyException, ToPyObject, ToPyResult},
1414
function::{
15-
ArgByteOrder, ArgIntoBool, FuncArgs, OptionalArg, OptionalOption, PyArithmeticValue,
16-
PyComparisonValue,
15+
ArgByteOrder, ArgIntoBool, FuncArgs, OptionalArg, PyArithmeticValue, PyComparisonValue,
1716
},
1817
protocol::{PyNumberMethods, handle_bytes_to_int_err},
1918
types::{AsNumber, Comparable, Constructor, Hashable, PyComparisonOp, Representable},

crates/vm/src/class.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use rustpython_common::static_cell;
1515
/// Iterates SLOT_DEFS and creates a PyWrapper for each slot that:
1616
/// 1. Has a function set in the type's slots
1717
/// 2. Doesn't already have an attribute in the type's dict
18-
fn add_operators(class: &'static Py<PyType>, ctx: &Context) {
18+
pub fn add_operators(class: &'static Py<PyType>, ctx: &Context) {
1919
for def in SLOT_DEFS.iter() {
2020
// Skip __new__ - it has special handling
2121
if def.name == "__new__" {

crates/vm/src/protocol/number.rs

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,115 @@ impl From<&PyNumberMethods> for PyNumberSlots {
352352
}
353353

354354
impl PyNumberSlots {
355+
/// Copy from static PyNumberMethods
356+
pub fn copy_from(&self, methods: &PyNumberMethods) {
357+
if let Some(f) = methods.add {
358+
self.add.store(Some(f));
359+
}
360+
if let Some(f) = methods.subtract {
361+
self.subtract.store(Some(f));
362+
}
363+
if let Some(f) = methods.multiply {
364+
self.multiply.store(Some(f));
365+
}
366+
if let Some(f) = methods.remainder {
367+
self.remainder.store(Some(f));
368+
}
369+
if let Some(f) = methods.divmod {
370+
self.divmod.store(Some(f));
371+
}
372+
if let Some(f) = methods.power {
373+
self.power.store(Some(f));
374+
}
375+
if let Some(f) = methods.negative {
376+
self.negative.store(Some(f));
377+
}
378+
if let Some(f) = methods.positive {
379+
self.positive.store(Some(f));
380+
}
381+
if let Some(f) = methods.absolute {
382+
self.absolute.store(Some(f));
383+
}
384+
if let Some(f) = methods.boolean {
385+
self.boolean.store(Some(f));
386+
}
387+
if let Some(f) = methods.invert {
388+
self.invert.store(Some(f));
389+
}
390+
if let Some(f) = methods.lshift {
391+
self.lshift.store(Some(f));
392+
}
393+
if let Some(f) = methods.rshift {
394+
self.rshift.store(Some(f));
395+
}
396+
if let Some(f) = methods.and {
397+
self.and.store(Some(f));
398+
}
399+
if let Some(f) = methods.xor {
400+
self.xor.store(Some(f));
401+
}
402+
if let Some(f) = methods.or {
403+
self.or.store(Some(f));
404+
}
405+
if let Some(f) = methods.int {
406+
self.int.store(Some(f));
407+
}
408+
if let Some(f) = methods.float {
409+
self.float.store(Some(f));
410+
}
411+
if let Some(f) = methods.inplace_add {
412+
self.inplace_add.store(Some(f));
413+
}
414+
if let Some(f) = methods.inplace_subtract {
415+
self.inplace_subtract.store(Some(f));
416+
}
417+
if let Some(f) = methods.inplace_multiply {
418+
self.inplace_multiply.store(Some(f));
419+
}
420+
if let Some(f) = methods.inplace_remainder {
421+
self.inplace_remainder.store(Some(f));
422+
}
423+
if let Some(f) = methods.inplace_power {
424+
self.inplace_power.store(Some(f));
425+
}
426+
if let Some(f) = methods.inplace_lshift {
427+
self.inplace_lshift.store(Some(f));
428+
}
429+
if let Some(f) = methods.inplace_rshift {
430+
self.inplace_rshift.store(Some(f));
431+
}
432+
if let Some(f) = methods.inplace_and {
433+
self.inplace_and.store(Some(f));
434+
}
435+
if let Some(f) = methods.inplace_xor {
436+
self.inplace_xor.store(Some(f));
437+
}
438+
if let Some(f) = methods.inplace_or {
439+
self.inplace_or.store(Some(f));
440+
}
441+
if let Some(f) = methods.floor_divide {
442+
self.floor_divide.store(Some(f));
443+
}
444+
if let Some(f) = methods.true_divide {
445+
self.true_divide.store(Some(f));
446+
}
447+
if let Some(f) = methods.inplace_floor_divide {
448+
self.inplace_floor_divide.store(Some(f));
449+
}
450+
if let Some(f) = methods.inplace_true_divide {
451+
self.inplace_true_divide.store(Some(f));
452+
}
453+
if let Some(f) = methods.index {
454+
self.index.store(Some(f));
455+
}
456+
if let Some(f) = methods.matrix_multiply {
457+
self.matrix_multiply.store(Some(f));
458+
}
459+
if let Some(f) = methods.inplace_matrix_multiply {
460+
self.inplace_matrix_multiply.store(Some(f));
461+
}
462+
}
463+
355464
pub fn left_binary_op(&self, op_slot: PyNumberBinaryOp) -> Option<PyNumberBinaryFunc> {
356465
use PyNumberBinaryOp::*;
357466
match op_slot {

crates/vm/src/types/slot.rs

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,42 @@ impl PyType {
689689
// === Rich compare (__lt__, __le__, __eq__, __ne__, __gt__, __ge__) ===
690690
SlotAccessor::TpRichcompare => {
691691
if ADD {
692-
if let Some(func) = self.lookup_slot_in_mro(name, ctx, |sf| {
692+
// Check if self or any class in MRO has a Python-defined comparison method
693+
// All comparison ops share the same slot, so if any is overridden anywhere
694+
// in the hierarchy with a Python function, we need to use the wrapper
695+
let cmp_names = [
696+
identifier!(ctx, __eq__),
697+
identifier!(ctx, __ne__),
698+
identifier!(ctx, __lt__),
699+
identifier!(ctx, __le__),
700+
identifier!(ctx, __gt__),
701+
identifier!(ctx, __ge__),
702+
];
703+
704+
let has_python_cmp = {
705+
// Check self first
706+
let attrs = self.attributes.read();
707+
let in_self = cmp_names.iter().any(|n| attrs.contains_key(*n));
708+
drop(attrs);
709+
710+
in_self
711+
|| self.mro.read().iter().any(|cls| {
712+
let attrs = cls.attributes.read();
713+
cmp_names.iter().any(|n| {
714+
if let Some(attr) = attrs.get(*n) {
715+
// Check if it's a Python function (not wrapper_descriptor)
716+
!attr.class().is(ctx.types.wrapper_descriptor_type)
717+
} else {
718+
false
719+
}
720+
})
721+
})
722+
};
723+
724+
if has_python_cmp {
725+
// Use wrapper to call the Python method
726+
self.slots.richcompare.store(Some(richcompare_wrapper));
727+
} else if let Some(func) = self.lookup_slot_in_mro(name, ctx, |sf| {
693728
if let SlotFunc::RichCompare(f, _) = sf {
694729
Some(*f)
695730
} else {
@@ -1710,6 +1745,10 @@ pub trait AsNumber: PyPayload {
17101745
#[pyslot]
17111746
fn as_number() -> &'static PyNumberMethods;
17121747

1748+
fn extend_slots(slots: &mut PyTypeSlots) {
1749+
slots.as_number.copy_from(Self::as_number());
1750+
}
1751+
17131752
fn clone_exact(_zelf: &Py<Self>, _vm: &VirtualMachine) -> PyRef<Self> {
17141753
// not all AsNumber requires this implementation.
17151754
unimplemented!()

crates/vm/src/types/slot_defs.rs

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,11 @@ impl SlotAccessor {
800800
if op == Some(SlotOp::Right) {
801801
match self {
802802
Self::NbAdd => {
803-
return slots.as_number.right_add.load().map(SlotFunc::NumBinaryRight);
803+
return slots
804+
.as_number
805+
.right_add
806+
.load()
807+
.map(SlotFunc::NumBinaryRight);
804808
}
805809
Self::NbSubtract => {
806810
return slots
@@ -831,7 +835,11 @@ impl SlotAccessor {
831835
.map(SlotFunc::NumBinaryRight);
832836
}
833837
Self::NbPower => {
834-
return slots.as_number.right_power.load().map(SlotFunc::NumTernary);
838+
return slots
839+
.as_number
840+
.right_power
841+
.load()
842+
.map(SlotFunc::NumTernaryRight);
835843
}
836844
Self::NbLshift => {
837845
return slots
@@ -848,13 +856,25 @@ impl SlotAccessor {
848856
.map(SlotFunc::NumBinaryRight);
849857
}
850858
Self::NbAnd => {
851-
return slots.as_number.right_and.load().map(SlotFunc::NumBinaryRight);
859+
return slots
860+
.as_number
861+
.right_and
862+
.load()
863+
.map(SlotFunc::NumBinaryRight);
852864
}
853865
Self::NbXor => {
854-
return slots.as_number.right_xor.load().map(SlotFunc::NumBinaryRight);
866+
return slots
867+
.as_number
868+
.right_xor
869+
.load()
870+
.map(SlotFunc::NumBinaryRight);
855871
}
856872
Self::NbOr => {
857-
return slots.as_number.right_or.load().map(SlotFunc::NumBinaryRight);
873+
return slots
874+
.as_number
875+
.right_or
876+
.load()
877+
.map(SlotFunc::NumBinaryRight);
858878
}
859879
Self::NbFloorDivide => {
860880
return slots

0 commit comments

Comments
 (0)