@@ -603,4 +603,61 @@ def test_contrast():
603603 [8 , - 1 ],
604604 [7 , 12 ],
605605 [2 , 13 ]])
606-
606+
607+ def test_term_subset_builder ():
608+ # For each combination of:
609+ # formula, term names, term objects, mixed term name and term objects
610+ # check that results match subset of full build
611+ # and that removed variables don't hurt
612+ all_data = {"x" : [1 , 2 ],
613+ "y" : [[3.1 , 3.2 ],
614+ [4.1 , 4.2 ]],
615+ "z" : [5 , 6 ]}
616+ all_terms = make_termlist ("x" , "y" , "z" )
617+ def iter_maker ():
618+ yield all_data
619+ all_builder = design_matrix_builders ([all_terms ], iter_maker )[0 ]
620+ full_matrix = build_design_matrices ([all_builder ], all_data )[0 ]
621+
622+ def t (which_terms , variables , columns ):
623+ sub_builder = all_builder .term_subset_builder (which_terms )
624+ sub_data = {}
625+ for variable in variables :
626+ sub_data [variable ] = all_data [variable ]
627+ sub_matrix = build_design_matrices ([sub_builder ], sub_data )[0 ]
628+ sub_full_matrix = full_matrix [:, columns ]
629+ if not isinstance (which_terms , basestring ):
630+ assert len (which_terms ) == len (sub_builder .design_info .terms )
631+ assert np .array_equal (sub_matrix , sub_full_matrix )
632+
633+ t ("~ 0 + x + y + z" , ["x" , "y" , "z" ], slice (None ))
634+ t (["x" , "y" , "z" ], ["x" , "y" , "z" ], slice (None ))
635+ t ([unicode ("x" ), unicode ("y" ), unicode ("z" )],
636+ ["x" , "y" , "z" ], slice (None ))
637+ t (all_terms , ["x" , "y" , "z" ], slice (None ))
638+ t ([all_terms [0 ], "y" , all_terms [2 ]], ["x" , "y" , "z" ], slice (None ))
639+
640+ t ("~ 0 + x + z" , ["x" , "z" ], [0 , 3 ])
641+ t (["x" , "z" ], ["x" , "z" ], [0 , 3 ])
642+ t ([unicode ("x" ), unicode ("z" )], ["x" , "z" ], [0 , 3 ])
643+ t ([all_terms [0 ], all_terms [2 ]], ["x" , "z" ], [0 , 3 ])
644+ t ([all_terms [0 ], "z" ], ["x" , "z" ], [0 , 3 ])
645+
646+ t ("~ 0 + z + x" , ["x" , "z" ], [3 , 0 ])
647+ t (["z" , "x" ], ["x" , "z" ], [3 , 0 ])
648+ t ([unicode ("z" ), unicode ("x" )], ["x" , "z" ], [3 , 0 ])
649+ t ([all_terms [2 ], all_terms [0 ]], ["x" , "z" ], [3 , 0 ])
650+ t ([all_terms [2 ], "x" ], ["x" , "z" ], [3 , 0 ])
651+
652+ t ("~ 0 + y" , ["y" ], [1 , 2 ])
653+ t (["y" ], ["y" ], [1 , 2 ])
654+ t ([unicode ("y" )], ["y" ], [1 , 2 ])
655+ t ([all_terms [1 ]], ["y" ], [1 , 2 ])
656+
657+ # Formula can't have a LHS
658+ assert_raises (PatsyError , all_builder .term_subset_builder , "a ~ a" )
659+ # Term must exist
660+ assert_raises (PatsyError , all_builder .term_subset_builder , "~ asdf" )
661+ assert_raises (PatsyError , all_builder .term_subset_builder , ["asdf" ])
662+ assert_raises (PatsyError ,
663+ all_builder .term_subset_builder , [Term (["asdf" ])])
0 commit comments