Skip to content

Commit 1c1ce0c

Browse files
committed
implement {control, disturbance}_indices in oep
1 parent e165dd5 commit 1c1ce0c

7 files changed

Lines changed: 242 additions & 125 deletions

File tree

control/namedio.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,7 @@ def _process_signal_list(signals, prefix='s'):
587587

588588

589589
#
590-
# Utility function to process signal indices
590+
# Utility functions to process signal indices
591591
#
592592
# Signal indices can be specified in one of four ways:
593593
#
@@ -629,3 +629,40 @@ def _process_indices(arg, name, labels, length):
629629
return arg
630630

631631
raise ValueError(f"invalid argument for {name}_indices")
632+
633+
#
634+
# Process control and disturbance indices
635+
#
636+
# For systems with inputs and disturbances, the control_indices and
637+
# disturbance_indices keywords are used to specify which is which. If only
638+
# one is given, the other is assumed to be the remaining indices in the
639+
# system input. If neither is given, the disturbance inputs are assumed to
640+
# be the same as the control inputs.
641+
#
642+
def _process_control_disturbance_indices(
643+
sys, control_indices, disturbance_indices):
644+
645+
if control_indices is None and disturbance_indices is None:
646+
# Disturbances enter in the same place as the controls
647+
dist_idx = ctrl_idx = list(range(sys.ninputs))
648+
649+
elif control_indices is not None:
650+
# Process the control indices
651+
ctrl_idx = _process_indices(
652+
control_indices, 'control', sys.input_labels, sys.ninputs)
653+
654+
# Disturbance indices are the complement of control indices
655+
dist_idx = [i for i in range(sys.ninputs) if i not in ctrl_idx]
656+
657+
else: # disturbance_indices is not None
658+
# If passed an integer, count from the end of the input vector
659+
arg = -disturbance_indices if isinstance(disturbance_indices, int) \
660+
else disturbance_indices
661+
662+
dist_idx = _process_indices(
663+
arg, 'disturbance', sys.input_labels, sys.ninputs)
664+
665+
# Set control indices to complement disturbance indices
666+
ctrl_idx = [i for i in range(sys.ninputs) if i not in dist_idx]
667+
668+
return ctrl_idx, dist_idx

control/optimal.py

Lines changed: 160 additions & 86 deletions
Large diffs are not rendered by default.

control/stochsys.py

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222

2323
from .iosys import InputOutputSystem, LinearIOSystem, NonlinearIOSystem
2424
from .lti import LTI
25-
from .namedio import isctime, isdtime, _process_indices
25+
from .namedio import isctime, isdtime
26+
from .namedio import _process_indices, _process_control_disturbance_indices
2627
from .mateqn import care, dare, _check_shape
2728
from .statesp import StateSpace, _ssmatrix
2829
from .exception import ControlArgument, ControlNotImplemented
@@ -308,9 +309,8 @@ def dlqe(*args, **kwargs):
308309

