Skip to content

Commit bab117d

Browse files
committed
clean up _process_time_response + use ndim
1 parent 97ae02b commit bab117d

2 files changed

Lines changed: 43 additions & 54 deletions

File tree

control/optimal.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import logging
1717
import time
1818

19-
from .timeresp import _process_time_response
19+
from .timeresp import TimeResponseData
2020

2121
__all__ = ['find_optimal_input']
2222

@@ -826,13 +826,14 @@ def __init__(
826826
else:
827827
states = None
828828

829-
retval = _process_time_response(
830-
ocp.system, ocp.timepts, inputs, states,
829+
# Process data as a time response (with "outputs" = inputs)
830+
response = TimeResponseData(
831+
ocp.timepts, inputs, states, sys=ocp.system,
831832
transpose=transpose, return_x=return_states, squeeze=squeeze)
832833

833-
self.time = retval[0]
834-
self.inputs = retval[1]
835-
self.states = None if states is None else retval[2]
834+
self.time = response.time
835+
self.inputs = response.outputs
836+
self.states = response.states
836837

837838

838839
# Compute the input for a nonlinear, (constrained) optimal control problem

control/timeresp.py

Lines changed: 36 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
general function for simulating LTI systems the
1111
:func:`forced_response` function, which has the form::
1212
13-
t, y = forced_response(sys, T, U, X0)
13+
response = forced_response(sys, T, U, X0)
14+
t, y = response.time, response.outputs
1415
1516
where `T` is a vector of times at which the response should be
1617
evaluated, `U` is a vector of inputs (one for each time point) and
@@ -106,7 +107,7 @@ class TimeResponseData:
106107
107108
Time responses are access through either the raw data, stored as ``t``,
108109
``y``, ``x``, ``u``, or using a set of properties ``time``, ``outputs``,
109-
``states``, ``inputs``. When access time responses via their
110+
``states``, ``inputs``. When accessing time responses via their
110111
properties, squeeze processing is applied so that (by default)
111112
single-input, single-output systems will have the output and input
112113
indices supressed. This behavior is set using the ``squeeze`` keyword.
@@ -278,28 +279,28 @@ def __init__(
278279

279280
# Time vector
280281
self.t = np.atleast_1d(time)
281-
if len(self.t.shape) != 1:
282+
if self.t.ndim != 1:
282283
raise ValueError("Time vector must be 1D array")
283284

284285
#
285286
# Output vector (and number of traces)
286287
#
287288
self.y = np.array(outputs)
288289

289-
if len(self.y.shape) == 3:
290+
if self.y.ndim == 3:
290291
multi_trace = True
291292
self.noutputs = self.y.shape[0]
292293
self.ntraces = self.y.shape[1]
293294

294-
elif multi_trace and len(self.y.shape) == 2:
295+
elif multi_trace and self.y.ndim == 2:
295296
self.noutputs = 1
296297
self.ntraces = self.y.shape[0]
297298

298-
elif not multi_trace and len(self.y.shape) == 2:
299+
elif not multi_trace and self.y.ndim == 2:
299300
self.noutputs = self.y.shape[0]
300301
self.ntraces = 1
301302

302-
elif not multi_trace and len(self.y.shape) == 1:
303+
elif not multi_trace and self.y.ndim == 1:
303304
self.nouptuts = 1
304305
self.ntraces = 1
305306

@@ -324,8 +325,8 @@ def __init__(
324325
self.nstates = self.x.shape[0]
325326

326327
# Make sure the shape is OK
327-
if multi_trace and len(self.x.shape) != 3 or \
328-
not multi_trace and len(self.x.shape) != 2:
328+
if multi_trace and self.x.ndim != 3 or \
329+
not multi_trace and self.x.ndim != 2:
329330
raise ValueError("State vector is the wrong shape")
330331

331332
# Make sure time dimension of state is the right length
@@ -346,19 +347,19 @@ def __init__(
346347
self.u = np.array(inputs)
347348

348349
# Make sure the shape is OK and figure out the nuumber of inputs
349-
if multi_trace and len(self.u.shape) == 3 and \
350+
if multi_trace and self.u.ndim == 3 and \
350351
self.u.shape[1] == self.ntraces:
351352
self.ninputs = self.u.shape[0]
352353

353-
elif multi_trace and len(self.u.shape) == 2 and \
354+
elif multi_trace and self.u.ndim == 2 and \
354355
self.u.shape[0] == self.ntraces:
355356
self.ninputs = 1
356357

357-
elif not multi_trace and len(self.u.shape) == 2 and \
358+
elif not multi_trace and self.u.ndim == 2 and \
358359
self.ntraces == 1:
359360
self.ninputs = self.u.shape[0]
360361

361-
elif not multi_trace and len(self.u.shape) == 1:
362+
elif not multi_trace and self.u.ndim == 1:
362363
self.ninputs = 1
363364

364365
else:
@@ -396,21 +397,30 @@ def time(self):
396397
@property
397398
def outputs(self):
398399
t, y = _process_time_response(
399-
self.sys, self.t, self.y, None,
400-
transpose=self.transpose, return_x=False, squeeze=self.squeeze,
400+
self.sys, self.t, self.y,
401+
transpose=self.transpose, squeeze=self.squeeze,
401402
input=self.input_index, output=self.output_index)
402403
return y
403404

404-
# Getter for state (implements squeeze processing)
405+
# Getter for state (implements non-standard squeeze processing)
405406
@property
406407
def states(self):
407408
if self.x is None:
408409
return None
409410

410-
t, y, x = _process_time_response(
411-
self.sys, self.t, self.y, self.x,
412-
transpose=self.transpose, return_x=True, squeeze=self.squeeze,
413-
input=self.input_index, output=self.output_index)
411+
elif self.ninputs == 1 and self.noutputs == 1 and \
412+
self.ntraces == 1 and self.x.ndim == 3:
413+
# Single-input, single-output system with single trace
414+
x = self.x[:, 0, :]
415+
416+
else:
417+
# Return the full set of data
418+
x = self.x
419+
420+
# Transpose processing
421+
if self.transpose:
422+
x = np.transpose(x, np.roll(range(x.ndim), 1))
423+
414424
return x
415425

416426
# Getter for state (implements squeeze processing)
@@ -420,8 +430,8 @@ def inputs(self):
420430
return None
421431

422432
t, u = _process_time_response(
423-
self.sys, self.t, self.u, None,
424-
transpose=self.transpose, return_x=False, squeeze=self.squeeze,
433+
self.sys, self.t, self.u,
434+
transpose=self.transpose, squeeze=self.squeeze,
425435
input=self.input_index, output=self.output_index)
426436
return u
427437

@@ -765,7 +775,7 @@ def forced_response(sys, T=None, U=0., X0=0., transpose=False,
765775
# General algorithm that interpolates U in between output points
766776
else:
767777
# convert input from 1D array to 2D array with only one row
768-
if len(U.shape) == 1:
778+
if U.ndim == 1:
769779
U = U.reshape(1, -1) # pylint: disable=E1103
770780

771781
# Algorithm: to integrate from time 0 to time dt, with linear
@@ -856,7 +866,7 @@ def forced_response(sys, T=None, U=0., X0=0., transpose=False,
856866

857867
# Process time responses in a uniform way
858868
def _process_time_response(
859-
sys, tout, yout, xout, transpose=None, return_x=False,
869+
sys, tout, yout, transpose=None,
860870
squeeze=None, input=None, output=None):
861871
"""Process time response signals.
862872
@@ -877,20 +887,11 @@ def _process_time_response(
877887
systems with no input indexing, such as initial_response or forced
878888
response) or a 3D array indexed by output, input, and time.
879889
880-
xout : array, optional
881-
Individual response of each x variable (if return_x is True). For a
882-
SISO system (or if a single input is specified), this should be a 2D
883-
array indexed by the state index and time (for single input systems)
884-
or a 3D array indexed by state, input, and time. Ignored if None.
885-
886890
transpose : bool, optional
887891
If True, transpose all input and output arrays (for backward
888892
compatibility with MATLAB and :func:`scipy.signal.lsim`). Default
889893
value is False.
890894
891-
return_x : bool, optional
892-
If True, return the state vector (default = False).
893-
894895
squeeze : bool, optional
895896
By default, if a system is single-input, single-output (SISO) then the
896897
output response is returned as a 1D array (indexed by time). If
@@ -917,13 +918,6 @@ def _process_time_response(
917918
squeeze is False, the array is either 2D (indexed by output and time)
918919
or 3D (indexed by input, output, and time).
919920
920-
xout : array, optional
921-
Individual response of each x variable (if return_x is True). For a
922-
SISO system (or if a single input is specified), xout is a 2D array
923-
indexed by the state index and time. For a non-SISO system, xout is a
924-
3D array indexed by the state, the input, and time. The shape of xout
925-
is not affected by the ``squeeze`` keyword.
926-
927921
"""
928922
# If squeeze was not specified, figure out the default (might remain None)
929923
if squeeze is None:
@@ -939,29 +933,23 @@ def _process_time_response(
939933
pass
940934
elif squeeze is None: # squeeze signals if SISO
941935
if issiso:
942-
if len(yout.shape) == 3:
936+
if yout.ndim == 3:
943937
yout = yout[0][0] # remove input and output
944938
else:
945939
yout = yout[0] # remove input
946940
else:
947941
raise ValueError("unknown squeeze value")
948942

949-
# Figure out whether and how to squeeze the state data
950-
if issiso and xout is not None and len(xout.shape) > 2:
951-
xout = xout[:, 0, :] # remove input
952-
953943
# See if we need to transpose the data back into MATLAB form
954944
if transpose:
955945
# Transpose time vector in case we are using np.matrix
956946
tout = np.transpose(tout)
957947

958948
# For signals, put the last index (time) into the first slot
959949
yout = np.transpose(yout, np.roll(range(yout.ndim), 1))
960-
if xout is not None:
961-
xout = np.transpose(xout, np.roll(range(xout.ndim), 1))
962950

963951
# Return time, output, and (optionally) state
964-
return (tout, yout, xout) if return_x else (tout, yout)
952+
return tout, yout
965953

966954

967955
def _get_ss_simo(sys, input=None, output=None, squeeze=None):

0 commit comments

Comments
 (0)