@@ -36,70 +36,91 @@ using internal::ToChars;
3636namespace compute {
3737namespace internal {
3838
39+ namespace {
40+
41+ std::vector<TypeHolder> ExtendWithGroupIdType (const std::vector<TypeHolder>& in_types) {
42+ std::vector<TypeHolder> aggr_in_types;
43+ aggr_in_types.reserve (in_types.size () + 1 );
44+ aggr_in_types = in_types;
45+ aggr_in_types.emplace_back (uint32 ());
46+ return aggr_in_types;
47+ }
48+
49+ Result<const HashAggregateKernel*> GetKernel (ExecContext* ctx, const Aggregate& aggregate,
50+ const std::vector<TypeHolder>& in_types) {
51+ const auto aggr_in_types = ExtendWithGroupIdType (in_types);
52+ ARROW_ASSIGN_OR_RAISE (auto function,
53+ ctx->func_registry ()->GetFunction (aggregate.function ));
54+ ARROW_ASSIGN_OR_RAISE (const Kernel* kernel, function->DispatchExact (aggr_in_types));
55+ return static_cast <const HashAggregateKernel*>(kernel);
56+ }
57+
58+ Result<std::unique_ptr<KernelState>> InitKernel (const HashAggregateKernel* kernel,
59+ ExecContext* ctx,
60+ const Aggregate& aggregate,
61+ const std::vector<TypeHolder>& in_types) {
62+ const auto aggr_in_types = ExtendWithGroupIdType (in_types);
63+
64+ KernelContext kernel_ctx{ctx};
65+ const auto * options =
66+ arrow::internal::checked_cast<const FunctionOptions*>(aggregate.options .get ());
67+ if (options == nullptr ) {
68+ // use known default options for the named function if possible
69+ auto maybe_function = ctx->func_registry ()->GetFunction (aggregate.function );
70+ if (maybe_function.ok ()) {
71+ options = maybe_function.ValueOrDie ()->default_options ();
72+ }
73+ }
74+
75+ ARROW_ASSIGN_OR_RAISE (
76+ auto state,
77+ kernel->init (&kernel_ctx, KernelInitArgs{kernel, aggr_in_types, options}));
78+ return std::move (state);
79+ }
80+
81+ } // namespace
82+
3983Result<std::vector<const HashAggregateKernel*>> GetKernels (
4084 ExecContext* ctx, const std::vector<Aggregate>& aggregates,
41- const std::vector<TypeHolder>& in_types) {
85+ const std::vector<std::vector< TypeHolder> >& in_types) {
4286 if (aggregates.size () != in_types.size ()) {
4387 return Status::Invalid (aggregates.size (), " aggregate functions were specified but " ,
4488 in_types.size (), " arguments were provided." );
4589 }
4690
4791 std::vector<const HashAggregateKernel*> kernels (in_types.size ());
48-
4992 for (size_t i = 0 ; i < aggregates.size (); ++i) {
50- ARROW_ASSIGN_OR_RAISE (auto function,
51- ctx->func_registry ()->GetFunction (aggregates[i].function ));
52- ARROW_ASSIGN_OR_RAISE (const Kernel* kernel,
53- function->DispatchExact ({in_types[i], uint32 ()}));
54- kernels[i] = static_cast <const HashAggregateKernel*>(kernel);
93+ ARROW_ASSIGN_OR_RAISE (kernels[i], GetKernel (ctx, aggregates[i], in_types[i]));
5594 }
5695 return kernels;
5796}
5897
5998Result<std::vector<std::unique_ptr<KernelState>>> InitKernels (
6099 const std::vector<const HashAggregateKernel*>& kernels, ExecContext* ctx,
61- const std::vector<Aggregate>& aggregates, const std::vector<TypeHolder>& in_types) {
100+ const std::vector<Aggregate>& aggregates,
101+ const std::vector<std::vector<TypeHolder>>& in_types) {
62102 std::vector<std::unique_ptr<KernelState>> states (kernels.size ());
63-
64103 for (size_t i = 0 ; i < aggregates.size (); ++i) {
65- const FunctionOptions* options =
66- arrow::internal::checked_cast<const FunctionOptions*>(
67- aggregates[i].options .get ());
68-
69- if (options == nullptr ) {
70- // use known default options for the named function if possible
71- auto maybe_function = ctx->func_registry ()->GetFunction (aggregates[i].function );
72- if (maybe_function.ok ()) {
73- options = maybe_function.ValueOrDie ()->default_options ();
74- }
75- }
76-
77- KernelContext kernel_ctx{ctx};
78104 ARROW_ASSIGN_OR_RAISE (states[i],
79- kernels[i]->init (&kernel_ctx, KernelInitArgs{kernels[i],
80- {
81- in_types[i],
82- uint32 (),
83- },
84- options}));
105+ InitKernel (kernels[i], ctx, aggregates[i], in_types[i]));
85106 }
86-
87107 return std::move (states);
88108}
89109
90110Result<FieldVector> ResolveKernels (
91111 const std::vector<Aggregate>& aggregates,
92112 const std::vector<const HashAggregateKernel*>& kernels,
93113 const std::vector<std::unique_ptr<KernelState>>& states, ExecContext* ctx,
94- const std::vector<TypeHolder>& types) {
114+ const std::vector<std::vector< TypeHolder> >& types) {
95115 FieldVector fields (types.size ());
96116
97117 for (size_t i = 0 ; i < kernels.size (); ++i) {
98118 KernelContext kernel_ctx{ctx};
99119 kernel_ctx.SetState (states[i].get ());
100120
101- ARROW_ASSIGN_OR_RAISE (auto type, kernels[i]->signature ->out_type ().Resolve (
102- &kernel_ctx, {types[i], uint32 ()}));
121+ const auto aggr_in_types = ExtendWithGroupIdType (types[i]);
122+ ARROW_ASSIGN_OR_RAISE (
123+ auto type, kernels[i]->signature ->out_type ().Resolve (&kernel_ctx, aggr_in_types));
103124 fields[i] = field (aggregates[i].function , type.GetSharedPtr ());
104125 }
105126 return fields;
@@ -121,27 +142,50 @@ Result<Datum> GroupBy(const std::vector<Datum>& arguments, const std::vector<Dat
121142 ExecSpanIterator argument_iterator;
122143
123144 ExecBatch args_batch;
124- if (!arguments.empty ()) {
125- ARROW_ASSIGN_OR_RAISE (args_batch, ExecBatch::Make (arguments));
145+ Result<int64_t > inferred_length = ExecBatch::InferLength (arguments);
146+ if (!inferred_length.ok ()) {
147+ inferred_length = ExecBatch::InferLength (keys);
148+ }
149+ ARROW_ASSIGN_OR_RAISE (const int64_t length, std::move (inferred_length));
150+ if (!aggregates.empty ()) {
151+ ARROW_ASSIGN_OR_RAISE (args_batch, ExecBatch::Make (arguments, length));
126152
127153 // Construct and initialize HashAggregateKernels
128- auto argument_types = args_batch.GetTypes ();
154+ std::vector<std::vector<TypeHolder>> aggs_argument_types;
155+ aggs_argument_types.reserve (aggregates.size ());
156+ size_t i = 0 ;
157+ for (const auto & aggregate : aggregates) {
158+ auto & agg_types = aggs_argument_types.emplace_back ();
159+ const size_t num_needed = aggregate.target .size ();
160+ for (size_t j = 0 ; j < num_needed && i < arguments.size (); j++, i++) {
161+ agg_types.emplace_back (arguments[i].type ());
162+ }
163+ if (agg_types.size () != num_needed) {
164+ return Status::Invalid (" Not enough arguments specified to aggregate functions." );
165+ }
166+ }
167+ DCHECK_EQ (aggs_argument_types.size (), aggregates.size ());
168+ if (i != arguments.size ()) {
169+ return Status::Invalid (" Aggregate functions expect exactly " , i, " arguments, but " ,
170+ arguments.size (), " were specified." );
171+ }
129172
130- ARROW_ASSIGN_OR_RAISE (kernels, GetKernels (ctx, aggregates, argument_types ));
173+ ARROW_ASSIGN_OR_RAISE (kernels, GetKernels (ctx, aggregates, aggs_argument_types ));
131174
132175 states.resize (task_group->parallelism ());
133176 for (auto & state : states) {
134- ARROW_ASSIGN_OR_RAISE (state, InitKernels (kernels, ctx, aggregates, argument_types));
177+ ARROW_ASSIGN_OR_RAISE (state,
178+ InitKernels (kernels, ctx, aggregates, aggs_argument_types));
135179 }
136180
137- ARROW_ASSIGN_OR_RAISE (
138- out_fields, ResolveKernels (aggregates, kernels, states[ 0 ], ctx, argument_types ));
181+ ARROW_ASSIGN_OR_RAISE (out_fields, ResolveKernels (aggregates, kernels, states[ 0 ], ctx,
182+ aggs_argument_types ));
139183
140184 RETURN_NOT_OK (argument_iterator.Init (args_batch, ctx->exec_chunksize ()));
141185 }
142186
143187 // Construct Groupers
144- ARROW_ASSIGN_OR_RAISE (ExecBatch keys_batch, ExecBatch::Make (keys));
188+ ARROW_ASSIGN_OR_RAISE (ExecBatch keys_batch, ExecBatch::Make (keys, length ));
145189 auto key_types = keys_batch.GetTypes ();
146190
147191 std::vector<std::unique_ptr<Grouper>> groupers (task_group->parallelism ());
@@ -164,6 +208,10 @@ Result<Datum> GroupBy(const std::vector<Datum>& arguments, const std::vector<Dat
164208 ExecSpan key_batch, argument_batch;
165209 while ((arguments.empty () || argument_iterator.Next (&argument_batch)) &&
166210 key_iterator.Next (&key_batch)) {
211+ if (arguments.empty ()) {
212+ // A value-less argument_batch should still have a valid length
213+ argument_batch.length = key_batch.length ;
214+ }
167215 if (key_batch.length == 0 ) continue ;
168216
169217 task_group->Append ([&, key_batch, argument_batch] {
@@ -181,13 +229,23 @@ Result<Datum> GroupBy(const std::vector<Datum>& arguments, const std::vector<Dat
181229 ARROW_ASSIGN_OR_RAISE (Datum id_batch, grouper->Consume (key_batch));
182230
183231 // consume group ids with HashAggregateKernels
184- for (size_t i = 0 ; i < kernels.size (); ++i) {
232+ for (size_t k = 0 , arg_idx = 0 ; k < kernels.size (); ++k) {
233+ const auto * kernel = kernels[k];
185234 KernelContext batch_ctx{ctx};
186- batch_ctx.SetState (states[thread_index][i].get ());
187- ExecSpan kernel_batch ({argument_batch[i], *id_batch.array ()},
188- argument_batch.length );
189- RETURN_NOT_OK (kernels[i]->resize (&batch_ctx, grouper->num_groups ()));
190- RETURN_NOT_OK (kernels[i]->consume (&batch_ctx, kernel_batch));
235+ batch_ctx.SetState (states[thread_index][k].get ());
236+
237+ const size_t kernel_num_args = kernel->signature ->in_types ().size ();
238+ DCHECK_GT (kernel_num_args, 0 );
239+
240+ std::vector<ExecValue> kernel_args;
241+ for (size_t i = 0 ; i + 1 < kernel_num_args; i++, arg_idx++) {
242+ kernel_args.push_back (argument_batch[arg_idx]);
243+ }
244+ kernel_args.emplace_back (*id_batch.array ());
245+
246+ ExecSpan kernel_batch (std::move (kernel_args), argument_batch.length );
247+ RETURN_NOT_OK (kernel->resize (&batch_ctx, grouper->num_groups ()));
248+ RETURN_NOT_OK (kernel->consume (&batch_ctx, kernel_batch));
191249 }
192250
193251 return Status::OK ();
@@ -215,7 +273,7 @@ Result<Datum> GroupBy(const std::vector<Datum>& arguments, const std::vector<Dat
215273 }
216274
217275 // Finalize output
218- ArrayDataVector out_data (arguments .size () + keys.size ());
276+ ArrayDataVector out_data (kernels .size () + keys.size ());
219277 auto it = out_data.begin ();
220278
221279 for (size_t idx = 0 ; idx < kernels.size (); ++idx) {
@@ -231,8 +289,8 @@ Result<Datum> GroupBy(const std::vector<Datum>& arguments, const std::vector<Dat
231289 *it++ = key.array ();
232290 }
233291
234- int64_t length = out_data[0 ]->length ;
235- return ArrayData::Make (struct_ (std::move (out_fields)), length ,
292+ const int64_t out_length = out_data[0 ]->length ;
293+ return ArrayData::Make (struct_ (std::move (out_fields)), out_length ,
236294 {/* null_bitmap=*/ nullptr }, std::move (out_data),
237295 /* null_count=*/ 0 );
238296}
0 commit comments