Skip to content

Commit 2b1740d

Browse files
committed
Add ability to build a subset of terms.
Esp. useful for things like graphing fitted splines.
1 parent 9df3873 commit 2b1740d

File tree

2 files changed

+132
-1
lines changed

2 files changed

+132
-1
lines changed

patsy/build.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from patsy.design_info import DesignMatrix, DesignInfo
2020
from patsy.redundancy import pick_contrasts_for_term
2121
from patsy.desc import ModelDesc
22+
from patsy.eval import EvalEnvironment
2223
from patsy.contrasts import code_contrast_matrix, Treatment
2324
from patsy.compat import itertools_product, OrderedDict
2425
from patsy.missing import NAAction
@@ -718,6 +719,79 @@ def design_info(self):
718719
return DesignInfo(self._column_names, self._term_slices,
719720
builder=self)
720721

722+
def term_subset_builder(self, which_terms):
723+
"""Create a new :class:`DesignMatrixBuilder` that includes only a
724+
subset of the terms that this object does.
725+
726+
For example, if `builder` has terms `x`, `y`, and `z`, then::
727+
728+
builder2 = builder.term_subset_builder(["x", "z"])
729+
730+
will return a new builder that will return design matrices with only
731+
the columns corresponding to the terms `x` and `y`. For example, in
732+
general these two expressions will return the same thing::
733+
734+
build_design_matrix([builder], data)[0][:, [0, 2, 3, 4]]
735+
build_design_matrix([builder2], data)[0]
736+
737+
However, a critical difference is that in the second case, `data` need
738+
not contain any values for `y`. This is very useful when doing
739+
prediction using a subset of a model, in which situation R usually
740+
forces you to specify dummy values for `y`.
741+
742+
If using a formula to specify the terms to include, remember that like
743+
any formula, the intercept term will be included by default, so use
744+
`0` or `-1` in your formula if you want to avoid this.
745+
746+
:arg which_terms: The terms which should be kept in the new
747+
:class:`DesignMatrixBuilder`. If this is a string, then it is parsed
748+
as a formula, and then the names of the resulting terms are taken as
749+
the terms to keep. If it is a list, then it can contain a mixture of
750+
term names (as strings) and :class:`Term` objects.
751+
"""
752+
factor_to_evaluators = {}
753+
for evaluator in self._evaluators:
754+
factor_to_evaluators[evaluator.factor] = evaluator
755+
design_info = self.design_info
756+
term_name_to_term = dict(zip(design_info.term_names,
757+
design_info.terms))
758+
if isinstance(which_terms, basestring):
759+
# We don't use this EvalEnvironment -- all we want to do is to
760+
# find matching terms, and we can't do that use == on Term
761+
# objects, because that calls == on factor objects, which in turn
762+
# compares EvalEnvironments. So all we do with the parsed formula
763+
# is pull out the term *names*, which the EvalEnvironment doesn't
764+
# effect. This is just a placeholder then to allow the ModelDesc
765+
# to be created:
766+
env = EvalEnvironment({})
767+
desc = ModelDesc.from_formula(which_terms, env)
768+
if desc.lhs_termlist:
769+
raise PatsyError("right-hand-side-only formula required")
770+
which_terms = [term.name() for term in desc.rhs_termlist]
771+
terms = []
772+
evaluators = set()
773+
term_to_column_builders = {}
774+
for term_or_name in which_terms:
775+
if isinstance(term_or_name, basestring):
776+
if term_or_name not in term_name_to_term:
777+
raise PatsyError("requested term %r not found in "
778+
"this DesignMatrixBuilder"
779+
% (term_or_name,))
780+
term = term_name_to_term[term_or_name]
781+
else:
782+
term = term_or_name
783+
if term not in self._termlist:
784+
raise PatsyError("requested term '%s' not found in this "
785+
"DesignMatrixBuilder" % (term,))
786+
for factor in term.factors:
787+
evaluators.add(factor_to_evaluators[factor])
788+
terms.append(term)
789+
column_builder = self._term_to_column_builders[term]
790+
term_to_column_builders[term] = column_builder
791+
return DesignMatrixBuilder(terms,
792+
evaluators,
793+
term_to_column_builders)
794+
721795
def _build(self, evaluator_to_values, dtype):
722796
factor_to_values = {}
723797
need_reshape = False

patsy/test_build.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,4 +603,61 @@ def test_contrast():
603603
[8, -1],
604604
[7, 12],
605605
[2, 13]])
606-
606+
607+
def test_term_subset_builder():
608+
# For each combination of:
609+
# formula, term names, term objects, mixed term name and term objects
610+
# check that results match subset of full build
611+
# and that removed variables don't hurt
612+
all_data = {"x": [1, 2],
613+
"y": [[3.1, 3.2],
614+
[4.1, 4.2]],
615+
"z": [5, 6]}
616+
all_terms = make_termlist("x", "y", "z")
617+
def iter_maker():
618+
yield all_data
619+
all_builder = design_matrix_builders([all_terms], iter_maker)[0]
620+
full_matrix = build_design_matrices([all_builder], all_data)[0]
621+
622+
def t(which_terms, variables, columns):
623+
sub_builder = all_builder.term_subset_builder(which_terms)
624+
sub_data = {}
625+
for variable in variables:
626+
sub_data[variable] = all_data[variable]
627+
sub_matrix = build_design_matrices([sub_builder], sub_data)[0]
628+
sub_full_matrix = full_matrix[:, columns]
629+
if not isinstance(which_terms, basestring):
630+
assert len(which_terms) == len(sub_builder.design_info.terms)
631+
assert np.array_equal(sub_matrix, sub_full_matrix)
632+
633+
t("~ 0 + x + y + z", ["x", "y", "z"], slice(None))
634+
t(["x", "y", "z"], ["x", "y", "z"], slice(None))
635+
t([unicode("x"), unicode("y"), unicode("z")],
636+
["x", "y", "z"], slice(None))
637+
t(all_terms, ["x", "y", "z"], slice(None))
638+
t([all_terms[0], "y", all_terms[2]], ["x", "y", "z"], slice(None))
639+
640+
t("~ 0 + x + z", ["x", "z"], [0, 3])
641+
t(["x", "z"], ["x", "z"], [0, 3])
642+
t([unicode("x"), unicode("z")], ["x", "z"], [0, 3])
643+
t([all_terms[0], all_terms[2]], ["x", "z"], [0, 3])
644+
t([all_terms[0], "z"], ["x", "z"], [0, 3])
645+
646+
t("~ 0 + z + x", ["x", "z"], [3, 0])
647+
t(["z", "x"], ["x", "z"], [3, 0])
648+
t([unicode("z"), unicode("x")], ["x", "z"], [3, 0])
649+
t([all_terms[2], all_terms[0]], ["x", "z"], [3, 0])
650+
t([all_terms[2], "x"], ["x", "z"], [3, 0])
651+
652+
t("~ 0 + y", ["y"], [1, 2])
653+
t(["y"], ["y"], [1, 2])
654+
t([unicode("y")], ["y"], [1, 2])
655+
t([all_terms[1]], ["y"], [1, 2])
656+
657+
# Formula can't have a LHS
658+
assert_raises(PatsyError, all_builder.term_subset_builder, "a ~ a")
659+
# Term must exist
660+
assert_raises(PatsyError, all_builder.term_subset_builder, "~ asdf")
661+
assert_raises(PatsyError, all_builder.term_subset_builder, ["asdf"])
662+
assert_raises(PatsyError,
663+
all_builder.term_subset_builder, [Term(["asdf"])])

0 commit comments

Comments
 (0)