Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1373,6 +1373,7 @@ Groupby/resample/rolling
- Bug in :meth:`.DataFrameGroupBy.agg` with ``engine="numba"`` failing to respect ``as_index=False`` (:issue:`51228`)
- Bug in :meth:`DataFrameGroupBy.agg`, :meth:`SeriesGroupBy.agg`, and :meth:`Resampler.agg` would ignore arguments when passed a list of functions (:issue:`50863`)
- Bug in :meth:`DataFrameGroupBy.ohlc` ignoring ``as_index=False`` (:issue:`51413`)
- Bug in :meth:`DataFrameGroupBy.agg` after subsetting columns (e.g. ``.groupby(...)[["a", "b"]]``) would not include groupings in the result (:issue:`51186`)
-

Reshaping
Expand Down
10 changes: 2 additions & 8 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1320,21 +1320,15 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
agg = aggregate

def _iterate_slices(self) -> Iterable[Series]:
obj = self._selected_obj
obj = self._obj_with_exclusions
if self.axis == 1:
obj = obj.T

if isinstance(obj, Series) and obj.name not in self.exclusions:
if isinstance(obj, Series):
# Occurs when doing DataFrameGroupBy(...)["X"]
yield obj
else:
for label, values in obj.items():
if label in self.exclusions:
# Note: if we tried to just iterate over _obj_with_exclusions,
# we would break test_wrap_agg_out by yielding a column
# that is skipped here but not dropped from obj_with_exclusions
continue

yield values

def _aggregate_frame(self, func, *args, **kwargs) -> DataFrame:
Expand Down
17 changes: 15 additions & 2 deletions pandas/tests/groupby/aggregate/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,8 @@ def func(ser):

with pytest.raises(TypeError, match="Test error message"):
grouped.aggregate(func)
result = grouped[[c for c in three_group if c != "C"]].aggregate(func)
exp_grouped = three_group.loc[:, three_group.columns != "C"]
result = grouped[["D", "E", "F"]].aggregate(func)
exp_grouped = three_group.loc[:, ["A", "B", "D", "E", "F"]]
expected = exp_grouped.groupby(["A", "B"]).aggregate(func)
tm.assert_frame_equal(result, expected)

Expand Down Expand Up @@ -1521,3 +1521,16 @@ def foo2(x, b=2, c=0):
[[8, 8], [9, 9], [10, 10]], index=Index([1, 2, 3]), columns=["foo1", "foo2"]
)
tm.assert_frame_equal(result, expected)


def test_agg_groupings_selection():
# GH#51186 - a selected grouping should be in the output of agg
df = DataFrame({"a": [1, 1, 2], "b": [3, 3, 4], "c": [5, 6, 7]})
gb = df.groupby(["a", "b"])
selected_gb = gb[["b", "c"]]
result = selected_gb.agg(lambda x: x.sum())
index = MultiIndex(
levels=[[1, 2], [3, 4]], codes=[[0, 1], [0, 1]], names=["a", "b"]
)
expected = DataFrame({"b": [6, 4], "c": [11, 7]}, index=index)
tm.assert_frame_equal(result, expected)