Skip to content

Commit b0ee5ef

Browse files
committed
FIX: showstopper bug with NA handling and data-independent builders
Some formulas, like "~ 1", don't actually depend on the passed-in data. So given a formula like "y ~ 1", what we do is build the y matrix for the LHS, and then build a 1x1 matrix for the RHS, and then broadcast the RHS matrix to match the left. However, this broadcasting logic was broadcasting to the *original* size of the LHS matrix, *before* NA removal. One result: expressions like dmatrices("y ~ 1") were returning matrices with different numbers of rows, whenever y had missing values. Another result: dmatrix(..., result_type="dataframe") was totally broken in the presence of NAs, because RHS-only formulas still have an invisible zero-column LHS which is calculated and then discarded. But this was calculated with the wrong number of rows, and then when we tried to convert it to a zero-column DataFrame, its shape didn't match the index that we tried to put on it. (Because NA removal *was* correctly affecting the index.) Fixes pydatagh-22.
1 parent 3048399 commit b0ee5ef

File tree

2 files changed

+37
-9
lines changed

2 files changed

+37
-9
lines changed

patsy/build.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -918,13 +918,19 @@ def build_design_matrices(builders, data,
918918
# Handle NAs
919919
values = evaluator_to_values.values()
920920
is_NAs = evaluator_to_isNAs.values()
921+
# num_rows is None iff evaluator_to_values (and associated sets like
922+
# 'values') are empty, i.e., we have no actual evaluators involved
923+
# (formulas like "~ 1").
921924
if return_type == "dataframe" and num_rows is not None:
922925
if pandas_index is None:
923926
pandas_index = np.arange(num_rows)
924927
values.append(pandas_index)
925928
is_NAs.append(np.zeros(len(pandas_index), dtype=bool))
926929
origins = [evaluator.factor.origin for evaluator in evaluator_to_values]
927930
new_values = NA_action.handle_NA(values, is_NAs, origins)
931+
# NA_action may have changed the number of rows.
932+
if num_rows is not None:
933+
num_rows = new_values[0].shape[0]
928934
if return_type == "dataframe" and num_rows is not None:
929935
pandas_index = new_values.pop()
930936
evaluator_to_values = dict(zip(evaluator_to_values, new_values))

patsy/test_highlevel.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -645,12 +645,34 @@ def raise_patsy_error(x):
645645
def test_dmatrix_NA_action():
646646
data = {"x": [1, 2, 3, np.nan], "y": [np.nan, 20, 30, 40]}
647647

648-
mat = dmatrix("x + y", data=data)
649-
assert np.array_equal(mat, [[1, 2, 20],
650-
[1, 3, 30]])
651-
assert_raises(PatsyError, dmatrix, "x + y", data=data, NA_action="raise")
652-
653-
lmat, rmat = dmatrices("y ~ x", data=data)
654-
assert np.array_equal(lmat, [[20], [30]])
655-
assert np.array_equal(rmat, [[1, 2], [1, 3]])
656-
assert_raises(PatsyError, dmatrices, "y ~ x", data=data, NA_action="raise")
648+
for return_type in ["matrix", "dataframe"]:
649+
mat = dmatrix("x + y", data=data, return_type=return_type)
650+
assert np.array_equal(mat, [[1, 2, 20],
651+
[1, 3, 30]])
652+
if return_type == "dataframe":
653+
assert mat.index.equals([1, 2])
654+
assert_raises(PatsyError, dmatrix, "x + y", data=data,
655+
return_type=return_type,
656+
NA_action="raise")
657+
658+
lmat, rmat = dmatrices("y ~ x", data=data, return_type=return_type)
659+
assert np.array_equal(lmat, [[20], [30]])
660+
assert np.array_equal(rmat, [[1, 2], [1, 3]])
661+
if return_type == "dataframe":
662+
assert lmat.index.equals([1, 2])
663+
assert rmat.index.equals([1, 2])
664+
assert_raises(PatsyError,
665+
dmatrices, "y ~ x", data=data, return_type=return_type,
666+
NA_action="raise")
667+
668+
# Initial release for the NA handling code had problems with
669+
# non-data-dependent matrices like "~ 1".
670+
lmat, rmat = dmatrices("y ~ 1", data=data, return_type=return_type)
671+
assert np.array_equal(lmat, [[20], [30], [40]])
672+
assert np.array_equal(rmat, [[1], [1], [1]])
673+
if return_type == "dataframe":
674+
assert lmat.index.equals([1, 2, 3])
675+
assert rmat.index.equals([1, 2, 3])
676+
assert_raises(PatsyError,
677+
dmatrices, "y ~ 1", data=data, return_type=return_type,
678+
NA_action="raise")

0 commit comments

Comments
 (0)