Skip to content

Commit 97ae02b

Browse files
committed
clean up trace processing + shape checks
1 parent 44274c3 commit 97ae02b

2 files changed

Lines changed: 76 additions & 25 deletions

File tree

control/iosys.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1572,7 +1572,7 @@ def input_output_response(
15721572
u = U[i] if len(U.shape) == 1 else U[:, i]
15731573
y[:, i] = sys._out(T[i], [], u)
15741574
return TimeResponseData(
1575-
T, y, np.zeros((0, 0, np.asarray(T).size)), None, sys=sys,
1575+
T, y, None, None, sys=sys,
15761576
transpose=transpose, return_x=return_x, squeeze=squeeze)
15771577

15781578
# create X0 if not given, test if X0 has correct shape

control/timeresp.py

Lines changed: 75 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def __init__(
219219
220220
states : array, optional
221221
Individual response of each state variable. This should be a 2D
222-
array indexed by the state index and time (for single input
222+
array indexed by the state index and time (for single trace
223223
systems) or a 3D array indexed by state, trace, and time.
224224
225225
inputs : array, optional
@@ -281,50 +281,101 @@ def __init__(
281281
if len(self.t.shape) != 1:
282282
raise ValueError("Time vector must be 1D array")
283283

284+
#
284285
# Output vector (and number of traces)
286+
#
285287
self.y = np.array(outputs)
286-
if multi_trace or len(self.y.shape) == 3:
287-
if len(self.y.shape) < 2:
288-
raise ValueError("Output vector is the wrong shape")
289-
self.ntraces = self.y.shape[-2]
290-
self.noutputs = 1 if len(self.y.shape) < 2 else \
291-
self.y.shape[0]
292-
else:
288+
289+
if len(self.y.shape) == 3:
290+
multi_trace = True
291+
self.noutputs = self.y.shape[0]
292+
self.ntraces = self.y.shape[1]
293+
294+
elif multi_trace and len(self.y.shape) == 2:
295+
self.noutputs = 1
296+
self.ntraces = self.y.shape[0]
297+
298+
elif not multi_trace and len(self.y.shape) == 2:
299+
self.noutputs = self.y.shape[0]
300+
self.ntraces = 1
301+
302+
elif not multi_trace and len(self.y.shape) == 1:
303+
self.nouptuts = 1
293304
self.ntraces = 1
294-
self.noutputs = 1 if len(self.y.shape) < 2 else \
295-
self.y.shape[0]
296305

297-
# Make sure time dimension of output is OK
306+
else:
307+
raise ValueError("Output vector is the wrong shape")
308+
309+
# Make sure time dimension of output is the right length
298310
if self.t.shape[-1] != self.y.shape[-1]:
299311
raise ValueError("Output vector does not match time vector")
300312

301-
# State vector
302-
self.x = np.array(states)
303-
self.nstates = 0 if self.x is None else self.x.shape[0]
304-
if self.t.shape[-1] != self.x.shape[-1]:
305-
raise ValueError("State vector does not match time vector")
313+
#
314+
# State vector (optional)
315+
#
316+
# If present, the shape of the state vector should be consistent
317+
# with the multi-trace nature of the data.
318+
#
319+
if states is None:
320+
self.x = None
321+
self.nstates = 0
322+
else:
323+
self.x = np.array(states)
324+
self.nstates = self.x.shape[0]
325+
326+
# Make sure the shape is OK
327+
if multi_trace and len(self.x.shape) != 3 or \
328+
not multi_trace and len(self.x.shape) != 2:
329+
raise ValueError("State vector is the wrong shape")
306330

307-
# Input vector
308-
# If no input is present, return an empty array
331+
# Make sure time dimension of state is the right length
332+
if self.t.shape[-1] != self.x.shape[-1]:
333+
raise ValueError("State vector does not match time vector")
334+
335+
#
336+
# Input vector (optional)
337+
#
338+
# If present, the shape and dimensions of the input vector should be
339+
# consistent with the trace count computed above.
340+
#
309341
if inputs is None:
310342
self.u = None
343+
self.ninputs = 0
344+
311345
else:
312346
self.u = np.array(inputs)
313347

314-
if self.u is not None:
315-
self.ninputs = 1 if len(self.u.shape) < 2 \
316-
else self.u.shape[-2]
348+
# Make sure the shape is OK and figure out the nuumber of inputs
349+
if multi_trace and len(self.u.shape) == 3 and \
350+
self.u.shape[1] == self.ntraces:
351+
self.ninputs = self.u.shape[0]
352+
353+
elif multi_trace and len(self.u.shape) == 2 and \
354+
self.u.shape[0] == self.ntraces:
355+
self.ninputs = 1
356+
357+
elif not multi_trace and len(self.u.shape) == 2 and \
358+
self.ntraces == 1:
359+
self.ninputs = self.u.shape[0]
360+
361+
elif not multi_trace and len(self.u.shape) == 1:
362+
self.ninputs = 1
363+
364+
else:
365+
raise ValueError("Input vector is the wrong shape")
366+
367+
# Make sure time dimension of output is the right length
317368
if self.t.shape[-1] != self.u.shape[-1]:
318369
raise ValueError("Input vector does not match time vector")
319-
else:
320-
self.ninputs = 0
321370

322371
# If the system was specified, make sure it is compatible
323372
if sys is not None:
324373
if sys.noutputs != self.noutputs:
325374
ValueError("System outputs do not match response data")
326-
if sys.nstates != self.nstates:
375+
if self.x is not None and sys.nstates != self.nstates:
327376
ValueError("System states do not match response data")
377+
if self.u is not None and sys.ninputs != self.ninputs:
378+
ValueError("System inputs do not match response data")
328379
self.sys = sys
329380

330381
# Keep track of whether to squeeze inputs, outputs, and states

0 commit comments

Comments
 (0)