Skip to content

Commit c030aff

Browse files
committed
Handle 0d categorical data similarly to 0d numerical data (see statsmodels/statsmodels#1881)
1 parent 9eb464a commit c030aff

File tree

2 files changed

+62
-11
lines changed

2 files changed

+62
-11
lines changed

patsy/categorical.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,20 @@ def test_guess_categorical():
135135
assert not guess_categorical([1.0, 2.0, 3.0])
136136
assert not guess_categorical([1.0, 2.0, 3.0, np.nan])
137137

138+
def _categorical_shape_fix(data):
139+
# helper function
140+
# data should not be a _CategoricalBox or pandas Categorical or anything
141+
# -- it should be an actual iterable of data, but which might have the
142+
# wrong shape.
143+
if hasattr(data, "ndim") and data.ndim > 1:
144+
raise PatsyError("categorical data cannot be >1-dimensional")
145+
# coerce scalars into 1d, which is consistent with what we do for numeric
146+
# factors. (See statsmodels/statsmodels#1881)
147+
if (not iterable(data)
148+
or isinstance(data, (six.text_type, six.binary_type))):
149+
data = [data]
150+
return data
151+
138152
class CategoricalSniffer(object):
139153
def __init__(self, NA_action, origin=None):
140154
self._NA_action = NA_action
@@ -171,6 +185,9 @@ def sniff(self, data):
171185
if hasattr(data, "dtype") and np.issubdtype(data.dtype, np.bool_):
172186
self._level_set = set([True, False])
173187
return True
188+
189+
data = _categorical_shape_fix(data)
190+
174191
for value in data:
175192
if self._NA_action.is_categorical_NA(value):
176193
continue
@@ -245,15 +262,27 @@ def t(NA_types, datas, exp_finish_fast, exp_levels, exp_contrast=None):
245262
# contrasts
246263
t([], [C([10, 20], contrast="FOO")], False, (10, 20), "FOO")
247264

248-
# unhashable level error:
265+
# no box
266+
t([], [[10, 30], [20]], False, (10, 20, 30))
267+
t([], [["b", "a"], ["a"]], False, ("a", "b"))
268+
269+
# 0d
270+
t([], ["b"], False, ("b",))
271+
249272
from nose.tools import assert_raises
273+
274+
# unhashable level error:
250275
sniffer = CategoricalSniffer(NAAction())
251276
assert_raises(PatsyError, sniffer.sniff, [{}])
252277

278+
# >1d is illegal
279+
assert_raises(PatsyError, sniffer.sniff, np.asarray([["b"]]))
280+
253281
# returns either a 1d ndarray or a pandas.Series
254282
def categorical_to_int(data, levels, NA_action, origin=None):
255283
assert isinstance(levels, tuple)
256284
# In this function, missing values are always mapped to -1
285+
257286
if have_pandas_categorical and isinstance(data, pandas.Categorical):
258287
data_levels_tuple = tuple(data.levels)
259288
if not data_levels_tuple == levels:
@@ -262,22 +291,21 @@ def categorical_to_int(data, levels, NA_action, origin=None):
262291
# pandas.Categorical also uses -1 to indicate NA, and we don't try to
263292
# second-guess its NA detection, so we can just pass it back.
264293
return data.labels
294+
265295
if isinstance(data, _CategoricalBox):
266296
if data.levels is not None and tuple(data.levels) != levels:
267297
raise PatsyError("mismatching levels: expected %r, got %r"
268298
% (levels, tuple(data.levels)), origin)
269299
data = data.data
270-
if hasattr(data, "shape") and len(data.shape) > 1:
271-
raise PatsyError("categorical data must be 1-dimensional",
272-
origin)
273-
if (not iterable(data)
274-
or isinstance(data, (six.text_type, six.binary_type))):
275-
raise PatsyError("categorical data must be an iterable container")
300+
301+
data = _categorical_shape_fix(data)
302+
276303
try:
277304
level_to_int = dict(zip(levels, range(len(levels))))
278305
except TypeError:
279306
raise PatsyError("Error interpreting categorical data: "
280307
"all items must be hashable", origin)
308+
281309
# fastpath to avoid doing an item-by-item iteration over boolean arrays,
282310
# as requested by #44
283311
if hasattr(data, "dtype") and np.issubdtype(data.dtype, np.bool_):
@@ -371,14 +399,15 @@ def t(data, levels, expected, NA_action=NAAction()):
371399
C(["a", "b", "a"], levels=["a", "b"]),
372400
("b", "a"), NAAction())
373401

402+
# ndim == 0 is okay
403+
t("a", ("a", "b"), [0])
404+
t("b", ("a", "b"), [1])
405+
t(True, (False, True), [1])
406+
374407
# ndim == 2 is disallowed
375408
assert_raises(PatsyError, categorical_to_int,
376409
np.asarray([["a", "b"], ["b", "a"]]),
377410
("a", "b"), NAAction())
378-
# ndim == 0 is disallowed likewise
379-
assert_raises(PatsyError, categorical_to_int,
380-
"a",
381-
("a", "b"), NAAction())
382411

383412
# levels must be hashable
384413
assert_raises(PatsyError, categorical_to_int,

patsy/test_highlevel.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,3 +680,25 @@ def test_dmatrix_NA_action():
680680
assert_raises(PatsyError,
681681
dmatrices, "y ~ 1", data=data, return_type=return_type,
682682
NA_action="raise")
683+
684+
def test_0d_data():
685+
# Use case from statsmodels/statsmodels#1881
686+
data_0d = {"x1": 1.1, "x2": 1.2, "a": "a1"}
687+
688+
for formula, expected in [
689+
("x1 + x2", [[1, 1.1, 1.2]]),
690+
("C(a, levels=('a1', 'a2')) + x1", [[1, 0, 1.1]]),
691+
]:
692+
mat = dmatrix(formula, data_0d)
693+
assert np.allclose(mat, expected)
694+
695+
assert np.allclose(build_design_matrices([mat.design_info.builder],
696+
data_0d)[0],
697+
expected)
698+
if have_pandas:
699+
data_series = pandas.Series(data_0d)
700+
assert np.allclose(dmatrix(formula, data_series), expected)
701+
702+
assert np.allclose(build_design_matrices([mat.design_info.builder],
703+
data_series)[0],
704+
expected)

0 commit comments

Comments
 (0)