309310
# Function to create an estimator
310311
#
311-
# TODO: add `control_indices` keyword to match create_mhe_iosystem (?)
312-
# TODO: change name to create_kalmanestimaor_iosystem (?)
313312
# TODO: create predictor/corrector, UKF, and other variants (?)
313+
# TODO: change *_labels to *_fmtstr and use signal keywords instead
314314
#
315315
def create_estimator_iosystem(
316316
sys, QN, RN, P0=None, G=None, C=None,
@@ -441,35 +441,9 @@ def create_estimator_iosystem(
441441
# Set the state matrix for later use
442442
A = sys.A
443443

444-
# Set the disturbance matrices (indices take priority over G)
445-
ctrl_idx = _process_indices(
446-
control_indices, 'control', sys.input_labels, sys.ninputs)
447-
448-
if disturbance_indices is None and control_indices is not None:
449-
# Disturbance indices are the complement of control indices
450-
dist_idx = [i for i in range(sys.ninputs) if i not in ctrl_idx]
451-
if G is not None:
452-
warn("'control_indices' and 'G' both specified; ignoring 'G'")
453-
G = sys.B[:, dist_idx]
454-
455-
elif disturbance_indices is not None:
456-
if G is not None:
457-
warn("'disturbance_indices' and 'G' both specified; ignoring 'G'")
458-
459-
# If passed an integer, count from the end of the input vector
460-
arg = -disturbance_indices if isinstance(disturbance_indices, int) \
461-
else disturbance_indices
462-
463-
dist_idx = _process_indices(
464-
arg, 'disturbance', sys.input_labels, sys.ninputs)
465-
G = sys.B[:, dist_idx]
466-
467-
# Set control indices to complement disturbance indices, if needed
468-
if control_indices is None:
469-
ctrl_idx = [i for i in range(sys.ninputs) if i not in dist_idx]
470-
471-
elif G is None:
472-
G = sys.B
444+
# Determine the control and disturbance indices
445+
ctrl_idx, dist_idx = _process_control_disturbance_indices(
446+
sys, control_indices, disturbance_indices)
473447

474448
# Set the input and direct matrices
475449
B = sys.B[:, ctrl_idx]
@@ -491,6 +465,10 @@ def create_estimator_iosystem(
491465
if sensor_labels is None:
492466
sensor_labels = sys.output_labels
493467

468+
# Generate the disturbance matrix (G)
469+
if G is None:
470+
G = sys.B if len(dist_idx) == 0 else sys.B[:, dist_idx]
471+
494472
# Initialize the covariance matrix
495473
if P0 is None:
496474
# Initalize P0 to the steady state value

control/tests/kwargs_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,8 @@ def test_matplotlib_kwargs(function, nsysargs, moreargs, kwargs, mplcleanup):
221221
optimal_test.test_ocp_argument_errors,
222222
'optimal.OptimalEstimationProblem.__init__':
223223
optimal_test.test_oep_argument_errors,
224+
'optimal.OptimalEstimationProblem.create_mhe_iosystem':
225+
optimal_test.test_oep_argument_errors,
224226
}
225227

226228
#

control/tests/optimal_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -786,3 +786,8 @@ def test_oep_argument_errors():
786786

787787
with pytest.raises(TypeError, match="unrecognized keyword"):
788788
oep = opt.OptimalEstimationProblem(sys, timepts, cost, unknown=True)
789+
790+
with pytest.raises(TypeError, match="unrecognized keyword"):
791+
sys = ct.rss(4, 2, 2, dt=True)
792+
oep = opt.OptimalEstimationProblem(sys, timepts, cost)
793+
oep.create_mhe_iosystem(unknown=True)

control/tests/stochsys_test.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,24 @@ def test_oep(dt):
430430
np.testing.assert_allclose(
431431
est2.states[:, -1], res1.states[:, -1], atol=1e-1, rtol=1e-2)
432432

433+
# Change around the inputs and disturbances
434+
sys2 = ct.ss(sys.A, sys.B[:, ::-1], sys.C, sys.D[::-1], sys.dt)
435+
oep2a = opt.OptimalEstimationProblem(
436+
sys2, timepts, traj_cost, terminal_cost=init_cost,
437+
control_indices=[1])
438+
est2a = oep2a.compute_estimate(
439+
Y1, U, initial_guess=(est2.states, est2.inputs))
440+
np.testing.assert_allclose(
441+
est2a.states[:, -1], res1.states[:, -1], atol=1e-1, rtol=1e-2)
442+
443+
oep2b = opt.OptimalEstimationProblem(
444+
sys2, timepts, traj_cost, terminal_cost=init_cost,
445+
disturbance_indices=[0])
446+
est2b = oep2b.compute_estimate(
447+
Y1, U, initial_guess=(est2.states, est2.inputs))
448+
np.testing.assert_allclose(
449+
est2b.states[:, -1], res1.states[:, -1], atol=1e-1, rtol=1e-2)
450+
433451
#
434452
# Disturbance constraints
435453
#
@@ -483,8 +501,9 @@ def test_mhe():
483501
traj_cost = opt.gaussian_likelihood_cost(sys, Rv, Rw)
484502
init_cost = lambda xhat, x: (xhat - x) @ P0 @ (xhat - x)
485503
oep = opt.OptimalEstimationProblem(
486-
sys, mhe_timepts, traj_cost, terminal_cost=init_cost)
487-
mhe = oep.create_mhe_iosystem(1)
504+
sys, mhe_timepts, traj_cost, terminal_cost=init_cost,
505+
disturbance_indices=1)
506+
mhe = oep.create_mhe_iosystem()
488507

489508
# Generate system data
490509
U = 10 * np.sin(timepts / (4*dt))

examples/mhe-pvtol.ipynb

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -617,8 +617,9 @@
617617
"source": [
618618
"mhe_timepts = timepts[0:10]\n",
619619
"oep = opt.OptimalEstimationProblem(\n",
620-
" dsys, mhe_timepts, traj_cost, terminal_cost=init_cost)\n",
621-
"mhe = oep.create_mhe_iosystem(2)\n",
620+
" dsys, mhe_timepts, traj_cost, terminal_cost=init_cost,\n",
621+
" disturbance_indices=slice(2, 4))\n",
622+
"mhe = oep.create_mhe_iosystem()\n",
622623
" \n",
623624
"mhe_resp = ct.input_output_response(\n",
624625
" mhe, timepts, [Y, U], X0=x0, \n",
@@ -650,8 +651,9 @@
650651
"source": [
651652
"mhe_timepts = timepts[0:8]\n",
652653
"oep = opt.OptimalEstimationProblem(\n",
653-
" dsys, mhe_timepts, traj_cost, terminal_cost=init_cost)\n",
654-
"mhe = oep.create_mhe_iosystem(2)\n",
654+
" dsys, mhe_timepts, traj_cost, terminal_cost=init_cost,\n",
655+
" disturbance_indices=slice(2, 4))\n",
656+
"mhe = oep.create_mhe_iosystem()\n",
655657
" \n",
656658
"mhe_resp = ct.input_output_response(\n",
657659
" mhe, timepts, [Y, U],\n",

0 commit comments

Comments
 (0)