Skip to content

Commit e165dd5

Browse files
committed
add {control, disturbance}_indices to create_estimator_iosystem
1 parent 8375127 commit e165dd5

3 files changed

Lines changed: 189 additions & 30 deletions

File tree

control/namedio.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
'namedio.sampled_system_name_prefix': '',
2323
'namedio.sampled_system_name_suffix': '$sampled'
2424
}
25-
26-
25+
26+
2727
class NamedIOSystem(object):
2828
def __init__(
2929
self, name=None, inputs=None, outputs=None, states=None, **kwargs):
@@ -586,21 +586,46 @@ def _process_signal_list(signals, prefix='s'):
586586
raise TypeError("Can't parse signal list %s" % str(signals))
587587

588588

589+
#
589590
# Utility function to process signal indices
590-
def _process_indices(arg, name, labels, default=None):
591-
arg = default if arg is None else arg
592-
if arg is None:
593-
return None;
591+
#
592+
# Signal indices can be specified in one of four ways:
593+
#
594+
# 1. As a positive integer 'm', in which case we return a list
595+
# corresponding to the first 'm' elements of a range of a given length
596+
#
597+
# 2. As a negative integer '-m', in which case we return a list
598+
# corresponding to the last 'm' elements of a range of a given length
599+
#
600+
# 3. As a slice, in which case we return the a list corresponding to the
601+
# indices specified by the slice of a range of a given length
602+
#
603+
# 4. As a list of ints or strings specifying specific indices. Strings are
604+
# compared to a list of labels to determine the index.
605+
#
606+
def _process_indices(arg, name, labels, length):
607+
# Default is to return indices up to a certain length
608+
arg = length if arg is None else arg
594609

595610
if isinstance(arg, int):
596-
return range(arg)
611+
# Return the start or end of the list of possible indices
612+
return list(range(arg)) if arg > 0 else list(range(length))[arg:]
613+
597614
elif isinstance(arg, slice):
598-
return arg
615+
# Return the indices referenced by the slice
616+
return list(range(length))[arg]
617+
599618
elif isinstance(arg, list):
619+
# Make sure the length is OK
620+
if len(arg) > length:
621+
raise ValueError(
622+
f"{name}_indices list is too long; max length = {length}")
623+
624+
# Return the list, replacing strings with corresponding indices
600625
arg=arg.copy()
601626
for i, idx in enumerate(arg):
602627
if isinstance(idx, str):
603628
arg[i] = labels.index(arg[i])
604629
return arg
605-
else:
606-
raise ValueError(f"invalid argument for {name}_indices")
630+
631+
raise ValueError(f"invalid argument for {name}_indices")

control/stochsys.py

