Skip to content

Commit 5707dcf

Browse files
authored
Merge pull request #1109 from bnavigator/assert_tf_close_coeff
Move _tf_close_coeff back to testing realm and make better use of assertion messages
2 parents d2f9a9c + 2359299 commit 5707dcf

File tree

7 files changed

+131
-154
lines changed

7 files changed

+131
-154
lines changed

.github/scripts/set-conda-test-matrix.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
1-
""" set-conda-test-matrix.py
1+
"""Create test matrix for conda packages in OS/BLAS test matrix workflow."""
22

3-
Create test matrix for conda packages
4-
"""
5-
import json, re
63
from pathlib import Path
74

85
osmap = {'linux': 'ubuntu',

.github/scripts/set-pip-test-matrix.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
""" set-pip-test-matrix.py
1+
"""Create test matrix for pip wheels in OS/BLAS test matrix workflow."""
22

3-
Create test matrix for pip wheels
4-
"""
53
import json
64
from pathlib import Path
75

control/tests/bdalg_test.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,21 @@
1-
"""bdalg_test.py - test suite for block diagram algebra
1+
"""bdalg_test.py - test suite for block diagram algebra.
22
33
RMM, 30 Mar 2011 (based on TestBDAlg from v0.4a)
44
"""
55

6+
import control as ctrl
67
import numpy as np
7-
from numpy import sort
88
import pytest
9-
10-
import control as ctrl
11-
from control.xferfcn import TransferFunction, _tf_close_coeff
9+
from control.bdalg import _ensure_tf, append, connect, feedback
10+
from control.lti import poles, zeros
1211
from control.statesp import StateSpace
13-
from control.bdalg import feedback, append, connect
14-
from control.lti import zeros, poles
15-
from control.bdalg import _ensure_tf
12+
from control.tests.conftest import assert_tf_close_coeff
13+
from control.xferfcn import TransferFunction
14+
from numpy import sort
1615

1716

1817
class TestFeedback:
19-
"""These are tests for the feedback function in bdalg.py. Currently, some
20-
of the tests are not implemented, or are not working properly. TODO: these
21-
need to be fixed."""
18+
"""Tests for the feedback function in bdalg.py."""
2219

2320
@pytest.fixture
2421
def tsys(self):
@@ -180,7 +177,7 @@ def testTFTF(self, tsys):
180177
[[[1., 4., 9., 8., 5.]]])
181178

182179
def testLists(self, tsys):
183-
"""Make sure that lists of various lengths work for operations"""
180+
"""Make sure that lists of various lengths work for operations."""
184181
sys1 = ctrl.tf([1, 1], [1, 2])
185182
sys2 = ctrl.tf([1, 3], [1, 4])
186183
sys3 = ctrl.tf([1, 5], [1, 6])
@@ -237,7 +234,7 @@ def testLists(self, tsys):
237234
sort(zeros(sys1 + sys2 + sys3 + sys4 + sys5)))
238235

239236
def testMimoSeries(self, tsys):
240-
"""regression: bdalg.series reverses order of arguments"""
237+
"""regression: bdalg.series reverses order of arguments."""
241238
g1 = ctrl.ss([], [], [], [[1, 2], [0, 3]])
242239
g2 = ctrl.ss([], [], [], [[1, 0], [2, 3]])
243240
ref = g2 * g1
@@ -430,9 +427,9 @@ class TestEnsureTf:
430427
],
431428
)
432429
def test_ensure(self, arraylike_or_tf, dt, tf):
433-
"""Test nominal cases"""
430+
"""Test nominal cases."""
434431
ensured_tf = _ensure_tf(arraylike_or_tf, dt)
435-
assert _tf_close_coeff(tf, ensured_tf)
432+
assert_tf_close_coeff(tf, ensured_tf)
436433

437434
@pytest.mark.parametrize(
438435
"arraylike_or_tf, dt, exception",
@@ -460,7 +457,7 @@ def test_ensure(self, arraylike_or_tf, dt, tf):
460457
],
461458
)
462459
def test_error_ensure(self, arraylike_or_tf, dt, exception):
463-
"""Test error cases"""
460+
"""Test error cases."""
464461
with pytest.raises(exception):
465462
_ensure_tf(arraylike_or_tf, dt)
466463

@@ -624,7 +621,7 @@ class TestTfCombineSplit:
624621
def test_combine_tf(self, tf_array, tf):
625622
"""Test combining transfer functions."""
626623
tf_combined = ctrl.combine_tf(tf_array)
627-
assert _tf_close_coeff(tf_combined, tf)
624+
assert_tf_close_coeff(tf_combined, tf)
628625

629626
@pytest.mark.parametrize(
630627
"tf_array, tf",
@@ -712,12 +709,12 @@ def test_split_tf(self, tf_array, tf):
712709
# Test entry-by-entry
713710
for i in range(tf_split.shape[0]):
714711
for j in range(tf_split.shape[1]):
715-
assert _tf_close_coeff(
712+
assert_tf_close_coeff(
716713
tf_split[i, j],
717714
tf_array[i, j],
718715
)
719716
# Test combined
720-
assert _tf_close_coeff(
717+
assert_tf_close_coeff(
721718
ctrl.combine_tf(tf_split),
722719
ctrl.combine_tf(tf_array),
723720
)

control/tests/conftest.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""conftest.py - pytest local plugins and fixtures"""
1+
"""conftest.py - pytest local plugins, fixtures, marks and functions."""
22

33
import os
44
from contextlib import contextmanager
@@ -9,6 +9,7 @@
99

1010
import control
1111

12+
1213
# some common pytest marks. These can be used as test decorators or in
1314
# pytest.param(marks=)
1415
slycotonly = pytest.mark.skipif(
@@ -61,7 +62,7 @@ def mplcleanup():
6162

6263
@pytest.fixture(scope="function")
6364
def legacy_plot_signature():
64-
"""Turn off warnings for calls to plotting functions with old signatures"""
65+
"""Turn off warnings for calls to plotting functions with old signatures."""
6566
import warnings
6667
warnings.filterwarnings(
6768
'ignore', message='passing systems .* is deprecated',
@@ -75,14 +76,51 @@ def legacy_plot_signature():
7576

7677
@pytest.fixture(scope="function")
7778
def ignore_future_warning():
78-
"""Turn off warnings for functions that generate FutureWarning"""
79+
"""Turn off warnings for functions that generate FutureWarning."""
7980
import warnings
8081
warnings.filterwarnings(
8182
'ignore', message='.*deprecated', category=FutureWarning)
8283
yield
8384
warnings.resetwarnings()
84-
8585

86-
# Allow pytest.mark.slow to mark slow tests (skip with pytest -m "not slow")
86+
8787
def pytest_configure(config):
88+
"""Allow pytest.mark.slow to mark slow tests.
89+
90+
skip with pytest -m "not slow"
91+
"""
8892
config.addinivalue_line("markers", "slow: mark test as slow to run")
93+
94+
95+
def assert_tf_close_coeff(actual, desired, rtol=1e-5, atol=1e-8):
96+
"""Check if two transfer functions have close coefficients.
97+
98+
Parameters
99+
----------
100+
actual, desired : TransferFunction
101+
Transfer functions to compare.
102+
rtol : float
103+
Relative tolerance for ``np.testing.assert_allclose``.
104+
atol : float
105+
Absolute tolerance for ``np.testing.assert_allclose``.
106+
107+
Raises
108+
------
109+
AssertionError
110+
"""
111+
# Check number of outputs and inputs
112+
assert actual.noutputs == desired.noutputs
113+
assert actual.ninputs == desired.ninputs
114+
# Check timestep
115+
assert actual.dt == desired.dt
116+
# Check coefficient arrays
117+
for i in range(actual.noutputs):
118+
for j in range(actual.ninputs):
119+
np.testing.assert_allclose(
120+
actual.num[i][j],
121+
desired.num[i][j],
122+
rtol=rtol, atol=atol)
123+
np.testing.assert_allclose(
124+
actual.den[i][j],
125+
desired.den[i][j],
126+
rtol=rtol, atol=atol)

0 commit comments

Comments
 (0)