|
6 | 6 | # patsy.util, which is misc. utilities useful for implementing patsy). |
7 | 7 |
|
8 | 8 | # These are made available in the patsy.* namespace |
9 | | -__all__ = ["balanced", "demo_data"] |
| 9 | +__all__ = ["balanced", "demo_data", "LookupFactor"] |
10 | 10 |
|
11 | 11 | import numpy as np |
12 | 12 | from patsy import PatsyError |
13 | 13 | from patsy.compat import itertools_product |
| 14 | +from patsy.categorical import C |
14 | 15 |
|
15 | 16 | def balanced(**kwargs): |
16 | 17 | """balanced(factor_name=num_levels, [factor_name=num_levels, ..., repeat=1]) |
@@ -136,3 +137,89 @@ def test_demo_data(): |
136 | 137 | from nose.tools import assert_raises |
137 | 138 | assert_raises(PatsyError, demo_data, "a", "b", "__123") |
138 | 139 | assert_raises(TypeError, demo_data, "a", "b", asdfasdf=123) |
| 140 | + |
| 141 | +class LookupFactor(object): |
| 142 | + """A simple factor class that simply looks up a named entry in the given |
| 143 | + data. |
| 144 | +
|
| 145 | + Useful for programatically constructing formulas, and as a simple example |
| 146 | + of the factor protocol. For details see |
| 147 | + :ref:`expert-model-specification`. |
| 148 | +
|
| 149 | + Example:: |
| 150 | +
|
| 151 | + dmatrix(ModelDesc([], [Term([LookupFactor("x")])]), {"x": [1, 2, 3]}) |
| 152 | + """ |
| 153 | + def __init__(self, varname, |
| 154 | + force_categorical=False, contrast=None, levels=None, |
| 155 | + origin=None): |
| 156 | + self._varname = varname |
| 157 | + self._force_categorical = force_categorical |
| 158 | + self._contrast = contrast |
| 159 | + self._levels = levels |
| 160 | + self.origin = origin |
| 161 | + if not self._force_categorical: |
| 162 | + if contrast is not None: |
| 163 | + raise ValueError("contrast= requires force_categorical=True") |
| 164 | + if levels is not None: |
| 165 | + raise ValueError("levels= requires force_categorical=True") |
| 166 | + |
| 167 | + def name(self): |
| 168 | + return self._varname |
| 169 | + |
| 170 | + def __repr__(self): |
| 171 | + return "%s(%r)" % (self.__class__.__name__, self._varname) |
| 172 | + |
| 173 | + def __eq__(self, other): |
| 174 | + return (isinstance(other, LookupFactor) |
| 175 | + and self._varname == other._varname |
| 176 | + and self._force_categorical == other._force_categorical |
| 177 | + and self._contrast == other._contrast |
| 178 | + and self._levels == other._levels) |
| 179 | + |
| 180 | + def __ne__(self, other): |
| 181 | + return not self == other |
| 182 | + |
| 183 | + def __hash__(self): |
| 184 | + return hash((LookupFactor, self._varname, |
| 185 | + self._force_categorical, self._contrast, self._levels)) |
| 186 | + |
| 187 | + def memorize_passes_needed(self, state): |
| 188 | + return 0 |
| 189 | + |
| 190 | + def memorize_chunk(self, state, which_pass, env): # pragma: no cover |
| 191 | + assert False |
| 192 | + |
| 193 | + def memorize_finish(self, state, which_pass): # pragma: no cover |
| 194 | + assert False |
| 195 | + |
| 196 | + def eval(self, memorize_state, data): |
| 197 | + value = data[self._varname] |
| 198 | + if self._force_categorical: |
| 199 | + value = C(value, contrast=self._contrast, levels=self._levels) |
| 200 | + return value |
| 201 | + |
| 202 | +def test_LookupFactor(): |
| 203 | + l_a = LookupFactor("a") |
| 204 | + assert l_a.name() == "a" |
| 205 | + assert l_a == LookupFactor("a") |
| 206 | + assert l_a != LookupFactor("b") |
| 207 | + assert hash(l_a) == hash(LookupFactor("a")) |
| 208 | + assert hash(l_a) != hash(LookupFactor("b")) |
| 209 | + assert l_a.eval({}, {"a": 1}) == 1 |
| 210 | + assert l_a.eval({}, {"a": 2}) == 2 |
| 211 | + assert repr(l_a) == "LookupFactor('a')" |
| 212 | + assert l_a.origin is None |
| 213 | + l_with_origin = LookupFactor("b", origin="asdf") |
| 214 | + assert l_with_origin.origin == "asdf" |
| 215 | + |
| 216 | + l_c = LookupFactor("c", force_categorical=True, |
| 217 | + contrast="CONTRAST", levels=(1, 2)) |
| 218 | + box = l_c.eval({}, {"c": [1, 1, 2]}) |
| 219 | + assert box.data == [1, 1, 2] |
| 220 | + assert box.contrast == "CONTRAST" |
| 221 | + assert box.levels == (1, 2) |
| 222 | + |
| 223 | + from nose.tools import assert_raises |
| 224 | + assert_raises(ValueError, LookupFactor, "nc", contrast="CONTRAST") |
| 225 | + assert_raises(ValueError, LookupFactor, "nc", levels=(1, 2)) |
0 commit comments