Skip to content

Commit 3fdfa93

Browse files
Faster bezier root finding on [0, 1]
1 parent 74fa286 commit 3fdfa93

File tree

1 file changed

+106
-11
lines changed

1 file changed

+106
-11
lines changed

lib/matplotlib/bezier.py

Lines changed: 106 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,83 @@
99
import numpy as np
1010

1111
from matplotlib import _api
12+
from numpy.polynomial.polynomial import polyval as _polyval
13+
14+
15+
def _bisect(f, a, b, tol=1e-12, max_iter=53):
16+
"""Find root of f in [a, b] using bisection. Assumes sign change exists.
17+
Note that the default max_iter of 53 reflects float64 fractional precision.
18+
"""
19+
fa = f(a)
20+
for _ in range(max_iter):
21+
mid = (a + b) * 0.5
22+
fm = f(mid)
23+
if abs(fm) < tol or (b - a) < tol:
24+
return mid
25+
if fa * fm < 0:
26+
b = mid
27+
else:
28+
a, fa = mid, fm
29+
return (a + b) * 0.5
30+
31+
32+
def _real_roots_in_01(coeffs):
33+
"""
34+
Find real roots of polynomial in [0, 1] using sampling and bisection.
35+
coeffs in ascending order: c0 + c1*x + c2*x**2 + ...
36+
"""
37+
deg = len(coeffs) - 1
38+
n_samples = max(8, deg * 2)
39+
ts = np.linspace(0, 1, n_samples)
40+
vals = _polyval(ts, coeffs)
41+
42+
signs = np.sign(vals)
43+
sign_changes = np.where((signs[:-1] != signs[1:]) & (signs[:-1] != 0))[0]
44+
45+
roots = []
46+
47+
def f(t):
48+
return _polyval(t, coeffs)
49+
50+
for i in sign_changes:
51+
roots.append(_bisect(f, ts[i], ts[i + 1]))
52+
53+
# Check endpoints
54+
if abs(vals[0]) < 1e-12:
55+
roots.insert(0, 0.0)
56+
if abs(vals[-1]) < 1e-12 and (not roots or abs(roots[-1] - 1.0) > 1e-10):
57+
roots.append(1.0)
58+
59+
return np.asarray(roots)
60+
61+
62+
def _quadratic_roots_in_01(c0, c1, c2):
63+
"""Real roots of c0 + c1*x + c2*x**2 in [0, 1]."""
64+
if abs(c2) < 1e-12: # Linear
65+
if abs(c1) < 1e-12:
66+
return np.array([])
67+
root = -c0 / c1
68+
return np.array([root]) if 0 <= root <= 1 else np.array([])
69+
70+
disc = c1 * c1 - 4 * c2 * c0
71+
if disc < 0:
72+
return np.array([])
73+
74+
sqrt_disc = np.sqrt(disc)
75+
# Numerically stable quadratic formula
76+
if c1 >= 0:
77+
q = -0.5 * (c1 + sqrt_disc)
78+
else:
79+
q = -0.5 * (c1 - sqrt_disc)
80+
81+
roots = []
82+
if abs(q) > 1e-12:
83+
roots.append(c0 / q)
84+
if abs(c2) > 1e-12:
85+
roots.append(q / c2)
86+
87+
roots = np.asarray(roots)
88+
return roots[(roots >= 0) & (roots <= 1)]
1289

1390

1491
# same algorithm as 3.8's math.comb
@@ -22,7 +99,7 @@ def _comb(n, k):
2299
return np.prod((n + 1 - i)/i).astype(int)
23100

24101

25-
# Precomputed matrices for converting Bézier control points to polynomial
102+
# Precomputed matrices for converting Bezier control points to polynomial
26103
# coefficients. _COEFF_MATRICES[n] @ control_points gives coefficients.
27104
# These avoid the slow _comb vectorized function for common cases.
28105
_COEFF_MATRICES = {
@@ -322,17 +399,35 @@ def axis_aligned_extrema(self):
322399
if n <= 1:
323400
return np.array([]), np.array([])
324401
Cj = self.polynomial_coefficients
325-
dCj = np.arange(1, n+1)[:, None] * Cj[1:]
326-
dims = []
327-
roots = []
402+
dCj = np.arange(1, n + 1)[:, None] * Cj[1:]
403+
404+
all_dims = []
405+
all_roots = []
406+
328407
for i, pi in enumerate(dCj.T):
329-
r = np.roots(pi[::-1])
330-
roots.append(r)
331-
dims.append(np.full_like(r, i))
332-
roots = np.concatenate(roots)
333-
dims = np.concatenate(dims)
334-
in_range = np.isreal(roots) & (roots >= 0) & (roots <= 1)
335-
return dims[in_range], np.real(roots)[in_range]
408+
# Trim trailing near-zeros to get actual degree
409+
deg = len(pi) - 1
410+
while deg > 0 and abs(pi[deg]) < 1e-12:
411+
deg -= 1
412+
413+
if deg == 0:
414+
continue
415+
elif deg == 1:
416+
root = -pi[0] / pi[1]
417+
r = np.array([root]) if 0 <= root <= 1 else np.array([])
418+
elif deg == 2:
419+
r = _quadratic_roots_in_01(pi[0], pi[1], pi[2])
420+
else:
421+
r = _real_roots_in_01(pi[:deg + 1])
422+
423+
if len(r) > 0:
424+
all_roots.append(r)
425+
all_dims.append(np.full(len(r), i))
426+
427+
if not all_roots:
428+
return np.array([]), np.array([])
429+
430+
return np.concatenate(all_dims), np.concatenate(all_roots)
336431

337432

338433
def split_bezier_intersecting_with_closedpath(

0 commit comments

Comments
 (0)