Skip to content

Commit b94825e

Browse files
authored
Merge pull request #724 from skirpichev/bernoulli-plus
Add plus flag to select the B_1 sign convention for bernoulli/bernfrac
2 parents 6c0bb28 + 21f791e commit b94825e

7 files changed

Lines changed: 68 additions & 52 deletions

File tree

docs/functions/numtheory.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ Bernoulli numbers and polynomials
1818

1919
:func:`~mpmath.bernoulli`
2020
^^^^^^^^^^^^^^^^^^^^^^^^^^
21-
.. autofunction:: mpmath.bernoulli(n)
21+
.. autofunction:: mpmath.bernoulli
2222

2323
:func:`~mpmath.bernfrac`
2424
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
25-
.. autofunction:: mpmath.bernfrac(n)
25+
.. autofunction:: mpmath.bernfrac
2626

2727
:func:`~mpmath.bernpoly`
2828
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

mpmath/ctx_fp.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import cmath
2+
import functools
23
import math
34
import sys
45

@@ -18,7 +19,6 @@ def __init__(ctx):
1819

1920
# Override SpecialFunctions implementation
2021
ctx.loggamma = libfp.loggamma
21-
ctx._bernoulli_cache = {}
2222
ctx.pretty = False
2323

2424
ctx._init_aliases()
@@ -56,12 +56,9 @@ def f_wrapped(ctx, *args, **kwargs):
5656
f_wrapped.__doc__ = function_docs.__dict__.get(name, f.__doc__)
5757
setattr(cls, name, f_wrapped)
5858

59-
def bernoulli(ctx, n):
60-
cache = ctx._bernoulli_cache
61-
if n in cache:
62-
return cache[n]
63-
cache[n] = to_float(mpf_bernoulli(n, 53, 'n'), strict=True)
64-
return cache[n]
59+
@functools.lru_cache
60+
def bernoulli(ctx, n, plus=False):
61+
return to_float(mpf_bernoulli(n, 53, 'n', plus=plus), strict=True)
6562

6663
pi = libfp.pi
6764
e = math.e

mpmath/ctx_mp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,8 @@ def _agm(ctx, a, b=1):
219219
else: b = b._mpc_
220220
return ctx.make_mpc(libmp.mpc_agm(a, b, prec, rounding))
221221

222-
def bernoulli(ctx, n):
223-
return ctx.make_mpf(libmp.mpf_bernoulli(int(n), *ctx._prec_rounding))
222+
def bernoulli(ctx, n, plus=False):
223+
return ctx.make_mpf(libmp.mpf_bernoulli(int(n), *ctx._prec_rounding, plus=plus))
224224

225225
def _zeta_int(ctx, n):
226226
return ctx.make_mpf(libmp.mpf_zeta_int(int(n), *ctx._prec_rounding))

mpmath/function_docs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2228,6 +2228,9 @@
22282228
returns a floating-point approximation. To obtain an exact
22292229
fraction, use :func:`~mpmath.bernfrac` instead.
22302230
2231+
Optional ``plus`` flag (default: False) control the sign choice of
2232+
the `B_1` value (default: `-0.5`).
2233+
22312234
**Examples**
22322235
22332236
Numerical values of the first few Bernoulli numbers::
@@ -2282,6 +2285,11 @@
22822285
22832286
For larger `n`, `B_n` is evaluated in terms of the Riemann zeta
22842287
function.
2288+
2289+
**References**
2290+
2291+
1. https://en.wikipedia.org/wiki/Bernoulli_number
2292+
22852293
"""
22862294

22872295
stieltjes = r"""

mpmath/libmp/gammazeta.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -380,15 +380,15 @@ def bernoulli_size(n):
380380

381381
BERNOULLI_PREC_CUTOFF = bernoulli_size(MAX_BERNOULLI_CACHE)
382382