Lines changed: 93 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from .iosys import InputOutputSystem, LinearIOSystem, NonlinearIOSystem
2424
from .lti import LTI
25-
from .namedio import isctime, isdtime
25+
from .namedio import isctime, isdtime, _process_indices
2626
from .mateqn import care, dare, _check_shape
2727
from .statesp import StateSpace, _ssmatrix
2828
from .exception import ControlArgument, ControlNotImplemented
@@ -314,6 +314,7 @@ def dlqe(*args, **kwargs):
314314
#
315315
def create_estimator_iosystem(
316316
sys, QN, RN, P0=None, G=None, C=None,
317+
control_indices=None, disturbance_indices=None,
317318
state_labels='xhat[{i}]', output_labels='xhat[{i}]',
318319
covariance_labels='P[{i},{j}]', sensor_labels=None):
319320
r"""Create an I/O system implementing a linear quadratic estimator
@@ -347,9 +348,10 @@ def create_estimator_iosystem(
347348
348349
Parameters
349350
----------
350-
sys : InputOutputSystem
351-
The I/O system that represents the process dynamics. If no estimator
352-
is given, the output of this system should represent the full state.
351+
sys : LinearIOSystem
352+
The linear I/O system that represents the process dynamics. If no
353+
estimator is given, the output of this system should represent the
354+
full state.
353355
QN, RN : ndarray
354356
Process and sensor noise covariance matrices.
355357
P0 : ndarray, optional
@@ -362,14 +364,6 @@ def create_estimator_iosystem(
362364
If the system has full state output, define the measured values to
363365
be used by the estimator. Otherwise, use the system output as the
364366
measured values.
365-
{state, covariance, sensor, output}_labels : str or list of str, optional
366-
Set the name of the signals to use for the internal state, covariance,
367-
sensors, and outputs (state estimate). If a single string is
368-
specified, it should be a format string using the variable `i` as an
369-
index (or `i` and `j` for covariance). Otherwise, a list of
370-
strings matching the size of the respective signal should be used.
371-
Default is ``'xhat[{i}]'`` for state and output labels, ``'y[{i}]'``
372-
for output labels and ``'P[{i},{j}]'`` for covariance labels.
373367
374368
Returns
375369
-------
@@ -378,6 +372,47 @@ def create_estimator_iosystem(
378372
the system output y and input u and generates the estimated state
379373
xhat.
380374
375+
Other Parameters
376+
----------------
377+
control_indices : int, slice, or list of int or string, optional
378+
Specify the indices in the system input vector that correspond to
379+
the control inputs. These inputs will be used as known control
380+
inputs for the estimator. If value is an integer `m`, the first `m`
381+
system inputs are used. Otherwise, the value should be a slice or
382+
a list of indices. The list of indices can be specified as either
383+
integer offsets or as system input signal names. If not specified,
384+
defaults to the system inputs.
385+
disturbance_indices : int, list of int, or slice, optional
386+
Specify the indices in the system input vector that correspond to
387+
the unknown disturbances. These inputs are assumed to be white
388+
noise with noise intensity QN. If value is an integer `m`, the
389+
last `m` system inputs are used. Otherwise, the value should be a
390+
slice or a list of indices. The list of indices can be specified
391+
as either integer offsets or as system input signal names. If not
392+
specified, the disturbances are assumed to be added to the system
393+
inputs.
394+
state_labels : str or list of str, optional
395+
Set the names of the internal state estimate variables. If a
396+
single string is specified, it should be a format string using the
397+
variable `i` as an index. Otherwise, a list of strings matching
398+
the number of system states should be used. Default is "xhat[{i}]".
399+
covariance_labels : str or list of str, optional
400+
Set the name of the the covariance state variables. If a single
401+
string is specified, it should be a format string using the
402+
variables `i` and `j` as indices. Otherwise, a list of strings
403+
matching the size of the covariance matrix should be used. Default
404+
is "P[{i},{j}]".
405+
sensor_labels : str or list of str, optional
406+
Set the name of the sensor signals (estimator inputs). If
407+
specified, it should be a format string using the variable `i` as
408+
an index. Otherwise, a list of strings matching the size of the
409+
measured system outputs should be used. Default is "y[{i}]".
410+
output_labels : str or list of str, optional
411+
Set the name of the estimator outputs (state estimate). If a
412+
single string is specified, it should be a format string using the
413+
variable `i` as an index. Otherwise, a list of strings matching
414+
the size of the system state should be used. Default is "xhat[{i}]".
415+
381416
Notes
382417
-----
383418
This function can be used with the ``create_statefbk_iosystem()`` function
@@ -403,11 +438,45 @@ def create_estimator_iosystem(
403438
if not isinstance(sys, LinearIOSystem):
404439
raise ControlArgument("Input system must be a linear I/O system")
405440

406-
# Extract the matrices that we need for easy reference
407-
A, B = sys.A, sys.B
441+
# Set the state matrix for later use
442+
A = sys.A
443+
444+
# Set the disturbance matrices (indices take priority over G)
445+
ctrl_idx = _process_indices(
446+
control_indices, 'control', sys.input_labels, sys.ninputs)
447+
448+
if disturbance_indices is None and control_indices is not None:
449+
# Disturbance indices are the complement of control indices
450+
dist_idx = [i for i in range(sys.ninputs) if i not in ctrl_idx]
451+
if G is not None:
452+
warn("'control_indices' and 'G' both specified; ignoring 'G'")
453+
G = sys.B[:, dist_idx]
454+
455+
elif disturbance_indices is not None:
456+
if G is not None:
457+
warn("'disturbance_indices' and 'G' both specified; ignoring 'G'")
458+
459+
# If passed an integer, count from the end of the input vector
460+
arg = -disturbance_indices if isinstance(disturbance_indices, int) \
461+
else disturbance_indices
408462

409-
# Set the disturbance and output matrices
410-
G = sys.B if G is None else G
463+
dist_idx = _process_indices(
464+
arg, 'disturbance', sys.input_labels, sys.ninputs)
465+
G = sys.B[:, dist_idx]
466+
467+
# Set control indices to complement disturbance indices, if needed
468+
if control_indices is None:
469+
ctrl_idx = [i for i in range(sys.ninputs) if i not in dist_idx]
470+
471+
elif G is None:
472+
G = sys.B
473+
474+
# Set the input and direct matrices
475+
B = sys.B[:, ctrl_idx]
476+
if not np.allclose(sys.D, 0):
477+
raise NotImplemented("nonzero 'D' matrix not yet implemented")
478+
479+
# Set the output matrices
411480
if C is not None:
412481
# Make sure that we have the full system output
413482
if not np.array_equal(sys.C, np.eye(sys.nstates)):
@@ -425,7 +494,7 @@ def create_estimator_iosystem(
425494
# Initialize the covariance matrix
426495
if P0 is None:
427496
# Initalize P0 to the steady state value
428-
L0, P0, _ = lqe(A, G, C, QN, RN)
497+
_, P0, _ = lqe(A, G, C, QN, RN)
429498

430499
# Figure out the labels to use
431500
if isinstance(state_labels, str):
@@ -447,6 +516,10 @@ def create_estimator_iosystem(
447516
# Generate the list of labels using the argument as a format string
448517
sensor_labels = [sensor_labels.format(i=i) for i in range(C.shape[0])]
449518

519+
# Set the input labels based on the system input
520+
# TODO: allow these to be overriden
521+
input_labels = [sys.input_labels[i] for i in ctrl_idx]
522+
450523
if isctime(sys):
451524
# Create an I/O system for the state feedback gains
452525
# Note: reshape vectors into column vectors for legacy np.matrix
@@ -470,7 +543,7 @@ def _estim_update(t, x, u, params):
470543
L = P @ C.T @ R_inv
471544

472545
# Update the state estimate
473-
dxhat = A @ xhat + B @ u # prediction
546+
dxhat = A @ xhat + B @ u # prediction
474547
if correct:
475548
dxhat -= L @ (C @ xhat - y) # correction
476549

@@ -500,7 +573,7 @@ def _estim_update(t, x, u, params):
500573
L = A @ P @ C.T @ Reps_inv
501574

502575
# Update the state estimate
503-
dxhat = A @ xhat + B @ u # prediction
576+
dxhat = A @ xhat + B @ u # prediction
504577
if correct:
505578
dxhat -= L @ (C @ xhat - y) # correction
506579

@@ -518,7 +591,7 @@ def _estim_output(t, x, u, params):
518591
# Define the estimator system
519592
return NonlinearIOSystem(
520593
_estim_update, _estim_output, states=state_labels + covariance_labels,
521-
inputs=sensor_labels + sys.input_labels, outputs=output_labels,
594+
inputs=sensor_labels + input_labels, outputs=output_labels,
522595
dt=sys.dt)
523596

524597

control/tests/stochsys_test.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ def test_correlation():
319319
T = np.logspace(0, 2, T.size)
320320
tau, Rtau = ct.correlation(T, V)
321321

322+
@pytest.mark.slow
322323
@pytest.mark.parametrize('dt', [0, 1])
323324
def test_oep(dt):
324325
# Define the system to test, with additional input
@@ -455,6 +456,7 @@ def test_oep(dt):
455456
est3.states[:, -1], res3.states[:, -1], atol=1e-1, rtol=1e-2)
456457

457458

459+
@pytest.mark.slow
458460
def test_mhe():
459461
# Define the system to test, with additional input
460462
csys = ct.ss(
@@ -495,3 +497,62 @@ def test_mhe():
495497

496498
# Make sure the estimated state is close to the actual state
497499
np.testing.assert_allclose(estp.outputs, resp.states, atol=1e-2, rtol=1e-4)
500+
501+
@pytest.mark.parametrize("ctrl_indices, dist_indices", [
502+
(slice(0, 3), None),
503+
(3, None),
504+
(None, 2),
505+
([0, 1, 4], None),
506+
(['u[0]', 'u[1]', 'u[4]'], None),
507+
(['u[0]', 'u[1]', 'u[4]'], ['u[1]', 'u[3]']),
508+
(slice(0, 3), slice(3, 5))
509+
])
510+
def test_indices(ctrl_indices, dist_indices):
511+
# Define a system with inputs (0:3), disturbances (3:5), and noise (5, 7)
512+
ninputs = 3
513+
nstates = ninputs + 1
514+
ndisturbances = 2
515+
noutputs = 2
516+
nnoises = 0
517+
# TODO: remove strictly proper
518+
sys = ct.rss(nstates, noutputs, ninputs + ndisturbances + nnoises, strictly_proper=True)
519+
520+
# Create a system whose state we want to estimate
521+
if ctrl_indices is not None:
522+
ctrl_idx = ct.namedio._process_indices(
523+
ctrl_indices, 'control', sys.input_labels, sys.ninputs)
524+
else:
525+
arg = -dist_indices if isinstance(dist_indices, int) else dist_indices
526+
dist_idx = ct.namedio._process_indices(
527+
arg, 'disturbance', sys.input_labels, sys.ninputs)
528+
ctrl_idx = [i for i in range(sys.ninputs) if i not in dist_idx]
529+
sysm = ct.ss(sys.A, sys.B[:, ctrl_idx], sys.C, sys.D[:, ctrl_idx])
530+
531+
# Set the simulation time based on the slowest system pole
532+
from math import log
533+
T = 10 / min(-sys.poles().real)
534+
535+
# Generate a system response with no disturbances
536+
timepts = np.linspace(0, T, 50)
537+
U = np.vstack([np.sin(timepts + i) for i in range(ninputs)])
538+
resp = ct.input_output_response(
539+
sysm, timepts, U, np.zeros(nstates),
540+
solve_ivp_kwargs={'method': 'RK45', 'max_step': 0.01,
541+
'atol': 1, 'rtol': 1})
542+
Y = resp.outputs
543+
544+
# Create an estimator
545+
QN = np.eye(ndisturbances)
546+
RN = np.eye(noutputs)
547+
P0 = np.eye(nstates)
548+
estim = ct.create_estimator_iosystem(
549+
sys, QN, RN, control_indices=ctrl_indices,
550+
disturbance_indices=dist_indices)
551+
552+
# Run estimator (no prediction + same solve_ivp params => should be exact)
553+
resp_estim = ct.input_output_response(
554+
estim, timepts, [Y, U], [np.zeros(nstates), P0],
555+
solve_ivp_kwargs={'method': 'RK45', 'max_step': 0.01,
556+
'atol': 1, 'rtol': 1},
557+
params={'correct': False})
558+
np.testing.assert_allclose(resp.states, resp_estim.outputs, rtol=1e-2)

0 commit comments

Comments
 (0)