Skip to content

Commit c72fa91

Browse files
committed
final tweaks of arrow code
1 parent d6432ca commit c72fa91

File tree

1 file changed

+12
-22
lines changed

1 file changed

+12
-22
lines changed

control/freqplot.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1919,7 +1919,7 @@ def _parse_linestyle(style_name, allow_false=False):
19191919
# Internal function to add arrows to a curve
19201920
def _add_arrows_to_line2D(
19211921
axes, line, arrow_locs=[0.2, 0.4, 0.6, 0.8],
1922-
arrowstyle='-|>', arrowsize=1, dir=1, transform=None):
1922+
arrowstyle='-|>', arrowsize=1, dir=1):
19231923
"""
19241924
Add arrows to a matplotlib.lines.Line2D at selected locations.
19251925
@@ -1930,7 +1930,6 @@ def _add_arrows_to_line2D(
19301930
arrow_locs: list of locations where to insert arrows, % of total length
19311931
arrowstyle: style of the arrow
19321932
arrowsize: size of the arrow
1933-
transform: a matplotlib transform instance, default to data coordinates
19341933
19351934
Returns:
19361935
--------
@@ -1939,13 +1938,13 @@ def _add_arrows_to_line2D(
19391938
Based on https://stackoverflow.com/questions/26911898/
19401939
19411940
"""
1941+
# Get the coordinates of the line, in plot coordinates
19421942
if not isinstance(line, mpl.lines.Line2D):
19431943
raise ValueError("expected a matplotlib.lines.Line2D object")
19441944
x, y = line.get_xdata(), line.get_ydata()
19451945

1946-
arrow_kw = {
1947-
"arrowstyle": arrowstyle,
1948-
}
1946+
# Determine the arrow properties
1947+
arrow_kw = {"arrowstyle": arrowstyle}
19491948

19501949
color = line.get_color()
19511950
use_multicolor_lines = isinstance(color, np.ndarray)
@@ -1960,9 +1959,6 @@ def _add_arrows_to_line2D(
19601959
else:
19611960
arrow_kw['linewidth'] = linewidth
19621961

1963-
if transform is None:
1964-
transform = axes.transData
1965-
19661962
# Figure out the size of the axes (length of diagonal)
19671963
xlim, ylim = axes.get_xlim(), axes.get_ylim()
19681964
ul, lr = np.array([xlim[0], ylim[0]]), np.array([xlim[1], ylim[1]])
@@ -1979,33 +1975,27 @@ def _add_arrows_to_line2D(
19791975
elif len(arrow_locs) and frac < 0.2:
19801976
arrow_locs = [0.5] # single arrow in the middle
19811977

1978+
# Plot the arrows (and return list if patches)
19821979
arrows = []
19831980
for loc in arrow_locs:
19841981
n = np.searchsorted(s, s[-1] * loc)
19851982

1986-
# Figure out what direction to paint the arrow
1987-
if dir == 1:
1988-
n = 1 if n == 0 else n # move arrow forward if at start
1989-
arrow_tail = (x[n - 1], y[n - 1])
1990-
arrow_head = (np.mean(x[n - 1:n + 1]), np.mean(y[n - 1:n + 1]))
1991-
1992-
elif dir == -1:
1993-
# Orient the arrow in the other direction on the segment
1994-
arrow_tail = (x[n + 1], y[n + 1])
1995-
arrow_head = (np.mean(x[n:n + 2]), np.mean(y[n:n + 2]))
1983+
if dir == 1 and n == 0:
1984+
# Move the arrow forward by one if it is at start of a segment
1985+
n = 1
19961986

1997-
else:
1998-
raise ValueError("unknown value for keyword 'dir'")
1987+
# Place the head of the arrow at the desired location
1988+
arrow_head = [x[n], y[n]]
1989+
arrow_tail = [x[n - dir], y[n - dir]]
19991990

20001991
p = mpl.patches.FancyArrowPatch(
2001-
arrow_tail, arrow_head, transform=transform, lw=0,
1992+
arrow_tail, arrow_head, transform=axes.transData, lw=0,
20021993
**arrow_kw)
20031994
axes.add_patch(p)
20041995
arrows.append(p)
20051996
return arrows
20061997

20071998

2008-
20091999
#
20102000
# Function to compute Nyquist curve offsets
20112001
#

0 commit comments

Comments
 (0)