@@ -219,7 +219,7 @@ def __init__(
219219
220220 states : array, optional
221221 Individual response of each state variable. This should be a 2D
222- array indexed by the state index and time (for single input
222+ array indexed by the state index and time (for single trace
223223 systems) or a 3D array indexed by state, trace, and time.
224224
225225 inputs : array, optional
@@ -281,50 +281,101 @@ def __init__(
281281 if len (self .t .shape ) != 1 :
282282 raise ValueError ("Time vector must be 1D array" )
283283
284+ #
284285 # Output vector (and number of traces)
286+ #
285287 self .y = np .array (outputs )
286- if multi_trace or len (self .y .shape ) == 3 :
287- if len (self .y .shape ) < 2 :
288- raise ValueError ("Output vector is the wrong shape" )
289- self .ntraces = self .y .shape [- 2 ]
290- self .noutputs = 1 if len (self .y .shape ) < 2 else \
291- self .y .shape [0 ]
292- else :
288+
289+ if len (self .y .shape ) == 3 :
290+ multi_trace = True
291+ self .noutputs = self .y .shape [0 ]
292+ self .ntraces = self .y .shape [1 ]
293+
294+ elif multi_trace and len (self .y .shape ) == 2 :
295+ self .noutputs = 1
296+ self .ntraces = self .y .shape [0 ]
297+
298+ elif not multi_trace and len (self .y .shape ) == 2 :
299+ self .noutputs = self .y .shape [0 ]
300+ self .ntraces = 1
301+
302+ elif not multi_trace and len (self .y .shape ) == 1 :
303+ self .nouptuts = 1
293304 self .ntraces = 1
294- self .noutputs = 1 if len (self .y .shape ) < 2 else \
295- self .y .shape [0 ]
296305
297- # Make sure time dimension of output is OK
306+ else :
307+ raise ValueError ("Output vector is the wrong shape" )
308+
309+ # Make sure time dimension of output is the right length
298310 if self .t .shape [- 1 ] != self .y .shape [- 1 ]:
299311 raise ValueError ("Output vector does not match time vector" )
300312
301- # State vector
302- self .x = np .array (states )
303- self .nstates = 0 if self .x is None else self .x .shape [0 ]
304- if self .t .shape [- 1 ] != self .x .shape [- 1 ]:
305- raise ValueError ("State vector does not match time vector" )
313+ #
314+ # State vector (optional)
315+ #
316+ # If present, the shape of the state vector should be consistent
317+ # with the multi-trace nature of the data.
318+ #
319+ if states is None :
320+ self .x = None
321+ self .nstates = 0
322+ else :
323+ self .x = np .array (states )
324+ self .nstates = self .x .shape [0 ]
325+
326+ # Make sure the shape is OK
327+ if multi_trace and len (self .x .shape ) != 3 or \
328+ not multi_trace and len (self .x .shape ) != 2 :
329+ raise ValueError ("State vector is the wrong shape" )
306330
307- # Input vector
308- # If no input is present, return an empty array
331+ # Make sure time dimension of state is the right length
332+ if self .t .shape [- 1 ] != self .x .shape [- 1 ]:
333+ raise ValueError ("State vector does not match time vector" )
334+
335+ #
336+ # Input vector (optional)
337+ #
338+ # If present, the shape and dimensions of the input vector should be
339+ # consistent with the trace count computed above.
340+ #
309341 if inputs is None :
310342 self .u = None
343+ self .ninputs = 0
344+
311345 else :
312346 self .u = np .array (inputs )
313347
314- if self .u is not None :
315- self .ninputs = 1 if len (self .u .shape ) < 2 \
316- else self .u .shape [- 2 ]
348+ # Make sure the shape is OK and figure out the nuumber of inputs
349+ if multi_trace and len (self .u .shape ) == 3 and \
350+ self .u .shape [1 ] == self .ntraces :
351+ self .ninputs = self .u .shape [0 ]
352+
353+ elif multi_trace and len (self .u .shape ) == 2 and \
354+ self .u .shape [0 ] == self .ntraces :
355+ self .ninputs = 1
356+
357+ elif not multi_trace and len (self .u .shape ) == 2 and \
358+ self .ntraces == 1 :
359+ self .ninputs = self .u .shape [0 ]
360+
361+ elif not multi_trace and len (self .u .shape ) == 1 :
362+ self .ninputs = 1
363+
364+ else :
365+ raise ValueError ("Input vector is the wrong shape" )
366+
367+ # Make sure time dimension of output is the right length
317368 if self .t .shape [- 1 ] != self .u .shape [- 1 ]:
318369 raise ValueError ("Input vector does not match time vector" )
319- else :
320- self .ninputs = 0
321370
322371 # If the system was specified, make sure it is compatible
323372 if sys is not None :
324373 if sys .noutputs != self .noutputs :
325374 ValueError ("System outputs do not match response data" )
326- if sys .nstates != self .nstates :
375+ if self . x is not None and sys .nstates != self .nstates :
327376 ValueError ("System states do not match response data" )
377+ if self .u is not None and sys .ninputs != self .ninputs :
378+ ValueError ("System inputs do not match response data" )
328379 self .sys = sys
329380
330381 # Keep track of whether to squeeze inputs, outputs, and states
0 commit comments