383-
def mpf_bernoulli(n, prec, rnd=None):
383+
def mpf_bernoulli(n, prec, rnd=None, plus=False):
384384
"""Computation of Bernoulli numbers (numerically)"""
385385
if n < 2:
386386
if n < 0:
387387
raise ValueError("Bernoulli numbers only defined for n >= 0")
388388
if n == 0:
389389
return fone
390390
if n == 1:
391-
return mpf_neg(fhalf)
391+
return fhalf if plus else mpf_neg(fhalf)
392392
# For odd n > 1, the Bernoulli numbers are zero
393393
if n & 1:
394394
return fzero
@@ -466,13 +466,16 @@ def mpf_bernoulli_huge(n, prec, rnd=None):
466466
v = mpf_neg(v)
467467
return mpf_pos(v, prec, rnd or round_fast)
468468

469-
def bernfrac(n):
469+
def bernfrac(n, plus=False):
470470
r"""
471471
Returns a tuple of integers `(p, q)` such that `p/q = B_n` exactly,
472472
where `B_n` denotes the `n`-th Bernoulli number. The fraction is
473473
always reduced to lowest terms. Note that for `n > 1` and `n` odd,
474474
`B_n = 0`, and `(0, 1)` is returned.
475475
476+
Optional ``plus`` flag (default: False) control the sign choice of
477+
the `B_1` value (default: `(-1, 2)`).
478+
476479
**Examples**
477480
478481
The first few Bernoulli numbers are exactly::
@@ -540,10 +543,12 @@ def bernfrac(n):
540543
2. The Bernoulli Number Page:
541544
http://www.bernoulli.org/
542545
546+
3. https://en.wikipedia.org/wiki/Bernoulli_number
547+
543548
"""
544549
n = int(n)
545550
if n < 3:
546-
return [(1, 1), (-1, 2), (1, 6)][n]
551+
return [(1, 1), (1 if plus else -1, 2), (1, 6)][n]
547552
if n & 1:
548553
return (0, 1)
549554
q = 1

