Skip to content

Commit af400a8

Browse files
felipecrvpitrouwestonpace
authored
apacheGH-33566: [C++] Add support for nullary and n-ary aggregate functions (apache#15083)
- [x] Add ability to pass 0 or more than 1 target fields via the Aggregate API - [x] Add support for nullary `count` -- `count(*)` - [x] Add a n-ary aggregate function to test changes `*` `*` I implemented a `"covariant(y, x)"` aggregation function and used it to test the Aggregate API changes, but it's not present in this PR now that I intend to focus on passing the CI tests and get a final review * Closes: apache#33566 Lead-authored-by: Felipe Oliveira Carvalho <felipekde@gmail.com> Co-authored-by: Antoine Pitrou <pitrou@free.fr> Co-authored-by: Weston Pace <weston.pace@gmail.com> Signed-off-by: Weston Pace <weston.pace@gmail.com>
1 parent 17ea6fc commit af400a8

24 files changed

Lines changed: 743 additions & 275 deletions

cpp/src/arrow/compute/api_aggregate.h

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
#pragma once
2222

23+
#include <vector>
24+
2325
#include "arrow/compute/function.h"
2426
#include "arrow/datum.h"
2527
#include "arrow/result.h"
@@ -186,16 +188,38 @@ class ARROW_EXPORT IndexOptions : public FunctionOptions {
186188

187189
/// \brief Configure a grouped aggregation
188190
struct ARROW_EXPORT Aggregate {
191+
Aggregate() = default;
192+
193+
Aggregate(std::string function, std::shared_ptr<FunctionOptions> options,
194+
std::vector<FieldRef> target, std::string name = "")
195+
: function(std::move(function)),
196+
options(std::move(options)),
197+
target(std::move(target)),
198+
name(std::move(name)) {}
199+
200+
Aggregate(std::string function, std::shared_ptr<FunctionOptions> options,
201+
FieldRef target, std::string name = "")
202+
: Aggregate(std::move(function), std::move(options),
203+
std::vector<FieldRef>{std::move(target)}, std::move(name)) {}
204+
205+
Aggregate(std::string function, FieldRef target, std::string name)
206+
: Aggregate(std::move(function), /*options=*/NULLPTR,
207+
std::vector<FieldRef>{std::move(target)}, std::move(name)) {}
208+
209+
Aggregate(std::string function, std::string name)
210+
: Aggregate(std::move(function), /*options=*/NULLPTR,
211+
/*target=*/std::vector<FieldRef>{}, std::move(name)) {}
212+
189213
/// the name of the aggregation function
190214
std::string function;
191215

192216
/// options for the aggregation function
193217
std::shared_ptr<FunctionOptions> options;
194218

195-
// fields to which aggregations will be applied
196-
FieldRef target;
219+
/// zero or more fields to which aggregations will be applied
220+
std::vector<FieldRef> target;
197221

198-
// output field name for aggregations
222+
/// optional output field name for aggregations
199223
std::string name;
200224
};
201225

cpp/src/arrow/compute/exec.cc

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,21 @@ ExecBatch ExecBatch::Slice(int64_t offset, int64_t length) const {
147147
return out;
148148
}
149149

150-
Result<ExecBatch> ExecBatch::Make(std::vector<Datum> values) {
150+
namespace {
151+
152+
enum LengthInferenceError {
153+
kEmptyInput = -1,
154+
kInvalidValues = -2,
155+
};
156+
157+
/// \brief Infer the ExecBatch length from values.
158+
///
159+
/// \return the inferred length of the batch. If there are no values in the
160+
/// batch then kEmptyInput (-1) is returned. If the values in the batch have
161+
/// different lengths then kInvalidValues (-2) is returned.
162+
int64_t DoInferLength(const std::vector<Datum>& values) {
151163
if (values.empty()) {
152-
return Status::Invalid("Cannot infer ExecBatch length without at least one value");
164+
return kEmptyInput;
153165
}
154166

155167
int64_t length = -1;
@@ -164,13 +176,52 @@ Result<ExecBatch> ExecBatch::Make(std::vector<Datum> values) {
164176
}
165177

166178
if (length != value.length()) {
179+
// all the arrays should have the same length
180+
return kInvalidValues;
181+
}
182+
}
183+
184+
return length == -1 ? 1 : length;
185+
}
186+
187+
} // namespace
188+
189+
Result<int64_t> ExecBatch::InferLength(const std::vector<Datum>& values) {
190+
const int64_t length = DoInferLength(values);
191+
switch (length) {
192+
case kInvalidValues:
167193
return Status::Invalid(
168194
"Arrays used to construct an ExecBatch must have equal length");
169-
}
195+
case kEmptyInput:
196+
return Status::Invalid("Cannot infer ExecBatch length without at least one value");
197+
default:
198+
break;
170199
}
200+
return {length};
201+
}
171202

172-
if (length == -1) {
173-
length = 1;
203+
Result<ExecBatch> ExecBatch::Make(std::vector<Datum> values, int64_t length) {
204+
// Infer the length again and/or validate the given length.
205+
const int64_t inferred_length = DoInferLength(values);
206+
switch (inferred_length) {
207+
case kEmptyInput:
208+
if (length < 0) {
209+
return Status::Invalid(
210+
"Cannot infer ExecBatch length without at least one value");
211+
}
212+
break;
213+
214+
case kInvalidValues:
215+
return Status::Invalid(
216+
"Arrays used to construct an ExecBatch must have equal length");
217+
218+
default:
219+
if (length < 0) {
220+
length = inferred_length;
221+
} else if (length != inferred_length) {
222+
return Status::Invalid("Length used to construct an ExecBatch is invalid");
223+
}
224+
break;
174225
}
175226

176227
return ExecBatch(std::move(values), length);

cpp/src/arrow/compute/exec.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <cstdint>
2525
#include <limits>
2626
#include <memory>
27+
#include <optional>
2728
#include <string>
2829
#include <utility>
2930
#include <vector>
@@ -174,7 +175,10 @@ struct ARROW_EXPORT ExecBatch {
174175

175176
explicit ExecBatch(const RecordBatch& batch);
176177

177-
static Result<ExecBatch> Make(std::vector<Datum> values);
178+
/// \brief Infer the ExecBatch length from values.
179+
static Result<int64_t> InferLength(const std::vector<Datum>& values);
180+
181+
static Result<ExecBatch> Make(std::vector<Datum> values, int64_t length = -1);
178182

179183
Result<std::shared_ptr<RecordBatch>> ToRecordBatch(
180184
std::shared_ptr<Schema> schema, MemoryPool* pool = default_memory_pool()) const;

cpp/src/arrow/compute/exec/aggregate.cc

Lines changed: 107 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -36,70 +36,91 @@ using internal::ToChars;
3636
namespace compute {
3737
namespace internal {
3838

39+
namespace {
40+
41+
std::vector<TypeHolder> ExtendWithGroupIdType(const std::vector<TypeHolder>& in_types) {
42+
std::vector<TypeHolder> aggr_in_types;
43+
aggr_in_types.reserve(in_types.size() + 1);
44+
aggr_in_types = in_types;
45+
aggr_in_types.emplace_back(uint32());
46+
return aggr_in_types;
47+
}
48+
49+
Result<const HashAggregateKernel*> GetKernel(ExecContext* ctx, const Aggregate& aggregate,
50+
const std::vector<TypeHolder>& in_types) {
51+
const auto aggr_in_types = ExtendWithGroupIdType(in_types);
52+
ARROW_ASSIGN_OR_RAISE(auto function,
53+
ctx->func_registry()->GetFunction(aggregate.function));
54+
ARROW_ASSIGN_OR_RAISE(const Kernel* kernel, function->DispatchExact(aggr_in_types));
55+
return static_cast<const HashAggregateKernel*>(kernel);
56+
}
57+
58+
Result<std::unique_ptr<KernelState>> InitKernel(const HashAggregateKernel* kernel,
59+
ExecContext* ctx,
60+
const Aggregate& aggregate,
61+
const std::vector<TypeHolder>& in_types) {
62+
const auto aggr_in_types = ExtendWithGroupIdType(in_types);
63+
64+
KernelContext kernel_ctx{ctx};
65+
const auto* options =
66+
arrow::internal::checked_cast<const FunctionOptions*>(aggregate.options.get());
67+
if (options == nullptr) {
68+
// use known default options for the named function if possible
69+
auto maybe_function = ctx->func_registry()->GetFunction(aggregate.function);
70+
if (maybe_function.ok()) {
71+
options = maybe_function.ValueOrDie()->default_options();
72+
}
73+
}
74+
75+
ARROW_ASSIGN_OR_RAISE(
76+
auto state,
77+
kernel->init(&kernel_ctx, KernelInitArgs{kernel, aggr_in_types, options}));
78+
return std::move(state);
79+
}
80+
81+
} // namespace
82+
3983
Result<std::vector<const HashAggregateKernel*>> GetKernels(
4084
ExecContext* ctx, const std::vector<Aggregate>& aggregates,
41-
const std::vector<TypeHolder>& in_types) {
85+
const std::vector<std::vector<TypeHolder>>& in_types) {
4286
if (aggregates.size() != in_types.size()) {
4387
return Status::Invalid(aggregates.size(), " aggregate functions were specified but ",
4488
in_types.size(), " arguments were provided.");
4589
}
4690

4791
std::vector<const HashAggregateKernel*> kernels(in_types.size());
48-
4992
for (size_t i = 0; i < aggregates.size(); ++i) {
50-
ARROW_ASSIGN_OR_RAISE(auto function,
51-
ctx->func_registry()->GetFunction(aggregates[i].function));
52-
ARROW_ASSIGN_OR_RAISE(const Kernel* kernel,
53-
function->DispatchExact({in_types[i], uint32()}));
54-
kernels[i] = static_cast<const HashAggregateKernel*>(kernel);
93+
ARROW_ASSIGN_OR_RAISE(kernels[i], GetKernel(ctx, aggregates[i], in_types[i]));
5594
}
5695
return kernels;
5796
}
5897

5998
Result<std::vector<std::unique_ptr<KernelState>>> InitKernels(
6099
const std::vector<const HashAggregateKernel*>& kernels, ExecContext* ctx,
61-
const std::vector<Aggregate>& aggregates, const std::vector<TypeHolder>& in_types) {
100+
const std::vector<Aggregate>& aggregates,
101+
const std::vector<std::vector<TypeHolder>>& in_types) {
62102
std::vector<std::unique_ptr<KernelState>> states(kernels.size());
63-
64103
for (size_t i = 0; i < aggregates.size(); ++i) {
65-
const FunctionOptions* options =
66-
arrow::internal::checked_cast<const FunctionOptions*>(
67-
aggregates[i].options.get());
68-
69-
if (options == nullptr) {
70-
// use known default options for the named function if possible
71-
auto maybe_function = ctx->func_registry()->GetFunction(aggregates[i].function);
72-
if (maybe_function.ok()) {
73-
options = maybe_function.ValueOrDie()->default_options();
74-
}
75-
}
76-
77-
KernelContext kernel_ctx{ctx};
78104
ARROW_ASSIGN_OR_RAISE(states[i],
79-
kernels[i]->init(&kernel_ctx, KernelInitArgs{kernels[i],
80-
{
81-
in_types[i],
82-
uint32(),
83-
},
84-
options}));
105+
InitKernel(kernels[i], ctx, aggregates[i], in_types[i]));
85106
}
86-
87107
return std::move(states);
88108
}
89109

90110
Result<FieldVector> ResolveKernels(
91111
const std::vector<Aggregate>& aggregates,
92112
const std::vector<const HashAggregateKernel*>& kernels,
93113
const std::vector<std::unique_ptr<KernelState>>& states, ExecContext* ctx,
94-
const std::vector<TypeHolder>& types) {
114+
const std::vector<std::vector<TypeHolder>>& types) {
95115
FieldVector fields(types.size());
96116

97117
for (size_t i = 0; i < kernels.size(); ++i) {
98118
KernelContext kernel_ctx{ctx};
99119
kernel_ctx.SetState(states[i].get());
100120

101-
ARROW_ASSIGN_OR_RAISE(auto type, kernels[i]->signature->out_type().Resolve(
102-
&kernel_ctx, {types[i], uint32()}));
121+
const auto aggr_in_types = ExtendWithGroupIdType(types[i]);
122+
ARROW_ASSIGN_OR_RAISE(
123+
auto type, kernels[i]->signature->out_type().Resolve(&kernel_ctx, aggr_in_types));
103124
fields[i] = field(aggregates[i].function, type.GetSharedPtr());
104125
}
105126
return fields;
@@ -121,27 +142,50 @@ Result<Datum> GroupBy(const std::vector<Datum>& arguments, const std::vector<Dat
121142
ExecSpanIterator argument_iterator;
122143

123144
ExecBatch args_batch;
124-
if (!arguments.empty()) {
125-
ARROW_ASSIGN_OR_RAISE(args_batch, ExecBatch::Make(arguments));
145+
Result<int64_t> inferred_length = ExecBatch::InferLength(arguments);
146+
if (!inferred_length.ok()) {
147+
inferred_length = ExecBatch::InferLength(keys);
148+
}
149+
ARROW_ASSIGN_OR_RAISE(const int64_t length, std::move(inferred_length));
150+
if (!aggregates.empty()) {
151+
ARROW_ASSIGN_OR_RAISE(args_batch, ExecBatch::Make(arguments, length));
126152

127153
// Construct and initialize HashAggregateKernels
128-
auto argument_types = args_batch.GetTypes();
154+
std::vector<std::vector<TypeHolder>> aggs_argument_types;
155+
aggs_argument_types.reserve(aggregates.size());
156+
size_t i = 0;
157+
for (const auto& aggregate : aggregates) {
158+
auto& agg_types = aggs_argument_types.emplace_back();
159+
const size_t num_needed = aggregate.target.size();
160+
for (size_t j = 0; j < num_needed && i < arguments.size(); j++, i++) {
161+
agg_types.emplace_back(arguments[i].type());
162+
}
163+
if (agg_types.size() != num_needed) {
164+
return Status::Invalid("Not enough arguments specified to aggregate functions.");
165+
}
166+
}
167+
DCHECK_EQ(aggs_argument_types.size(), aggregates.size());
168+
if (i != arguments.size()) {
169+
return Status::Invalid("Aggregate functions expect exactly ", i, " arguments, but ",
170+
arguments.size(), " were specified.");
171+
}
129172

130-
ARROW_ASSIGN_OR_RAISE(kernels, GetKernels(ctx, aggregates, argument_types));
173+
ARROW_ASSIGN_OR_RAISE(kernels, GetKernels(ctx, aggregates, aggs_argument_types));
131174

132175
states.resize(task_group->parallelism());
133176
for (auto& state : states) {
134-
ARROW_ASSIGN_OR_RAISE(state, InitKernels(kernels, ctx, aggregates, argument_types));
177+
ARROW_ASSIGN_OR_RAISE(state,
178+
InitKernels(kernels, ctx, aggregates, aggs_argument_types));
135179
}
136180

137-
ARROW_ASSIGN_OR_RAISE(
138-
out_fields, ResolveKernels(aggregates, kernels, states[0], ctx, argument_types));
181+
ARROW_ASSIGN_OR_RAISE(out_fields, ResolveKernels(aggregates, kernels, states[0], ctx,
182+
aggs_argument_types));
139183

140184
RETURN_NOT_OK(argument_iterator.Init(args_batch, ctx->exec_chunksize()));
141185
}
142186

143187
// Construct Groupers
144-
ARROW_ASSIGN_OR_RAISE(ExecBatch keys_batch, ExecBatch::Make(keys));
188+
ARROW_ASSIGN_OR_RAISE(ExecBatch keys_batch, ExecBatch::Make(keys, length));
145189
auto key_types = keys_batch.GetTypes();
146190

147191
std::vector<std::unique_ptr<Grouper>> groupers(task_group->parallelism());
@@ -164,6 +208,10 @@ Result<Datum> GroupBy(const std::vector<Datum>& arguments, const std::vector<Dat
164208
ExecSpan key_batch, argument_batch;
165209
while ((arguments.empty() || argument_iterator.Next(&argument_batch)) &&
166210
key_iterator.Next(&key_batch)) {
211+
if (arguments.empty()) {
212+
// A value-less argument_batch should still have a valid length
213+
argument_batch.length = key_batch.length;
214+
}
167215
if (key_batch.length == 0) continue;
168216

169217
task_group->Append([&, key_batch, argument_batch] {
@@ -181,13 +229,23 @@ Result<Datum> GroupBy(const std::vector<Datum>& arguments, const std::vector<Dat
181229
ARROW_ASSIGN_OR_RAISE(Datum id_batch, grouper->Consume(key_batch));
182230

183231
// consume group ids with HashAggregateKernels
184-
for (size_t i = 0; i < kernels.size(); ++i) {
232+
for (size_t k = 0, arg_idx = 0; k < kernels.size(); ++k) {
233+
const auto* kernel = kernels[k];
185234
KernelContext batch_ctx{ctx};
186-
batch_ctx.SetState(states[thread_index][i].get());
187-
ExecSpan kernel_batch({argument_batch[i], *id_batch.array()},
188-
argument_batch.length);
189-
RETURN_NOT_OK(kernels[i]->resize(&batch_ctx, grouper->num_groups()));
190-
RETURN_NOT_OK(kernels[i]->consume(&batch_ctx, kernel_batch));
235+
batch_ctx.SetState(states[thread_index][k].get());
236+
237+
const size_t kernel_num_args = kernel->signature->in_types().size();
238+
DCHECK_GT(kernel_num_args, 0);
239+
240+
std::vector<ExecValue> kernel_args;
241+
for (size_t i = 0; i + 1 < kernel_num_args; i++, arg_idx++) {
242+
kernel_args.push_back(argument_batch[arg_idx]);
243+
}
244+
kernel_args.emplace_back(*id_batch.array());
245+
246+
ExecSpan kernel_batch(std::move(kernel_args), argument_batch.length);
247+
RETURN_NOT_OK(kernel->resize(&batch_ctx, grouper->num_groups()));
248+
RETURN_NOT_OK(kernel->consume(&batch_ctx, kernel_batch));
191249
}
192250

193251
return Status::OK();
@@ -215,7 +273,7 @@ Result<Datum> GroupBy(const std::vector<Datum>& arguments, const std::vector<Dat
215273
}
216274

217275
// Finalize output
218-
ArrayDataVector out_data(arguments.size() + keys.size());
276+
ArrayDataVector out_data(kernels.size() + keys.size());
219277
auto it = out_data.begin();
220278

221279
for (size_t idx = 0; idx < kernels.size(); ++idx) {
@@ -231,8 +289,8 @@ Result<Datum> GroupBy(const std::vector<Datum>& arguments, const std::vector<Dat
231289
*it++ = key.array();
232290
}
233291

234-
int64_t length = out_data[0]->length;
235-
return ArrayData::Make(struct_(std::move(out_fields)), length,
292+
const int64_t out_length = out_data[0]->length;
293+
return ArrayData::Make(struct_(std::move(out_fields)), out_length,
236294
{/*null_bitmap=*/nullptr}, std::move(out_data),
237295
/*null_count=*/0);
238296
}

0 commit comments

Comments
 (0)