Skip to content

Commit 21082ab

Browse files
committed
MNT: simplify and add tests for cbook._check_shape
1 parent 351f47e commit 21082ab

File tree

2 files changed

+27
-13
lines changed

2 files changed

+27
-13
lines changed

lib/matplotlib/cbook/__init__.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2197,24 +2197,20 @@ def _check_shape(_shape, **kwargs):
21972197
for k, v in kwargs.items():
21982198
data_shape = v.shape
21992199

2200-
if not (
2201-
len(data_shape) == len(target_shape)
2202-
and all(
2203-
(t == s if t is not None else True)
2200+
if len(target_shape) != len(data_shape) or any(
2201+
t not in [s, None]
22042202
for t, s in zip(target_shape, data_shape)
2205-
)
22062203
):
2207-
def format_dims(target_shape):
2208-
dim_labels = iter(itertools.chain(
2209-
'MNLIJKLH',
2210-
(f"D{i}" for i in itertools.count())))
2211-
text_shape = tuple(n if n is not None else next(dim_labels)
2212-
for n in target_shape)
2213-
return '(' + ", ".join(str(_) for _ in text_shape) + ')'
2204+
dim_labels = iter(itertools.chain(
2205+
'MNLIJKLH',
2206+
(f"D{i}" for i in itertools.count())))
2207+
text_shape = ", ".join(str(_) for _ in
2208+
(n if n is not None else next(dim_labels)
2209+
for n in target_shape))
22142210

22152211
raise ValueError(
22162212
f"{k!r} must be {len(target_shape)}D "
2217-
f"with shape {format_dims(target_shape)}. "
2213+
f"with shape ({text_shape}). "
22182214
f"Your input has shape {v.shape}."
22192215
)
22202216

lib/matplotlib/tests/test_cbook.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import itertools
22
import pickle
3+
import re
4+
35
from weakref import ref
46
from unittest.mock import patch, Mock
57

@@ -633,3 +635,19 @@ def divisors(n):
633635
for rstride, cstride in itertools.product(divisors(rows - 1),
634636
divisors(cols - 1)):
635637
check(x, rstride=rstride, cstride=cstride)
638+
639+
640+
@pytest.mark.parametrize('target,test_shape',
641+
[((None, ), (1, 3)),
642+
((None, 3), (1,)),
643+
((None, 3), (1, 2)),
644+
((1, 5), (1, 9)),
645+
((None, 2, None), (1, 3, 1))
646+
])
647+
def test_check_shape(target, test_shape):
648+
error_pattern = (f"^'aardvark' must be {len(target)}D.*" +
649+
re.escape(f'has shape {test_shape}'))
650+
data = np.zeros(test_shape)
651+
with pytest.raises(ValueError,
652+
match=error_pattern):
653+
cbook._check_shape(target, aardvark=data)

0 commit comments

Comments
 (0)