mpmath/tests/test_fp.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
2222
"""
2323

24+
import pytest
25+
2426
from mpmath import fp
2527

2628

@@ -122,12 +124,13 @@ def test_fp_expj():
122124
assert ae(fp.expjpi(0.75), (-0.7071067811865475244 + 0.7071067811865475244j))
123125
assert ae(fp.expjpi(2+3j), (0.000080699517570304599239 + 0.0j))
124126

125-
def test_fp_bernoulli():
126-
assert ae(fp.bernoulli(0), 1.0)
127-
assert ae(fp.bernoulli(1), -0.5)
128-
assert ae(fp.bernoulli(2), 0.16666666666666666667)
129-
assert ae(fp.bernoulli(10), 0.075757575757575757576)
130-
assert ae(fp.bernoulli(11), 0.0)
127+
@pytest.mark.parametrize('plus', [True, False])
128+
def test_fp_bernoulli(plus):
129+
assert ae(fp.bernoulli(0, plus), 1.0)
130+
assert ae(fp.bernoulli(1, plus), 0.5 if plus else -0.5)
131+
assert ae(fp.bernoulli(2, plus), 0.16666666666666666667)
132+
assert ae(fp.bernoulli(10, plus), 0.075757575757575757576)
133+
assert ae(fp.bernoulli(11, plus), 0.0)
131134

132135
def test_fp_gamma():
133136
assert ae(fp.gamma(1), 1.0)

mpmath/tests/test_gammazeta.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import pytest
2+
13
from mpmath import (altzeta, apery, barnesg, bell, bernfrac, bernoulli,
24
bernpoly, beta, binomial, catalan, digamma, e, euler,
35
eulerpoly, fac, fac2, factorial, fadd, ff, findroot, fp,
@@ -11,48 +13,49 @@
1113
def test_zeta_int_bug():
1214
assert mpf_zeta_int(0, 10) == from_float(-0.5)
1315

14-
def test_bernoulli():
15-
assert bernfrac(0) == (1,1)
16-
assert bernfrac(1) == (-1,2)
17-
assert bernfrac(2) == (1,6)
18-
assert bernfrac(3) == (0,1)
19-
assert bernfrac(4) == (-1,30)
20-
assert bernfrac(5) == (0,1)
21-
assert bernfrac(6) == (1,42)
22-
assert bernfrac(8) == (-1,30)
23-
assert bernfrac(10) == (5,66)
24-
assert bernfrac(12) == (-691,2730)
25-
assert bernfrac(18) == (43867,798)
26-
p, q = bernfrac(228)
16+
@pytest.mark.parametrize('plus', [True, False])
17+
def test_bernoulli(plus):
18+
assert bernfrac(0, plus) == (1,1)
19+
assert bernfrac(1, plus) == (1,2) if plus else (-1,2)
20+
assert bernfrac(2, plus) == (1,6)
21+
assert bernfrac(3, plus) == (0,1)
22+
assert bernfrac(4, plus) == (-1,30)
23+
assert bernfrac(5, plus) == (0,1)
24+
assert bernfrac(6, plus) == (1,42)
25+
assert bernfrac(8, plus) == (-1,30)
26+
assert bernfrac(10, plus) == (5,66)
27+
assert bernfrac(12, plus) == (-691,2730)
28+
assert bernfrac(18, plus) == (43867,798)
29+
p, q = bernfrac(228, plus)
2730
assert p % 10**10 == 164918161
2831
assert q == 625170
29-
p, q = bernfrac(1000)
32+
p, q = bernfrac(1000, plus)
3033
assert p % 10**10 == 7950421099
3134
assert q == 342999030
3235
mp.dps = 15
33-
assert bernoulli(0) == 1
34-
assert bernoulli(1) == -0.5
35-
assert bernoulli(2).ae(1./6)
36-
assert bernoulli(3) == 0
37-
assert bernoulli(4).ae(-1./30)
38-
assert bernoulli(5) == 0
39-
assert bernoulli(6).ae(1./42)
40-
assert str(bernoulli(10)) == '0.0757575757575758'
41-
assert str(bernoulli(234)) == '7.62772793964344e+267'
42-
assert str(bernoulli(10**5)) == '-5.82229431461335e+376755'
43-
assert str(bernoulli(10**8+2)) == '1.19570355039953e+676752584'
36+
assert bernoulli(0, plus) == 1
37+
assert bernoulli(1, plus) == 0.5 if plus else -0.5
38+
assert bernoulli(2, plus).ae(1./6)
39+
assert bernoulli(3, plus) == 0
40+
assert bernoulli(4, plus).ae(-1./30)
41+
assert bernoulli(5, plus) == 0
42+
assert bernoulli(6, plus).ae(1./42)
43+
assert str(bernoulli(10, plus)) == '0.0757575757575758'
44+
assert str(bernoulli(234, plus)) == '7.62772793964344e+267'
45+
assert str(bernoulli(10**5, plus)) == '-5.82229431461335e+376755'
46+
assert str(bernoulli(10**8+2, plus)) == '1.19570355039953e+676752584'
4447

4548
mp.dps = 50
46-
assert str(bernoulli(10)) == '0.075757575757575757575757575757575757575757575757576'
47-
assert str(bernoulli(234)) == '7.6277279396434392486994969020496121553385863373331e+267'
48-
assert str(bernoulli(10**5)) == '-5.8222943146133508236497045360612887555320691004308e+376755'
49-
assert str(bernoulli(10**8+2)) == '1.1957035503995297272263047884604346914602088317782e+676752584'
49+
assert str(bernoulli(10, plus)) == '0.075757575757575757575757575757575757575757575757576'
50+
assert str(bernoulli(234, plus)) == '7.6277279396434392486994969020496121553385863373331e+267'
51+
assert str(bernoulli(10**5, plus)) == '-5.8222943146133508236497045360612887555320691004308e+376755'
52+
assert str(bernoulli(10**8+2, plus)) == '1.1957035503995297272263047884604346914602088317782e+676752584'
5053

5154
mp.dps = 1000
52-
assert bernoulli(10).ae(mpf(5)/66)
55+
assert bernoulli(10, plus).ae(mpf(5)/66)
5356

5457
mp.dps = 50000
55-
assert bernoulli(10).ae(mpf(5)/66)
58+
assert bernoulli(10, plus).ae(mpf(5)/66)
5659

5760
mp.dps = 15
5861

0 commit comments

Comments
 (0)