Skip to content

Commit 97a5230

Browse files
committed
update phaseplot to use common ax, rcParams processing
1 parent fb5c194 commit 97a5230

1 file changed

Lines changed: 51 additions & 30 deletions

File tree

control/phaseplot.py

Lines changed: 51 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,16 @@
3636
from scipy.integrate import odeint
3737

3838
from . import config
39-
from .ctrlplot import ControlPlot, _add_arrows_to_line2D, _process_ax_keyword
39+
from .ctrlplot import ControlPlot, _add_arrows_to_line2D, \
40+
_ctrlplot_rcParams, _process_ax_keyword, suptitle
4041
from .exception import ControlNotImplemented
4142
from .nlsys import NonlinearIOSystem, find_eqpt, input_output_response
4243

4344
__all__ = ['phase_plane_plot', 'phase_plot', 'box_grid']
4445

4546
# Default values for module parameter variables
4647
_phaseplot_defaults = {
48+
'phaseplot.rcParams': _ctrlplot_rcParams,
4749
'phaseplot.arrows': 2, # number of arrows around curve
4850
'phaseplot.arrow_size': 8, # pixel size for arrows
4951
'phaseplot.separatrices_radius': 0.1 # initial radius for separatrices
@@ -139,15 +141,12 @@ def phase_plane_plot(
139141
params = kwargs.get('params', None)
140142
sys = _create_system(sys, params)
141143
pointdata = [-1, 1, -1, 1] if pointdata is None else pointdata
144+
rcParams = config._get_param(
145+
'timeplot', 'rcParams', kwargs, _phaseplot_defaults, pop=True)
142146

143147
# Create axis if needed
144148
user_ax = ax
145-
# TODO: make use of _process_ax_keyword
146-
# fig, ax = _process_ax_keyword(user_ax, squeeze=True)
147-
if ax is None:
148-
fig, ax = plt.gcf(), plt.gca()
149-
else:
150-
fig = None # don't modify figure
149+
fig, ax = _process_ax_keyword(user_ax, squeeze=True, rcParams=rcParams)
151150

152151
# Create copy of kwargs for later checking to find unused arguments
153152
initial_kwargs = dict(kwargs)
@@ -217,10 +216,12 @@ def _create_kwargs(global_kwargs, local_kwargs, **other_kwargs):
217216

218217
# TODO: update to common code pattern
219218
if user_ax is None:
220-
ax.set_title(f"Phase portrait for {sys.name}")
221-
ax.set_xlabel(sys.state_labels[0])
222-
ax.set_ylabel(sys.state_labels[1])
219+
with plt.rc_context(rcParams):
220+
suptitle(f"Phase portrait for {sys.name}")
221+
ax.set_xlabel(sys.state_labels[0])
222+
ax.set_ylabel(sys.state_labels[1])
223223

224+
plt.tight_layout()
224225
return ControlPlot(out, ax, fig)
225226

226227

@@ -273,6 +274,10 @@ def vectorfield(
273274
If set to `True`, suppress warning messages in generating trajectories.
274275
275276
"""
277+
# Process keywords
278+
rcParams = config._get_param(
279+
'timeplot', 'rcParams', kwargs, _phaseplot_defaults, pop=True)
280+
276281
# Get system parameters
277282
params = kwargs.pop('params', None)
278283

@@ -303,9 +308,10 @@ def vectorfield(
303308
vfdata[i, :2] = x
304309
vfdata[i, 2:] = sys._rhs(0, x, 0)
305310

306-
out = ax.quiver(
307-
vfdata[:, 0], vfdata[:, 1], vfdata[:, 2], vfdata[:, 3],
308-
angles='xy', color=color)
311+
with plt.rc_context(rcParams):
312+
out = ax.quiver(
313+
vfdata[:, 0], vfdata[:, 1], vfdata[:, 2], vfdata[:, 3],
314+
angles='xy', color=color)
309315

310316
return out
311317

@@ -362,6 +368,10 @@ def streamlines(
362368
If set to `True`, suppress warning messages in generating trajectories.
363369
364370
"""
371+
# Process keywords
372+
rcParams = config._get_param(
373+
'timeplot', 'rcParams', kwargs, _phaseplot_defaults, pop=True)
374+
365375
# Get system parameters
366376
params = kwargs.pop('params', None)
367377

@@ -411,13 +421,13 @@ def streamlines(
411421

412422
# Plot the trajectory (if there is one)
413423
if traj.shape[1] > 1:
414-
out.append(
415-
ax.plot(traj[0], traj[1], color=color))
416-
417-
# Add arrows to the lines at specified intervals
418-
_add_arrows_to_line2D(
419-
ax, out[-1][0], arrow_pos, arrowstyle=arrow_style, dir=1)
424+
with plt.rc_context(rcParams):
425+
out.append(
426+
ax.plot(traj[0], traj[1], color=color))
420427

428+
# Add arrows to the lines at specified intervals
429+
_add_arrows_to_line2D(
430+
ax, out[-1][0], arrow_pos, arrowstyle=arrow_style, dir=1)
421431
return out
422432

423433

@@ -464,6 +474,10 @@ def equilpoints(
464474
out : list of Line2D objects
465475
466476
"""
477+
# Process keywords
478+
rcParams = config._get_param(
479+
'timeplot', 'rcParams', kwargs, _phaseplot_defaults, pop=True)
480+
467481
# Get system parameters
468482
params = kwargs.pop('params', None)
469483

@@ -491,9 +505,9 @@ def equilpoints(
491505
# Plot the equilibrium points
492506
out = []
493507
for xeq in equilpts:
494-
out.append(
495-
ax.plot(xeq[0], xeq[1], marker='o', color=color))
496-
508+
with plt.rc_context(rcParams):
509+
out.append(
510+
ax.plot(xeq[0], xeq[1], marker='o', color=color))
497511
return out
498512

499513

@@ -549,6 +563,10 @@ def separatrices(
549563
If set to `True`, suppress warning messages in generating trajectories.
550564
551565
"""
566+
# Process keywords
567+
rcParams = config._get_param(
568+
'timeplot', 'rcParams', kwargs, _phaseplot_defaults, pop=True)
569+
552570
# Get system parameters
553571
params = kwargs.pop('params', None)
554572

@@ -598,8 +616,9 @@ def separatrices(
598616
out = []
599617
for i, xeq in enumerate(equilpts):
600618
# Plot the equilibrium points
601-
out.append(
602-
ax.plot(xeq[0], xeq[1], marker='o', color='k'))
619+
with plt.rc_context(rcParams):
620+
out.append(
621+
ax.plot(xeq[0], xeq[1], marker='o', color='k'))
603622

604623
# Figure out the linearization and eigenvectors
605624
evals, evecs = np.linalg.eig(sys.linearize(xeq, 0, params=params).A)
@@ -639,14 +658,15 @@ def separatrices(
639658

640659
# Plot the trajectory (if there is one)
641660
if traj.shape[1] > 1:
642-
out.append(ax.plot(
643-
traj[0], traj[1], color=color, linestyle=linestyle))
661+
with plt.rc_context(rcParams):
662+
out.append(ax.plot(
663+
traj[0], traj[1], color=color, linestyle=linestyle))
644664

645665
# Add arrows to the lines at specified intervals
646-
_add_arrows_to_line2D(
647-
ax, out[-1][0], arrow_pos, arrowstyle=arrow_style,
648-
dir=1)
649-
666+
with plt.rc_context(rcParams):
667+
_add_arrows_to_line2D(
668+
ax, out[-1][0], arrow_pos, arrowstyle=arrow_style,
669+
dir=1)
650670
return out
651671

652672

@@ -903,6 +923,7 @@ def _parse_arrow_keywords(kwargs):
903923
return arrow_pos, arrow_style
904924

905925

926+
# TODO: move to ctrlplot?
906927
def _get_color(kwargs, ax=None):
907928
if 'color' in kwargs:
908929
return kwargs.pop('color')

0 commit comments

Comments
 (0)