|
36 | 36 | from scipy.integrate import odeint |
37 | 37 |
|
38 | 38 | from . import config |
39 | | -from .ctrlplot import ControlPlot, _add_arrows_to_line2D, _process_ax_keyword |
| 39 | +from .ctrlplot import ControlPlot, _add_arrows_to_line2D, \ |
| 40 | + _ctrlplot_rcParams, _process_ax_keyword, suptitle |
40 | 41 | from .exception import ControlNotImplemented |
41 | 42 | from .nlsys import NonlinearIOSystem, find_eqpt, input_output_response |
42 | 43 |
|
43 | 44 | __all__ = ['phase_plane_plot', 'phase_plot', 'box_grid'] |
44 | 45 |
|
45 | 46 | # Default values for module parameter variables |
46 | 47 | _phaseplot_defaults = { |
| 48 | + 'phaseplot.rcParams': _ctrlplot_rcParams, |
47 | 49 | 'phaseplot.arrows': 2, # number of arrows around curve |
48 | 50 | 'phaseplot.arrow_size': 8, # pixel size for arrows |
49 | 51 | 'phaseplot.separatrices_radius': 0.1 # initial radius for separatrices |
@@ -139,15 +141,12 @@ def phase_plane_plot( |
139 | 141 | params = kwargs.get('params', None) |
140 | 142 | sys = _create_system(sys, params) |
141 | 143 | pointdata = [-1, 1, -1, 1] if pointdata is None else pointdata |
| 144 | + rcParams = config._get_param( |
| 145 | + 'timeplot', 'rcParams', kwargs, _phaseplot_defaults, pop=True) |
142 | 146 |
|
143 | 147 | # Create axis if needed |
144 | 148 | user_ax = ax |
145 | | - # TODO: make use of _process_ax_keyword |
146 | | - # fig, ax = _process_ax_keyword(user_ax, squeeze=True) |
147 | | - if ax is None: |
148 | | - fig, ax = plt.gcf(), plt.gca() |
149 | | - else: |
150 | | - fig = None # don't modify figure |
| 149 | + fig, ax = _process_ax_keyword(user_ax, squeeze=True, rcParams=rcParams) |
151 | 150 |
|
152 | 151 | # Create copy of kwargs for later checking to find unused arguments |
153 | 152 | initial_kwargs = dict(kwargs) |
@@ -217,10 +216,12 @@ def _create_kwargs(global_kwargs, local_kwargs, **other_kwargs): |
217 | 216 |
|
218 | 217 | # TODO: update to common code pattern |
219 | 218 | if user_ax is None: |
220 | | - ax.set_title(f"Phase portrait for {sys.name}") |
221 | | - ax.set_xlabel(sys.state_labels[0]) |
222 | | - ax.set_ylabel(sys.state_labels[1]) |
| 219 | + with plt.rc_context(rcParams): |
| 220 | + suptitle(f"Phase portrait for {sys.name}") |
| 221 | + ax.set_xlabel(sys.state_labels[0]) |
| 222 | + ax.set_ylabel(sys.state_labels[1]) |
223 | 223 |
|
| 224 | + plt.tight_layout() |
224 | 225 | return ControlPlot(out, ax, fig) |
225 | 226 |
|
226 | 227 |
|
@@ -273,6 +274,10 @@ def vectorfield( |
273 | 274 | If set to `True`, suppress warning messages in generating trajectories. |
274 | 275 |
|
275 | 276 | """ |
| 277 | + # Process keywords |
| 278 | + rcParams = config._get_param( |
| 279 | + 'timeplot', 'rcParams', kwargs, _phaseplot_defaults, pop=True) |
| 280 | + |
276 | 281 | # Get system parameters |
277 | 282 | params = kwargs.pop('params', None) |
278 | 283 |
|
@@ -303,9 +308,10 @@ def vectorfield( |
303 | 308 | vfdata[i, :2] = x |
304 | 309 | vfdata[i, 2:] = sys._rhs(0, x, 0) |
305 | 310 |
|
306 | | - out = ax.quiver( |
307 | | - vfdata[:, 0], vfdata[:, 1], vfdata[:, 2], vfdata[:, 3], |
308 | | - angles='xy', color=color) |
| 311 | + with plt.rc_context(rcParams): |
| 312 | + out = ax.quiver( |
| 313 | + vfdata[:, 0], vfdata[:, 1], vfdata[:, 2], vfdata[:, 3], |
| 314 | + angles='xy', color=color) |
309 | 315 |
|
310 | 316 | return out |
311 | 317 |
|
@@ -362,6 +368,10 @@ def streamlines( |
362 | 368 | If set to `True`, suppress warning messages in generating trajectories. |
363 | 369 |
|
364 | 370 | """ |
| 371 | + # Process keywords |
| 372 | + rcParams = config._get_param( |
| 373 | + 'timeplot', 'rcParams', kwargs, _phaseplot_defaults, pop=True) |
| 374 | + |
365 | 375 | # Get system parameters |
366 | 376 | params = kwargs.pop('params', None) |
367 | 377 |
|
@@ -411,13 +421,13 @@ def streamlines( |
411 | 421 |
|
412 | 422 | # Plot the trajectory (if there is one) |
413 | 423 | if traj.shape[1] > 1: |
414 | | - out.append( |
415 | | - ax.plot(traj[0], traj[1], color=color)) |
416 | | - |
417 | | - # Add arrows to the lines at specified intervals |
418 | | - _add_arrows_to_line2D( |
419 | | - ax, out[-1][0], arrow_pos, arrowstyle=arrow_style, dir=1) |
| 424 | + with plt.rc_context(rcParams): |
| 425 | + out.append( |
| 426 | + ax.plot(traj[0], traj[1], color=color)) |
420 | 427 |
|
| 428 | + # Add arrows to the lines at specified intervals |
| 429 | + _add_arrows_to_line2D( |
| 430 | + ax, out[-1][0], arrow_pos, arrowstyle=arrow_style, dir=1) |
421 | 431 | return out |
422 | 432 |
|
423 | 433 |
|
@@ -464,6 +474,10 @@ def equilpoints( |
464 | 474 | out : list of Line2D objects |
465 | 475 |
|
466 | 476 | """ |
| 477 | + # Process keywords |
| 478 | + rcParams = config._get_param( |
| 479 | + 'timeplot', 'rcParams', kwargs, _phaseplot_defaults, pop=True) |
| 480 | + |
467 | 481 | # Get system parameters |
468 | 482 | params = kwargs.pop('params', None) |
469 | 483 |
|
@@ -491,9 +505,9 @@ def equilpoints( |
491 | 505 | # Plot the equilibrium points |
492 | 506 | out = [] |
493 | 507 | for xeq in equilpts: |
494 | | - out.append( |
495 | | - ax.plot(xeq[0], xeq[1], marker='o', color=color)) |
496 | | - |
| 508 | + with plt.rc_context(rcParams): |
| 509 | + out.append( |
| 510 | + ax.plot(xeq[0], xeq[1], marker='o', color=color)) |
497 | 511 | return out |
498 | 512 |
|
499 | 513 |
|
@@ -549,6 +563,10 @@ def separatrices( |
549 | 563 | If set to `True`, suppress warning messages in generating trajectories. |
550 | 564 |
|
551 | 565 | """ |
| 566 | + # Process keywords |
| 567 | + rcParams = config._get_param( |
| 568 | + 'timeplot', 'rcParams', kwargs, _phaseplot_defaults, pop=True) |
| 569 | + |
552 | 570 | # Get system parameters |
553 | 571 | params = kwargs.pop('params', None) |
554 | 572 |
|
@@ -598,8 +616,9 @@ def separatrices( |
598 | 616 | out = [] |
599 | 617 | for i, xeq in enumerate(equilpts): |
600 | 618 | # Plot the equilibrium points |
601 | | - out.append( |
602 | | - ax.plot(xeq[0], xeq[1], marker='o', color='k')) |
| 619 | + with plt.rc_context(rcParams): |
| 620 | + out.append( |
| 621 | + ax.plot(xeq[0], xeq[1], marker='o', color='k')) |
603 | 622 |
|
604 | 623 | # Figure out the linearization and eigenvectors |
605 | 624 | evals, evecs = np.linalg.eig(sys.linearize(xeq, 0, params=params).A) |
@@ -639,14 +658,15 @@ def separatrices( |
639 | 658 |
|
640 | 659 | # Plot the trajectory (if there is one) |
641 | 660 | if traj.shape[1] > 1: |
642 | | - out.append(ax.plot( |
643 | | - traj[0], traj[1], color=color, linestyle=linestyle)) |
| 661 | + with plt.rc_context(rcParams): |
| 662 | + out.append(ax.plot( |
| 663 | + traj[0], traj[1], color=color, linestyle=linestyle)) |
644 | 664 |
|
645 | 665 | # Add arrows to the lines at specified intervals |
646 | | - _add_arrows_to_line2D( |
647 | | - ax, out[-1][0], arrow_pos, arrowstyle=arrow_style, |
648 | | - dir=1) |
649 | | - |
| 666 | + with plt.rc_context(rcParams): |
| 667 | + _add_arrows_to_line2D( |
| 668 | + ax, out[-1][0], arrow_pos, arrowstyle=arrow_style, |
| 669 | + dir=1) |
650 | 670 | return out |
651 | 671 |
|
652 | 672 |
|
@@ -903,6 +923,7 @@ def _parse_arrow_keywords(kwargs): |
903 | 923 | return arrow_pos, arrow_style |
904 | 924 |
|
905 | 925 |
|
| 926 | +# TODO: move to ctrlplot? |
906 | 927 | def _get_color(kwargs, ax=None): |
907 | 928 | if 'color' in kwargs: |
908 | 929 | return kwargs.pop('color') |
|
0 commit comments