5050# Adrian Baddeley. Adapted by Raymond Hettinger for use with
5151# the Mersenne Twister and os.urandom() core generators.
5252
53- from warnings import warn as _warn
5453from math import log as _log , exp as _exp , pi as _pi , e as _e , ceil as _ceil
5554from math import sqrt as _sqrt , acos as _acos , cos as _cos , sin as _sin
5655from math import tau as TWOPI , floor as _floor , isfinite as _isfinite
5756from math import lgamma as _lgamma , fabs as _fabs , log2 as _log2
58- try :
59- from os import urandom as _urandom
60- except ImportError :
61- # XXX RUSTPYTHON
62- # On wasm, _random.Random.random() does give a proper random value, but
63- # we don't have the os module
64- def _urandom (* args , ** kwargs ):
65- raise NotImplementedError ("urandom" )
66- _os = None
67- from _collections_abc import Set as _Set , Sequence as _Sequence
57+ from os import urandom as _urandom
58+ from _collections_abc import Sequence as _Sequence
6859from operator import index as _index
6960from itertools import accumulate as _accumulate , repeat as _repeat
7061from bisect import bisect as _bisect
71- try :
72- import os as _os
73- except ImportError :
74- # XXX RUSTPYTHON
75- # On wasm, we don't have the os module
76- _os = None
62+ import os as _os
7763import _random
7864
79- try :
80- # hashlib is pretty heavy to load, try lean internal module first
81- from _sha2 import sha512 as _sha512
82- except ImportError :
83- # fallback to official implementation
84- from hashlib import sha512 as _sha512
85-
8665__all__ = [
8766 "Random" ,
8867 "SystemRandom" ,
@@ -118,6 +97,7 @@ def _urandom(*args, **kwargs):
11897BPF = 53 # Number of bits in a float
11998RECIP_BPF = 2 ** - BPF
12099_ONE = 1
100+ _sha512 = None
121101
122102
123103class Random (_random .Random ):
@@ -172,13 +152,23 @@ def seed(self, a=None, version=2):
172152 a = - 2 if x == - 1 else x
173153
174154 elif version == 2 and isinstance (a , (str , bytes , bytearray )):
155+ global _sha512
156+ if _sha512 is None :
157+ try :
158+ # hashlib is pretty heavy to load, try lean internal
159+ # module first
160+ from _sha2 import sha512 as _sha512
161+ except ImportError :
162+ # fallback to official implementation
163+ from hashlib import sha512 as _sha512
164+
175165 if isinstance (a , str ):
176166 a = a .encode ()
177167 a = int .from_bytes (a + _sha512 (a ).digest ())
178168
179169 elif not isinstance (a , (type (None ), int , float , str , bytes , bytearray )):
180- raise TypeError ('The only supported seed types are: None, \n '
181- 'int, float, str, bytes, and bytearray.' )
170+ raise TypeError ('The only supported seed types are:\n '
171+ 'None, int, float, str, bytes, and bytearray.' )
182172
183173 super ().seed (a )
184174 self .gauss_next = None
@@ -255,11 +245,10 @@ def __init_subclass__(cls, /, **kwargs):
255245 def _randbelow_with_getrandbits (self , n ):
256246 "Return a random int in the range [0,n). Defined for n > 0."
257247
258- getrandbits = self .getrandbits
259248 k = n .bit_length ()
260- r = getrandbits (k ) # 0 <= r < 2**k
249+ r = self . getrandbits (k ) # 0 <= r < 2**k
261250 while r >= n :
262- r = getrandbits (k )
251+ r = self . getrandbits (k )
263252 return r
264253
265254 def _randbelow_without_getrandbits (self , n , maxsize = 1 << BPF ):
@@ -270,9 +259,10 @@ def _randbelow_without_getrandbits(self, n, maxsize=1<<BPF):
270259
271260 random = self .random
272261 if n >= maxsize :
273- _warn ("Underlying random() generator does not supply \n "
274- "enough bits to choose from a population range this large.\n "
275- "To remove the range limitation, add a getrandbits() method." )
262+ from warnings import warn
263+ warn ("Underlying random() generator does not supply \n "
264+ "enough bits to choose from a population range this large.\n "
265+ "To remove the range limitation, add a getrandbits() method." )
276266 return _floor (random () * n )
277267 rem = maxsize % n
278268 limit = (maxsize - rem ) / maxsize # int(limit * maxsize) % n == 0
@@ -345,8 +335,11 @@ def randrange(self, start, stop=None, step=_ONE):
345335 def randint (self , a , b ):
346336 """Return random integer in range [a, b], including both end points.
347337 """
348-
349- return self .randrange (a , b + 1 )
338+ a = _index (a )
339+ b = _index (b )
340+ if b < a :
341+ raise ValueError (f"empty range in randint({ a } , { b } )" )
342+ return a + self ._randbelow (b - a + 1 )
350343
351344
352345 ## -------------------- sequence methods -------------------
@@ -430,11 +423,11 @@ def sample(self, population, k, *, counts=None):
430423 cum_counts = list (_accumulate (counts ))
431424 if len (cum_counts ) != n :
432425 raise ValueError ('The number of counts does not match the population' )
433- total = cum_counts .pop ()
426+ total = cum_counts .pop () if cum_counts else 0
434427 if not isinstance (total , int ):
435428 raise TypeError ('Counts must be integers' )
436- if total <= 0 :
437- raise ValueError ('Total of counts must be greater than zero ' )
429+ if total < 0 :
430+ raise ValueError ('Counts must be non-negative ' )
438431 selections = self .sample (range (total ), k = k )
439432 bisect = _bisect
440433 return [population [bisect (cum_counts , s )] for s in selections ]
@@ -801,12 +794,18 @@ def binomialvariate(self, n=1, p=0.5):
801794
802795 sum(random() < p for i in range(n))
803796
804- Returns an integer in the range: 0 <= X <= n
797+ Returns an integer in the range:
798+
799+ 0 <= X <= n
800+
801+ The integer is chosen with the probability:
802+
803+ P(X == k) = math.comb(n, k) * p ** k * (1 - p) ** (n - k)
805804
806805 The mean (expected value) and variance of the random variable are:
807806
808807 E[X] = n * p
809- Var[x ] = n * p * (1 - p)
808+ Var[X ] = n * p * (1 - p)
810809
811810 """
812811 # Error check inputs and handle edge cases
@@ -1005,5 +1004,75 @@ def _test(N=10_000):
10051004 _os .register_at_fork (after_in_child = _inst .seed )
10061005
10071006
1007+ # ------------------------------------------------------
1008+ # -------------- command-line interface ----------------
1009+
1010+
1011+ def _parse_args (arg_list : list [str ] | None ):
1012+ import argparse
1013+ parser = argparse .ArgumentParser (
1014+ formatter_class = argparse .RawTextHelpFormatter , color = True )
1015+ group = parser .add_mutually_exclusive_group ()
1016+ group .add_argument (
1017+ "-c" , "--choice" , nargs = "+" ,
1018+ help = "print a random choice" )
1019+ group .add_argument (
1020+ "-i" , "--integer" , type = int , metavar = "N" ,
1021+ help = "print a random integer between 1 and N inclusive" )
1022+ group .add_argument (
1023+ "-f" , "--float" , type = float , metavar = "N" ,
1024+ help = "print a random floating-point number between 0 and N inclusive" )
1025+ group .add_argument (
1026+ "--test" , type = int , const = 10_000 , nargs = "?" ,
1027+ help = argparse .SUPPRESS )
1028+ parser .add_argument ("input" , nargs = "*" ,
1029+ help = """\
1030+ if no options given, output depends on the input
1031+ string or multiple: same as --choice
1032+ integer: same as --integer
1033+ float: same as --float""" )
1034+ args = parser .parse_args (arg_list )
1035+ return args , parser .format_help ()
1036+
1037+
1038+ def main (arg_list : list [str ] | None = None ) -> int | str :
1039+ args , help_text = _parse_args (arg_list )
1040+
1041+ # Explicit arguments
1042+ if args .choice :
1043+ return choice (args .choice )
1044+
1045+ if args .integer is not None :
1046+ return randint (1 , args .integer )
1047+
1048+ if args .float is not None :
1049+ return uniform (0 , args .float )
1050+
1051+ if args .test :
1052+ _test (args .test )
1053+ return ""
1054+
1055+ # No explicit argument, select based on input
1056+ if len (args .input ) == 1 :
1057+ val = args .input [0 ]
1058+ try :
1059+ # Is it an integer?
1060+ val = int (val )
1061+ return randint (1 , val )
1062+ except ValueError :
1063+ try :
1064+ # Is it a float?
1065+ val = float (val )
1066+ return uniform (0 , val )
1067+ except ValueError :
1068+ # Split in case of space-separated string: "a b c"
1069+ return choice (val .split ())
1070+
1071+ if len (args .input ) >= 2 :
1072+ return choice (args .input )
1073+
1074+ return help_text
1075+
1076+
10081077if __name__ == '__main__' :
1009- _test ( )
1078+ print ( main () )
0 commit comments