Skip to content

Commit c4de2a7

Browse files
CPython Developersyouknowone
authored andcommitted
Update random from v3.14.3
1 parent cc4a7bb commit c4de2a7

2 files changed

Lines changed: 348 additions & 199 deletions

File tree

Lib/random.py

Lines changed: 109 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -50,39 +50,18 @@
5050
# Adrian Baddeley. Adapted by Raymond Hettinger for use with
5151
# the Mersenne Twister and os.urandom() core generators.
5252

53-
from warnings import warn as _warn
5453
from math import log as _log, exp as _exp, pi as _pi, e as _e, ceil as _ceil
5554
from math import sqrt as _sqrt, acos as _acos, cos as _cos, sin as _sin
5655
from math import tau as TWOPI, floor as _floor, isfinite as _isfinite
5756
from math import lgamma as _lgamma, fabs as _fabs, log2 as _log2
58-
try:
59-
from os import urandom as _urandom
60-
except ImportError:
61-
# XXX RUSTPYTHON
62-
# On wasm, _random.Random.random() does give a proper random value, but
63-
# we don't have the os module
64-
def _urandom(*args, **kwargs):
65-
raise NotImplementedError("urandom")
66-
_os = None
67-
from _collections_abc import Set as _Set, Sequence as _Sequence
57+
from os import urandom as _urandom
58+
from _collections_abc import Sequence as _Sequence
6859
from operator import index as _index
6960
from itertools import accumulate as _accumulate, repeat as _repeat
7061
from bisect import bisect as _bisect
71-
try:
72-
import os as _os
73-
except ImportError:
74-
# XXX RUSTPYTHON
75-
# On wasm, we don't have the os module
76-
_os = None
62+
import os as _os
7763
import _random
7864

