@@ -6,73 +6,37 @@ pub(crate) use _random::make_module;
66mod _random {
77 use crate :: common:: lock:: PyMutex ;
88 use crate :: vm:: {
9- builtins:: { PyInt , PyTypeRef } ,
9+ builtins:: { PyInt , PyTupleRef } ,
10+ convert:: ToPyException ,
1011 function:: OptionalOption ,
11- types:: Constructor ,
12- PyObjectRef , PyPayload , PyResult , VirtualMachine ,
12+ types:: { Constructor , Initializer } ,
13+ PyObjectRef , PyPayload , PyRef , PyResult , VirtualMachine ,
1314 } ;
15+ use itertools:: Itertools ;
1416 use malachite_bigint:: { BigInt , BigUint , Sign } ;
17+ use mt19937:: MT19937 ;
1518 use num_traits:: { Signed , Zero } ;
16- use rand:: { rngs:: StdRng , RngCore , SeedableRng } ;
17-
18- #[ derive( Debug ) ]
19- enum PyRng {
20- Std ( Box < StdRng > ) ,
21- MT ( Box < mt19937:: MT19937 > ) ,
22- }
23-
24- impl Default for PyRng {
25- fn default ( ) -> Self {
26- PyRng :: Std ( Box :: new ( StdRng :: from_os_rng ( ) ) )
27- }
28- }
29-
30- impl RngCore for PyRng {
31- fn next_u32 ( & mut self ) -> u32 {
32- match self {
33- Self :: Std ( s) => s. next_u32 ( ) ,
34- Self :: MT ( m) => m. next_u32 ( ) ,
35- }
36- }
37- fn next_u64 ( & mut self ) -> u64 {
38- match self {
39- Self :: Std ( s) => s. next_u64 ( ) ,
40- Self :: MT ( m) => m. next_u64 ( ) ,
41- }
42- }
43- fn fill_bytes ( & mut self , dest : & mut [ u8 ] ) {
44- match self {
45- Self :: Std ( s) => s. fill_bytes ( dest) ,
46- Self :: MT ( m) => m. fill_bytes ( dest) ,
47- }
48- }
49- }
19+ use rand:: { RngCore , SeedableRng } ;
20+ use rustpython_vm:: types:: DefaultConstructor ;
5021
5122 #[ pyattr]
5223 #[ pyclass( name = "Random" ) ]
53- #[ derive( Debug , PyPayload ) ]
24+ #[ derive( Debug , PyPayload , Default ) ]
5425 struct PyRandom {
55- rng : PyMutex < PyRng > ,
26+ rng : PyMutex < MT19937 > ,
5627 }
5728
58- impl Constructor for PyRandom {
59- type Args = OptionalOption < PyObjectRef > ;
29+ impl DefaultConstructor for PyRandom { }
6030
61- fn py_new (
62- cls : PyTypeRef ,
63- // TODO: use x as the seed.
64- _x : Self :: Args ,
65- vm : & VirtualMachine ,
66- ) -> PyResult {
67- PyRandom {
68- rng : PyMutex :: default ( ) ,
69- }
70- . into_ref_with_type ( vm, cls)
71- . map ( Into :: into)
31+ impl Initializer for PyRandom {
32+ type Args = OptionalOption ;
33+
34+ fn init ( zelf : PyRef < Self > , x : Self :: Args , vm : & VirtualMachine ) -> PyResult < ( ) > {
35+ zelf. seed ( x, vm)
7236 }
7337 }
7438
75- #[ pyclass( flags( BASETYPE ) , with( Constructor ) ) ]
39+ #[ pyclass( flags( BASETYPE ) , with( Constructor , Initializer ) ) ]
7640 impl PyRandom {
7741 #[ pymethod]
7842 fn random ( & self ) -> f64 {
@@ -82,9 +46,8 @@ mod _random {
8246
8347 #[ pymethod]
8448 fn seed ( & self , n : OptionalOption < PyObjectRef > , vm : & VirtualMachine ) -> PyResult < ( ) > {
85- let new_rng = n
86- . flatten ( )
87- . map ( |n| {
49+ * self . rng . lock ( ) = match n. flatten ( ) {
50+ Some ( n) => {
8851 // Fallback to using hash if object isn't Int-like.
8952 let ( _, mut key) = match n. downcast :: < PyInt > ( ) {
9053 Ok ( n) => n. as_bigint ( ) . abs ( ) ,
@@ -95,27 +58,21 @@ mod _random {
9558 key. reverse ( ) ;
9659 }
9760 let key = if key. is_empty ( ) { & [ 0 ] } else { key. as_slice ( ) } ;
98- Ok ( PyRng :: MT ( Box :: new ( mt19937:: MT19937 :: new_with_slice_seed (
99- key,
100- ) ) ) )
101- } )
102- . transpose ( ) ?
103- . unwrap_or_default ( ) ;
104-
105- * self . rng . lock ( ) = new_rng;
61+ MT19937 :: new_with_slice_seed ( key)
62+ }
63+ None => MT19937 :: try_from_os_rng ( )
64+ . map_err ( |e| std:: io:: Error :: from ( e) . to_pyexception ( vm) ) ?,
65+ } ;
10666 Ok ( ( ) )
10767 }
10868
10969 #[ pymethod]
11070 fn getrandbits ( & self , k : isize , vm : & VirtualMachine ) -> PyResult < BigInt > {
11171 match k {
112- k if k < 0 => {
113- Err ( vm. new_value_error ( "number of bits must be non-negative" . to_owned ( ) ) )
114- }
72+ ..0 => Err ( vm. new_value_error ( "number of bits must be non-negative" . to_owned ( ) ) ) ,
11573 0 => Ok ( BigInt :: zero ( ) ) ,
116- _ => {
74+ mut k => {
11775 let mut rng = self . rng . lock ( ) ;
118- let mut k = k;
11976 let mut gen_u32 = |k| {
12077 let r = rng. next_u32 ( ) ;
12178 if k < 32 {
@@ -145,5 +102,40 @@ mod _random {
145102 }
146103 }
147104 }
105+
106+ #[ pymethod]
107+ fn getstate ( & self , vm : & VirtualMachine ) -> PyTupleRef {
108+ let rng = self . rng . lock ( ) ;
109+ vm. new_tuple (
110+ rng. get_state ( )
111+ . iter ( )
112+ . copied ( )
113+ . chain ( [ rng. get_index ( ) as u32 ] )
114+ . map ( |i| vm. ctx . new_int ( i) . into ( ) )
115+ . collect :: < Vec < PyObjectRef > > ( ) ,
116+ )
117+ }
118+
119+ #[ pymethod]
120+ fn setstate ( & self , state : PyTupleRef , vm : & VirtualMachine ) -> PyResult < ( ) > {
121+ let state: & [ _ ; mt19937:: N + 1 ] = state
122+ . as_slice ( )
123+ . try_into ( )
124+ . map_err ( |_| vm. new_value_error ( "state vector is the wrong size" . to_owned ( ) ) ) ?;
125+ let ( index, state) = state. split_last ( ) . unwrap ( ) ;
126+ let index: usize = index. try_to_value ( vm) ?;
127+ if index > mt19937:: N {
128+ return Err ( vm. new_value_error ( "invalid state" . to_owned ( ) ) ) ;
129+ }
130+ let state: [ u32 ; mt19937:: N ] = state
131+ . iter ( )
132+ . map ( |i| i. try_to_value ( vm) )
133+ . process_results ( |it| it. collect_array ( ) ) ?
134+ . unwrap ( ) ;
135+ let mut rng = self . rng . lock ( ) ;
136+ rng. set_state ( & state) ;
137+ rng. set_index ( index) ;
138+ Ok ( ( ) )
139+ }
148140 }
149141}
0 commit comments