Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions control/tests/xferfcn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def test_constructor_bad_input_type(self):

with pytest.raises(TypeError, match="unsupported data type"):
ct.tf([1j], [1, 2, 3])
for dtype in [np.complex64, np.complex128]:
with pytest.raises(TypeError, match="unsupported data type"):
ct.tf(np.array([1 + 1j], dtype=dtype), [1, 2, 3])

# good input
TransferFunction([[[0, 1], [2, 3]],
Expand Down Expand Up @@ -1550,6 +1553,19 @@ def test_zpk(zeros, poles, gain, args, kwargs):
if kwargs.get('name'):
assert sys.name == kwargs.get('name')


@pytest.mark.parametrize("dtype", [np.complex64, np.complex128])
def test_zpk_complex_dtype_real_coefficients(dtype):
zeros = np.array([1 + 1j, 1 - 1j], dtype=dtype)
poles = np.array([-1 + 1j, -1 - 1j], dtype=dtype)

sys = ct.zpk(zeros, poles, gain=1, dt=0)

assert sys.num_array[0, 0].dtype == float
assert sys.den_array[0, 0].dtype == float
assert "float32" not in str(sys)


@pytest.mark.parametrize("create, args, kwargs, convert", [
(StateSpace, ([-1], [1], [1], [0]), {}, ss2tf),
(StateSpace, ([-1], [1], [1], [0]), {}, ss),
Expand Down
25 changes: 18 additions & 7 deletions control/xferfcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1960,7 +1960,7 @@ def _clean_part(data, name="<unknown>"):
if isinstance(data, np.ndarray) and data.ndim == 2 and \
data.dtype == object and isinstance(data[0, 0], np.ndarray):
# Data is already in the right format
return data
out = data
elif isinstance(data, ndarray) and data.ndim == 3 and \
isinstance(data[0, 0, 0], valid_types):
out = np.empty(data.shape[0:2], dtype=np.ndarray)
Expand Down Expand Up @@ -1995,15 +1995,26 @@ def _clean_part(data, name="<unknown>"):
"The numerator and denominator inputs must be scalars or vectors "
"(for\nSISO), or lists of lists of vectors (for SISO or MIMO).")

# Check for coefficients that are ints and convert to floats
# Check for real numeric coefficients and normalize floating arrays
for i in range(out.shape[0]):
for j in range(out.shape[1]):
for k in range(len(out[i, j])):
if isinstance(out[i, j][k], (int, np.integer)):
out[i, j][k] = float(out[i, j][k])
elif isinstance(out[i, j][k], unsupported_types):
coefficients = np.asarray(out[i, j])
convert_to_float = np.issubdtype(coefficients.dtype, np.floating)
if np.iscomplexobj(coefficients):
real_coefficients = coefficients.real
zero_tol = 1000 * np.finfo(float).eps * max(
1, np.max(np.abs(real_coefficients)))
if np.any(np.abs(coefficients.imag) > zero_tol):
raise TypeError(
f"unsupported data type: {type(out[i, j][k])}")
f"unsupported data type: {type(coefficients.flat[0])}")
coefficients = real_coefficients
convert_to_float = True
for k in range(len(coefficients)):
if isinstance(coefficients[k], unsupported_types):
raise TypeError(
f"unsupported data type: {type(coefficients[k])}")
out[i, j] = np.asarray(
coefficients, dtype=float) if convert_to_float else coefficients
return out


Expand Down