Skip to content

Commit 26db44a

Browse files
committed
small cleanups, improve test coverage slightly
1 parent cdd2ad5 commit 26db44a

File tree

2 files changed

+52
-10
lines changed

2 files changed

+52
-10
lines changed

patsy/categorical.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
# Copyright (C) 2011-2013 Nathaniel Smith <njs@pobox.com>
33
# See file COPYING for license information.
44

5-
__all__ = ["C", "guess_categorical", "CategoricalSniffer", "categorical_to_int"]
5+
__all__ = ["C", "guess_categorical", "CategoricalSniffer",
6+
"categorical_to_int"]
67

78
# How we handle categorical data: the big picture
89
# -----------------------------------------------
@@ -37,9 +38,8 @@
3738
from patsy.state import stateful_transform
3839
from patsy.util import (SortAnythingKey,
3940
have_pandas, have_pandas_categorical,
40-
asarray_or_pandas,
41-
pandas_friendly_reshape,
42-
safe_scalar_isnan)
41+
safe_scalar_isnan,
42+
iterable)
4343

4444
if have_pandas:
4545
import pandas
@@ -92,6 +92,25 @@ def C(data, contrast=None, levels=None):
9292
data = data.data
9393
return _CategoricalBox(data, contrast, levels)
9494

95+
def test_C():
96+
c1 = C("asdf")
97+
assert isinstance(c1, _CategoricalBox)
98+
assert c1.data == "asdf"
99+
assert c1.levels is None
100+
assert c1.contrast is None
101+
c2 = C("DATA", "CONTRAST", "LEVELS")
102+
assert c2.data == "DATA"
103+
assert c2.contrast == "CONTRAST"
104+
assert c2.levels == "LEVELS"
105+
c3 = C(c2, levels="NEW LEVELS")
106+
assert c3.data == "DATA"
107+
assert c3.contrast == "CONTRAST"
108+
assert c3.levels == "NEW LEVELS"
109+
c4 = C(c2, "NEW CONTRAST")
110+
assert c4.data == "DATA"
111+
assert c4.contrast == "NEW CONTRAST"
112+
assert c4.levels == "LEVELS"
113+
95114
def guess_categorical(data):
96115
if have_pandas_categorical and isinstance(data, pandas.Categorical):
97116
return True
@@ -198,7 +217,14 @@ def t(NA_types, datas, exp_finish_fast, exp_levels, exp_contrast=None):
198217
t(["None", "NaN"], [C([1, np.nan]), C([10, None])],
199218
False, (1, 10))
200219
# But 'None' can be a type if we don't make it represent NA:
201-
t(["NaN"], [C([1, np.nan, None])], False, (None, 1))
220+
sniffer = CategoricalSniffer(NAAction(NA_types=["NaN"]))
221+
sniffer.sniff(C([1, np.nan, None]))
222+
# The level order here is different on py2 and py3 :-( Because there's no
223+
# consistent way to sort mixed-type values on both py2 and py3. Honestly
224+
# people probably shouldn't use this, but I don't know how to give a
225+
# sensible error.
226+
levels, _ = sniffer.levels_contrast()
227+
assert set(levels) == set([None, 1])
202228

203229
# bool special case
204230
t(["None", "NaN"], [C([True, np.nan, None])],
@@ -236,10 +262,10 @@ def categorical_to_int(data, levels, NA_action, origin=None):
236262
% (levels, tuple(data.levels)), origin)
237263
data = data.data
238264
if hasattr(data, "shape") and len(data.shape) > 1:
239-
raise PatsyError("categorical data must be at most 1-dimensional",
265+
raise PatsyError("categorical data must be 1-dimensional",
240266
origin)
241-
if hasattr(data, "shape") and len(data.shape) < 1:
242-
data.resize((-1,))
267+
if not iterable(data) or isinstance(data, basestring):
268+
raise PatsyError("categorical data must be an iterable container")
243269
try:
244270
level_to_int = dict(zip(levels, xrange(len(levels))))
245271
except TypeError:
@@ -337,8 +363,10 @@ def t(data, levels, expected, NA_action=NAAction()):
337363
assert_raises(PatsyError, categorical_to_int,
338364
np.asarray([["a", "b"], ["b", "a"]]),
339365
("a", "b"), NAAction())
340-
# ndim == 0 is okay (and coerced into ndim == 1)
341-
t("b", ("a", "b"), [1])
366+
# ndim == 0 is disallowed likewise
367+
assert_raises(PatsyError, categorical_to_int,
368+
"a",
369+
("a", "b"), NAAction())
342370

343371
# levels must be hashable
344372
assert_raises(PatsyError, categorical_to_int,

patsy/util.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"widest_float", "widest_complex", "wide_dtype_for", "widen",
99
"repr_pretty_delegate", "repr_pretty_impl",
1010
"SortAnythingKey", "safe_scalar_isnan", "safe_isnan",
11+
"iterable",
1112
]
1213

1314
import sys
@@ -535,3 +536,16 @@ def test_safe_isnan():
535536
# raw isnan raises a *different* error for strings than for objects:
536537
assert not safe_isnan("asdf")
537538

539+
def iterable(obj):
540+
try:
541+
iter(obj)
542+
except Exception:
543+
return False
544+
return True
545+
546+
def test_iterable():
547+
assert iterable("asdf")
548+
assert iterable([])
549+
assert iterable({"a": 1})
550+
assert not iterable(1)
551+
assert not iterable(iterable)

0 commit comments

Comments
 (0)