79-
try:
80-
# hashlib is pretty heavy to load, try lean internal module first
81-
from _sha2 import sha512 as _sha512
82-
except ImportError:
83-
# fallback to official implementation
84-
from hashlib import sha512 as _sha512
85-
8665
__all__ = [
8766
"Random",
8867
"SystemRandom",
@@ -118,6 +97,7 @@ def _urandom(*args, **kwargs):
11897
BPF = 53 # Number of bits in a float
11998
RECIP_BPF = 2 ** -BPF
12099
_ONE = 1
100+
_sha512 = None
121101

122102

123103
class Random(_random.Random):
@@ -172,13 +152,23 @@ def seed(self, a=None, version=2):
172152
a = -2 if x == -1 else x
173153

174154
elif version == 2 and isinstance(a, (str, bytes, bytearray)):
155+
global _sha512
156+
if _sha512 is None:
157+
try:
158+
# hashlib is pretty heavy to load, try lean internal
159+
# module first
160+
from _sha2 import sha512 as _sha512
161+
except ImportError:
162+
# fallback to official implementation
163+
from hashlib import sha512 as _sha512
164+
175165
if isinstance(a, str):
176166
a = a.encode()
177167
a = int.from_bytes(a + _sha512(a).digest())
178168

179169
elif not isinstance(a, (type(None), int, float, str, bytes, bytearray)):
180-
raise TypeError('The only supported seed types are: None,\n'
181-
'int, float, str, bytes, and bytearray.')
170+
raise TypeError('The only supported seed types are:\n'
171+
'None, int, float, str, bytes, and bytearray.')
182172

183173
super().seed(a)
184174
self.gauss_next = None
@@ -255,11 +245,10 @@ def __init_subclass__(cls, /, **kwargs):
255245
def _randbelow_with_getrandbits(self, n):
256246
"Return a random int in the range [0,n). Defined for n > 0."
257247

258-
getrandbits = self.getrandbits
259248
k = n.bit_length()
260-
r = getrandbits(k) # 0 <= r < 2**k
249+
r = self.getrandbits(k) # 0 <= r < 2**k
261250
while r >= n:
262-
r = getrandbits(k)
251+
r = self.getrandbits(k)
263252
return r
264253

265254
def _randbelow_without_getrandbits(self, n, maxsize=1<<BPF):
@@ -270,9 +259,10 @@ def _randbelow_without_getrandbits(self, n, maxsize=1<<BPF):
270259

271260
random = self.random
272261
if n >= maxsize:
273-
_warn("Underlying random() generator does not supply \n"
274-
"enough bits to choose from a population range this large.\n"
275-
"To remove the range limitation, add a getrandbits() method.")
262+
from warnings import warn
263+
warn("Underlying random() generator does not supply \n"
264+
"enough bits to choose from a population range this large.\n"
265+
"To remove the range limitation, add a getrandbits() method.")
276266
return _floor(random() * n)
277267
rem = maxsize % n
278268
limit = (maxsize - rem) / maxsize # int(limit * maxsize) % n == 0
@@ -345,8 +335,11 @@ def randrange(self, start, stop=None, step=_ONE):
345335
def randint(self, a, b):
346336
"""Return random integer in range [a, b], including both end points.
347337
"""
348-
349-
return self.randrange(a, b+1)
338+
a = _index(a)
339+
b = _index(b)
340+
if b < a:
341+
raise ValueError(f"empty range in randint({a}, {b})")
342+
return a + self._randbelow(b - a + 1)
350343

351344

352345
## -------------------- sequence methods -------------------
@@ -430,11 +423,11 @@ def sample(self, population, k, *, counts=None):
430423
cum_counts = list(_accumulate(counts))
431424
if len(cum_counts) != n:
432425
raise ValueError('The number of counts does not match the population')
433-
total = cum_counts.pop()
426+
total = cum_counts.pop() if cum_counts else 0
434427
if not isinstance(total, int):
435428
raise TypeError('Counts must be integers')
436-
if total <= 0:
437-
raise ValueError('Total of counts must be greater than zero')
429+
if total < 0:
430+
raise ValueError('Counts must be non-negative')
438431
selections = self.sample(range(total), k=k)
439432
bisect = _bisect
440433
return [population[bisect(cum_counts, s)] for s in selections]
@@ -801,12 +794,18 @@ def binomialvariate(self, n=1, p=0.5):
801794
802795
sum(random() < p for i in range(n))
803796
804-
Returns an integer in the range: 0 <= X <= n
797+
Returns an integer in the range:
798+
799+
0 <= X <= n
800+
801+
The integer is chosen with the probability:
802+
803+
P(X == k) = math.comb(n, k) * p ** k * (1 - p) ** (n - k)
805804
806805
The mean (expected value) and variance of the random variable are:
807806
808807
E[X] = n * p
809-
Var[x] = n * p * (1 - p)
808+
Var[X] = n * p * (1 - p)
810809
811810
"""
812811
# Error check inputs and handle edge cases
@@ -1005,5 +1004,75 @@ def _test(N=10_000):
10051004
_os.register_at_fork(after_in_child=_inst.seed)
10061005

10071006

1007+
# ------------------------------------------------------
1008+
# -------------- command-line interface ----------------
1009+
1010+
1011+
def _parse_args(arg_list: list[str] | None):
1012+
import argparse
1013+
parser = argparse.ArgumentParser(
1014+
formatter_class=argparse.RawTextHelpFormatter, color=True)
1015+
group = parser.add_mutually_exclusive_group()
1016+
group.add_argument(
1017+
"-c", "--choice", nargs="+",
1018+
help="print a random choice")
1019+
group.add_argument(
1020+
"-i", "--integer", type=int, metavar="N",
1021+
help="print a random integer between 1 and N inclusive")
1022+
group.add_argument(
1023+
"-f", "--float", type=float, metavar="N",
1024+
help="print a random floating-point number between 0 and N inclusive")
1025+
group.add_argument(
1026+
"--test", type=int, const=10_000, nargs="?",
1027+
help=argparse.SUPPRESS)
1028+
parser.add_argument("input", nargs="*",
1029+
help="""\
1030+
if no options given, output depends on the input
1031+
string or multiple: same as --choice
1032+
integer: same as --integer
1033+
float: same as --float""")
1034+
args = parser.parse_args(arg_list)
1035+
return args, parser.format_help()
1036+
1037+
1038+
def main(arg_list: list[str] | None = None) -> int | str:
1039+
args, help_text = _parse_args(arg_list)
1040+
1041+
# Explicit arguments
1042+
if args.choice:
1043+
return choice(args.choice)
1044+
1045+
if args.integer is not None:
1046+
return randint(1, args.integer)
1047+
1048+
if args.float is not None:
1049+
return uniform(0, args.float)
1050+
1051+
if args.test:
1052+
_test(args.test)
1053+
return ""
1054+
1055+
# No explicit argument, select based on input
1056+
if len(args.input) == 1:
1057+
val = args.input[0]
1058+
try:
1059+
# Is it an integer?
1060+
val = int(val)
1061+
return randint(1, val)
1062+
except ValueError:
1063+
try:
1064+
# Is it a float?
1065+
val = float(val)
1066+
return uniform(0, val)
1067+
except ValueError:
1068+
# Split in case of space-separated string: "a b c"
1069+
return choice(val.split())
1070+
1071+
if len(args.input) >= 2:
1072+
return choice(args.input)
1073+
1074+
return help_text
1075+
1076+
10081077
if __name__ == '__main__':
1009-
_test()
1078+
print(main())

0 commit comments

Comments
 (0)