@@ -424,6 +424,113 @@ def test_rcParams():
424424 my_rcParams ['ytick.labelsize' ]
425425 assert fig ._suptitle .get_fontsize () == my_rcParams ['figure.titlesize' ]
426426
427+
428+ @pytest .mark .parametrize ("resp_fcn" , [
429+ ct .step_response , ct .initial_response , ct .impulse_response ,
430+ ct .forced_response , ct .input_output_response ])
431+ @pytest .mark .usefixtures ("editsdefaults" )
432+ def test_timeplot_trace_labels (resp_fcn ):
433+ plt .close ('all' )
434+ sys1 = ct .rss (2 , 2 , 2 , strictly_proper = True , name = 'sys1' )
435+ sys2 = ct .rss (2 , 2 , 2 , strictly_proper = True , name = 'sys2' )
436+
437+ # Figure out the expected shape of the system
438+ match resp_fcn :
439+ case ct .step_response | ct .impulse_response :
440+ shape = (2 , 2 )
441+ kwargs = {}
442+ case ct .initial_response :
443+ shape = (2 , 1 )
444+ kwargs = {}
445+ case ct .forced_response | ct .input_output_response :
446+ shape = (4 , 1 ) # outputs and inputs both plotted
447+ T = np .linspace (0 , 10 )
448+ U = [np .sin (T ), np .cos (T )]
449+ kwargs = {'T' : T , 'U' : U }
450+
451+ # Use figure frame for suptitle to speed things up
452+ ct .set_defaults ('freqplot' , suptitle_frame = 'figure' )
453+
454+ # Make sure default labels are as expected
455+ out = resp_fcn ([sys1 , sys2 ], ** kwargs ).plot ()
456+ axs = ct .get_plot_axes (out )
457+ if axs .ndim == 1 :
458+ legend = axs [0 ].get_legend ().get_texts ()
459+ else :
460+ legend = axs [0 , - 1 ].get_legend ().get_texts ()
461+ assert legend [0 ].get_text () == 'sys1'
462+ assert legend [1 ].get_text () == 'sys2'
463+ plt .close ()
464+
465+ # Override labels all at once
466+ out = resp_fcn ([sys1 , sys2 ], ** kwargs ).plot (label = ['line1' , 'line2' ])
467+ axs = ct .get_plot_axes (out )
468+ if axs .ndim == 1 :
469+ legend = axs [0 ].get_legend ().get_texts ()
470+ else :
471+ legend = axs [0 , - 1 ].get_legend ().get_texts ()
472+ assert legend [0 ].get_text () == 'line1'
473+ assert legend [1 ].get_text () == 'line2'
474+ plt .close ()
475+
476+ # Override labels one at a time
477+ out = resp_fcn (sys1 , ** kwargs ).plot (label = 'line1' )
478+ out = resp_fcn (sys2 , ** kwargs ).plot (label = 'line2' )
479+ axs = ct .get_plot_axes (out )
480+ if axs .ndim == 1 :
481+ legend = axs [0 ].get_legend ().get_texts ()
482+ else :
483+ legend = axs [0 , - 1 ].get_legend ().get_texts ()
484+ assert legend [0 ].get_text () == 'line1'
485+ assert legend [1 ].get_text () == 'line2'
486+ plt .close ()
487+
488+
489+ def test_full_label_override ():
490+ sys1 = ct .rss (2 , 2 , 2 , strictly_proper = True , name = 'sys1' )
491+ sys2 = ct .rss (2 , 2 , 2 , strictly_proper = True , name = 'sys2' )
492+
493+ labels_2d = np .array ([
494+ ["outsys1u1y1" , "outsys1u1y2" , "outsys1u2y1" , "outsys1u2y2" ,
495+ "outsys2u1y1" , "outsys2u1y2" , "outsys2u2y1" , "outsys2u2y2" ],
496+ ["inpsys1u1y1" , "inpsys1u1y2" , "inpsys1u2y1" , "inpsys1u2y2" ,
497+ "inpsys2u1y1" , "inpsys2u1y2" , "inpsys2u2y1" , "inpsys2u2y2" ]])
498+
499+
500+ labels_4d = np .empty ((2 , 2 , 2 , 2 ), dtype = object )
501+ for i , sys in enumerate (['sys1' , 'sys2' ]):
502+ for j , trace in enumerate (['u1' , 'u2' ]):
503+ for k , out in enumerate (['y1' , 'y2' ]):
504+ labels_4d [i , j , k , 0 ] = "out" + sys + trace + out
505+ labels_4d [i , j , k , 1 ] = "inp" + sys + trace + out
506+
507+ # Test 4D labels
508+ out = ct .step_response ([sys1 , sys2 ]).plot (
509+ overlay_signals = True , overlay_traces = True , plot_inputs = True ,
510+ label = labels_4d )
511+ axs = ct .get_plot_axes (out )
512+ assert axs .shape == (2 , 1 )
513+ legend_text = axs [0 , 0 ].get_legend ().get_texts ()
514+ for i , label in enumerate (labels_2d [0 ]):
515+ assert legend_text [i ].get_text () == label
516+ legend_text = axs [1 , 0 ].get_legend ().get_texts ()
517+ for i , label in enumerate (labels_2d [1 ]):
518+ assert legend_text [i ].get_text () == label
519+
520+ # Test 2D labels
521+ out = ct .step_response ([sys1 , sys2 ]).plot (
522+ overlay_signals = True , overlay_traces = True , plot_inputs = True ,
523+ label = labels_2d )
524+ axs = ct .get_plot_axes (out )
525+ assert axs .shape == (2 , 1 )
526+ legend_text = axs [0 , 0 ].get_legend ().get_texts ()
527+ for i , label in enumerate (labels_2d [0 ]):
528+ assert legend_text [i ].get_text () == label
529+ legend_text = axs [1 , 0 ].get_legend ().get_texts ()
530+ for i , label in enumerate (labels_2d [1 ]):
531+ assert legend_text [i ].get_text () == label
532+
533+
427534def test_relabel ():
428535 sys1 = ct .rss (2 , inputs = 'u' , outputs = 'y' )
429536 sys2 = ct .rss (1 , 1 , 1 ) # uses default i/o labels
0 commit comments