Skip to content

Commit eccdd48

Browse files
cyb70289pitrou
authored andcommitted
ARROW-10325: [C++][Compute] Refine aggregate kernel registration
Separate Mode and Variance/Stddev kernels registration from basic aggregation kernels. Closes apache#8523 from cyb70289/agg-refine Authored-by: Yibo Cai <yibo.cai@arm.com> Signed-off-by: Antoine Pitrou <antoine@python.org>
1 parent a9988ee commit eccdd48

7 files changed

Lines changed: 59 additions & 43 deletions

File tree

cpp/src/arrow/compute/kernels/aggregate_basic.cc

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424

2525
namespace arrow {
2626
namespace compute {
27-
namespace aggregate {
27+
28+
namespace {
2829

2930
void AggregateConsume(KernelContext* ctx, const ExecBatch& batch) {
3031
checked_cast<ScalarAggregator*>(ctx->state())->Consume(ctx, batch);
@@ -38,6 +39,19 @@ void AggregateFinalize(KernelContext* ctx, Datum* out) {
3839
checked_cast<ScalarAggregator*>(ctx->state())->Finalize(ctx, out);
3940
}
4041

42+
} // namespace
43+
44+
void AddAggKernel(std::shared_ptr<KernelSignature> sig, KernelInit init,
45+
ScalarAggregateFunction* func, SimdLevel::type simd_level) {
46+
ScalarAggregateKernel kernel(std::move(sig), init, AggregateConsume, AggregateMerge,
47+
AggregateFinalize);
48+
// Set the simd level
49+
kernel.simd_level = simd_level;
50+
DCHECK_OK(func->AddKernel(kernel));
51+
}
52+
53+
namespace aggregate {
54+
4155
// ----------------------------------------------------------------------
4256
// Count implementation
4357

@@ -137,15 +151,6 @@ std::unique_ptr<KernelState> MinMaxInit(KernelContext* ctx, const KernelInitArgs
137151
return visitor.Create();
138152
}
139153

140-
void AddAggKernel(std::shared_ptr<KernelSignature> sig, KernelInit init,
141-
ScalarAggregateFunction* func, SimdLevel::type simd_level) {
142-
ScalarAggregateKernel kernel(std::move(sig), init, AggregateConsume, AggregateMerge,
143-
AggregateFinalize);
144-
// Set the simd level
145-
kernel.simd_level = simd_level;
146-
DCHECK_OK(func->AddKernel(kernel));
147-
}
148-
149154
void AddBasicAggKernels(KernelInit init,
150155
const std::vector<std::shared_ptr<DataType>>& types,
151156
std::shared_ptr<DataType> out_ty, ScalarAggregateFunction* func,
@@ -202,8 +207,8 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
202207

203208
// Takes any array input, outputs int64 scalar
204209
InputType any_array(ValueDescr::ARRAY);
205-
aggregate::AddAggKernel(KernelSignature::Make({any_array}, ValueDescr::Scalar(int64())),
206-
aggregate::CountInit, func.get());
210+
AddAggKernel(KernelSignature::Make({any_array}, ValueDescr::Scalar(int64())),
211+
aggregate::CountInit, func.get());
207212
DCHECK_OK(registry->AddFunction(std::move(func)));
208213

209214
func = std::make_shared<ScalarAggregateFunction>("sum", Arity::Unary(), &sum_doc);
@@ -263,10 +268,6 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
263268
#endif
264269

265270
DCHECK_OK(registry->AddFunction(std::move(func)));
266-
267-
DCHECK_OK(registry->AddFunction(aggregate::AddModeAggKernels()));
268-
DCHECK_OK(registry->AddFunction(aggregate::AddStddevAggKernels()));
269-
DCHECK_OK(registry->AddFunction(aggregate::AddVarianceAggKernels()));
270271
}
271272

272273
} // namespace internal

cpp/src/arrow/compute/kernels/aggregate_basic_internal.h

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,6 @@ namespace arrow {
2929
namespace compute {
3030
namespace aggregate {
3131

32-
struct ScalarAggregator : public KernelState {
33-
virtual void Consume(KernelContext* ctx, const ExecBatch& batch) = 0;
34-
virtual void MergeFrom(KernelContext* ctx, KernelState&& src) = 0;
35-
virtual void Finalize(KernelContext* ctx, Datum* out) = 0;
36-
};
37-
38-
void AddAggKernel(std::shared_ptr<KernelSignature> sig, KernelInit init,
39-
ScalarAggregateFunction* func,
40-
SimdLevel::type simd_level = SimdLevel::NONE);
41-
4232
void AddBasicAggKernels(KernelInit init,
4333
const std::vector<std::shared_ptr<DataType>>& types,
4434
std::shared_ptr<DataType> out_ty, ScalarAggregateFunction* func,
@@ -58,10 +48,6 @@ void AddSumAvx512AggKernels(ScalarAggregateFunction* func);
5848
void AddMeanAvx512AggKernels(ScalarAggregateFunction* func);
5949
void AddMinMaxAvx512AggKernels(ScalarAggregateFunction* func);
6050

61-
std::shared_ptr<ScalarAggregateFunction> AddModeAggKernels();
62-
std::shared_ptr<ScalarAggregateFunction> AddStddevAggKernels();
63-
std::shared_ptr<ScalarAggregateFunction> AddVarianceAggKernels();
64-
6551
// ----------------------------------------------------------------------
6652
// Sum implementation
6753

cpp/src/arrow/compute/kernels/aggregate_internal.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,5 +47,15 @@ struct FindAccumulatorType<I, enable_if_floating_point<I>> {
4747
using Type = DoubleType;
4848
};
4949

50+
struct ScalarAggregator : public KernelState {
51+
virtual void Consume(KernelContext* ctx, const ExecBatch& batch) = 0;
52+
virtual void MergeFrom(KernelContext* ctx, KernelState&& src) = 0;
53+
virtual void Finalize(KernelContext* ctx, Datum* out) = 0;
54+
};
55+
56+
void AddAggKernel(std::shared_ptr<KernelSignature> sig, KernelInit init,
57+
ScalarAggregateFunction* func,
58+
SimdLevel::type simd_level = SimdLevel::NONE);
59+
5060
} // namespace compute
5161
} // namespace arrow

cpp/src/arrow/compute/kernels/aggregate_mode.cc

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@
1818
#include <cmath>
1919
#include <unordered_map>
2020

21-
#include "arrow/compute/kernels/aggregate_basic_internal.h"
21+
#include "arrow/compute/api_aggregate.h"
22+
#include "arrow/compute/kernels/aggregate_internal.h"
23+
#include "arrow/compute/kernels/common.h"
2224

2325
namespace arrow {
2426
namespace compute {
25-
namespace aggregate {
27+
namespace internal {
2628

2729
namespace {
2830

@@ -277,16 +279,20 @@ const FunctionDoc mode_doc{
277279
"null is returned."),
278280
{"array"}};
279281

280-
} // namespace
281-
282282
std::shared_ptr<ScalarAggregateFunction> AddModeAggKernels() {
283283
auto func =
284284
std::make_shared<ScalarAggregateFunction>("mode", Arity::Unary(), &mode_doc);
285285
AddModeKernels(ModeInit, {boolean()}, func.get());
286-
AddModeKernels(ModeInit, internal::NumericTypes(), func.get());
286+
AddModeKernels(ModeInit, NumericTypes(), func.get());
287287
return func;
288288
}
289289

290-
} // namespace aggregate
290+
} // namespace
291+
292+
void RegisterScalarAggregateMode(FunctionRegistry* registry) {
293+
DCHECK_OK(registry->AddFunction(AddModeAggKernels()));
294+
}
295+
296+
} // namespace internal
291297
} // namespace compute
292298
} // namespace arrow

cpp/src/arrow/compute/kernels/aggregate_var_std.cc

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,16 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
#include "arrow/compute/kernels/aggregate_basic_internal.h"
18+
#include <cmath>
19+
20+
#include "arrow/compute/api_aggregate.h"
21+
#include "arrow/compute/kernels/aggregate_internal.h"
22+
#include "arrow/compute/kernels/common.h"
1923
#include "arrow/util/int128_internal.h"
2024

2125
namespace arrow {
2226
namespace compute {
23-
namespace aggregate {
27+
namespace internal {
2428

2529
namespace {
2630

@@ -252,24 +256,29 @@ const FunctionDoc variance_doc{
252256
{"array"},
253257
"VarianceOptions"};
254258

255-
} // namespace
256-
257259
std::shared_ptr<ScalarAggregateFunction> AddStddevAggKernels() {
258260
static auto default_std_options = VarianceOptions::Defaults();
259261
auto func = std::make_shared<ScalarAggregateFunction>(
260262
"stddev", Arity::Unary(), &stddev_doc, &default_std_options);
261-
AddVarStdKernels(StddevInit, internal::NumericTypes(), func.get());
263+
AddVarStdKernels(StddevInit, NumericTypes(), func.get());
262264
return func;
263265
}
264266

265267
std::shared_ptr<ScalarAggregateFunction> AddVarianceAggKernels() {
266268
static auto default_var_options = VarianceOptions::Defaults();
267269
auto func = std::make_shared<ScalarAggregateFunction>(
268270
"variance", Arity::Unary(), &variance_doc, &default_var_options);
269-
AddVarStdKernels(VarianceInit, internal::NumericTypes(), func.get());
271+
AddVarStdKernels(VarianceInit, NumericTypes(), func.get());
270272
return func;
271273
}
272274

273-
} // namespace aggregate
275+
} // namespace
276+
277+
void RegisterScalarAggregateVariance(FunctionRegistry* registry) {
278+
DCHECK_OK(registry->AddFunction(AddVarianceAggKernels()));
279+
DCHECK_OK(registry->AddFunction(AddStddevAggKernels()));
280+
}
281+
282+
} // namespace internal
274283
} // namespace compute
275284
} // namespace arrow

cpp/src/arrow/compute/registry.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ static std::unique_ptr<FunctionRegistry> CreateBuiltInRegistry() {
128128

129129
// Aggregate functions
130130
RegisterScalarAggregateBasic(registry.get());
131+
RegisterScalarAggregateMode(registry.get());
132+
RegisterScalarAggregateVariance(registry.get());
131133

132134
// Vector functions
133135
RegisterVectorHash(registry.get());

cpp/src/arrow/compute/registry_internal.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ void RegisterVectorSort(FunctionRegistry* registry);
4343

4444
// Aggregate functions
4545
void RegisterScalarAggregateBasic(FunctionRegistry* registry);
46+
void RegisterScalarAggregateMode(FunctionRegistry* registry);
47+
void RegisterScalarAggregateVariance(FunctionRegistry* registry);
4648

4749
} // namespace internal
4850
} // namespace compute

0 commit comments

Comments
 (0)