Skip to content

Commit 2c405da

Browse files
lidavidmpitrou
andcommitted
ARROW-13882: [C++] Improve min_max/hash_min_max type support
This adds support for non-nested types to both, except for MonthDayNanoInterval (which is not really sortable). hash_min_max additionally lacks binary/string-like types as they require a different approach (I will file a followup). Closes apache#11111 from lidavidm/arrow-13882 Lead-authored-by: David Li <li.davidm96@gmail.com> Co-authored-by: Antoine Pitrou <antoine@python.org> Signed-off-by: Antoine Pitrou <antoine@python.org>
1 parent 1919c33 commit 2c405da

25 files changed

Lines changed: 791 additions & 354 deletions

cpp/src/arrow/array/array_binary_test.cc

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,6 @@ namespace arrow {
4343

4444
using internal::checked_cast;
4545

46-
using StringTypes =
47-
::testing::Types<StringType, LargeStringType, BinaryType, LargeBinaryType>;
48-
49-
using UTF8Types = ::testing::Types<StringType, LargeStringType>;
50-
5146
// ----------------------------------------------------------------------
5247
// String / Binary tests
5348

@@ -329,7 +324,7 @@ class TestStringArray : public ::testing::Test {
329324
std::shared_ptr<ArrayType> strings_;
330325
};
331326

332-
TYPED_TEST_SUITE(TestStringArray, StringTypes);
327+
TYPED_TEST_SUITE(TestStringArray, BinaryArrowTypes);
333328

334329
TYPED_TEST(TestStringArray, TestArrayBasics) { this->TestArrayBasics(); }
335330

@@ -386,7 +381,7 @@ class TestUTF8Array : public ::testing::Test {
386381
}
387382
};
388383

389-
TYPED_TEST_SUITE(TestUTF8Array, UTF8Types);
384+
TYPED_TEST_SUITE(TestUTF8Array, StringArrowTypes);
390385

391386
TYPED_TEST(TestUTF8Array, TestValidateUTF8) { this->TestValidateUTF8(); }
392387

@@ -666,7 +661,7 @@ class TestStringBuilder : public TestBuilder {
666661
std::shared_ptr<ArrayType> result_;
667662
};
668663

669-
TYPED_TEST_SUITE(TestStringBuilder, StringTypes);
664+
TYPED_TEST_SUITE(TestStringBuilder, BinaryArrowTypes);
670665

671666
TYPED_TEST(TestStringBuilder, TestScalarAppend) { this->TestScalarAppend(); }
672667

@@ -896,7 +891,7 @@ class TestBinaryDataVisitor : public ::testing::Test {
896891
std::shared_ptr<DataType> type_;
897892
};
898893

899-
TYPED_TEST_SUITE(TestBinaryDataVisitor, StringTypes);
894+
TYPED_TEST_SUITE(TestBinaryDataVisitor, BinaryArrowTypes);
900895

901896
TYPED_TEST(TestBinaryDataVisitor, Basics) { this->TestBasics(); }
902897

cpp/src/arrow/array/concatenate_test.cc

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,7 @@ class PrimitiveConcatenateTest : public ConcatenateTest {
133133
public:
134134
};
135135

136-
using PrimitiveTypes =
137-
::testing::Types<BooleanType, Int8Type, UInt8Type, Int16Type, UInt16Type, Int32Type,
138-
UInt32Type, Int64Type, UInt64Type, FloatType, DoubleType>;
139-
TYPED_TEST_SUITE(PrimitiveConcatenateTest, PrimitiveTypes);
136+
TYPED_TEST_SUITE(PrimitiveConcatenateTest, PrimitiveArrowTypes);
140137

141138
TYPED_TEST(PrimitiveConcatenateTest, Primitives) {
142139
this->Check([this](int64_t size, double null_probability, std::shared_ptr<Array>* out) {

cpp/src/arrow/compute/kernels/aggregate_basic.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -607,10 +607,7 @@ void AddMinMaxKernels(KernelInit init,
607607
const std::vector<std::shared_ptr<DataType>>& types,
608608
ScalarAggregateFunction* func, SimdLevel::type simd_level) {
609609
for (const auto& ty : types) {
610-
// any[T] -> scalar[struct<min: T, max: T>]
611-
auto out_ty = struct_({field("min", ty), field("max", ty)});
612-
auto sig = KernelSignature::Make({InputType(ty->id())}, ValueDescr::Scalar(out_ty));
613-
AddAggKernel(std::move(sig), init, func, simd_level);
610+
AddMinMaxKernel(init, ty, func, simd_level);
614611
}
615612
}
616613

@@ -764,8 +761,12 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
764761

765762
func = std::make_shared<ScalarAggregateFunction>(
766763
"min_max", Arity::Unary(), &min_max_doc, &default_scalar_aggregate_options);
767-
aggregate::AddMinMaxKernels(aggregate::MinMaxInit, {boolean()}, func.get());
764+
aggregate::AddMinMaxKernels(aggregate::MinMaxInit, {null(), boolean()}, func.get());
768765
aggregate::AddMinMaxKernels(aggregate::MinMaxInit, NumericTypes(), func.get());
766+
aggregate::AddMinMaxKernels(aggregate::MinMaxInit, TemporalTypes(), func.get());
767+
aggregate::AddMinMaxKernels(aggregate::MinMaxInit, BaseBinaryTypes(), func.get());
768+
aggregate::AddMinMaxKernel(aggregate::MinMaxInit, Type::FIXED_SIZE_BINARY, func.get());
769+
aggregate::AddMinMaxKernel(aggregate::MinMaxInit, Type::INTERVAL_MONTHS, func.get());
769770
aggregate::AddMinMaxKernel(aggregate::MinMaxInit, Type::DECIMAL128, func.get());
770771
aggregate::AddMinMaxKernel(aggregate::MinMaxInit, Type::DECIMAL256, func.get());
771772
// Add the SIMD variants for min max

cpp/src/arrow/compute/kernels/aggregate_basic_avx2.cc

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,30 +55,32 @@ Result<std::unique_ptr<KernelState>> MeanInitAvx2(KernelContext* ctx,
5555

5656
Result<std::unique_ptr<KernelState>> MinMaxInitAvx2(KernelContext* ctx,
5757
const KernelInitArgs& args) {
58+
ARROW_ASSIGN_OR_RAISE(auto out_type,
59+
args.kernel->signature->out_type().Resolve(ctx, args.inputs));
5860
MinMaxInitState<SimdLevel::AVX2> visitor(
59-
ctx, *args.inputs[0].type, args.kernel->signature->out_type().type(),
61+
ctx, *args.inputs[0].type, std::move(out_type.type),
6062
static_cast<const ScalarAggregateOptions&>(*args.options));
6163
return visitor.Create();
6264
}
6365

6466
void AddSumAvx2AggKernels(ScalarAggregateFunction* func) {
65-
AddBasicAggKernels(SumInitAvx2, internal::SignedIntTypes(), int64(), func,
66-
SimdLevel::AVX2);
67-
AddBasicAggKernels(SumInitAvx2, internal::UnsignedIntTypes(), uint64(), func,
68-
SimdLevel::AVX2);
69-
AddBasicAggKernels(SumInitAvx2, internal::FloatingPointTypes(), float64(), func,
70-
SimdLevel::AVX2);
67+
AddBasicAggKernels(SumInitAvx2, SignedIntTypes(), int64(), func, SimdLevel::AVX2);
68+
AddBasicAggKernels(SumInitAvx2, UnsignedIntTypes(), uint64(), func, SimdLevel::AVX2);
69+
AddBasicAggKernels(SumInitAvx2, FloatingPointTypes(), float64(), func, SimdLevel::AVX2);
7170
}
7271

7372
void AddMeanAvx2AggKernels(ScalarAggregateFunction* func) {
74-
AddBasicAggKernels(MeanInitAvx2, internal::NumericTypes(), float64(), func,
75-
SimdLevel::AVX2);
73+
AddBasicAggKernels(MeanInitAvx2, NumericTypes(), float64(), func, SimdLevel::AVX2);
7674
}
7775

7876
void AddMinMaxAvx2AggKernels(ScalarAggregateFunction* func) {
7977
// Enable int types for AVX2 variants.
8078
// No auto vectorize for float/double as it use fmin/fmax which has NaN handling.
81-
AddMinMaxKernels(MinMaxInitAvx2, internal::IntTypes(), func, SimdLevel::AVX2);
79+
AddMinMaxKernels(MinMaxInitAvx2, IntTypes(), func, SimdLevel::AVX2);
80+
AddMinMaxKernels(MinMaxInitAvx2, TemporalTypes(), func, SimdLevel::AVX2);
81+
AddMinMaxKernels(MinMaxInitAvx2, BaseBinaryTypes(), func, SimdLevel::AVX2);
82+
AddMinMaxKernel(MinMaxInitAvx2, Type::FIXED_SIZE_BINARY, func, SimdLevel::AVX2);
83+
AddMinMaxKernel(MinMaxInitAvx2, Type::INTERVAL_MONTHS, func, SimdLevel::AVX2);
8284
}
8385

8486
} // namespace aggregate

cpp/src/arrow/compute/kernels/aggregate_basic_avx512.cc

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,30 +55,35 @@ Result<std::unique_ptr<KernelState>> MeanInitAvx512(KernelContext* ctx,
5555

5656
Result<std::unique_ptr<KernelState>> MinMaxInitAvx512(KernelContext* ctx,
5757
const KernelInitArgs& args) {
58+
ARROW_ASSIGN_OR_RAISE(auto out_type,
59+
args.kernel->signature->out_type().Resolve(ctx, args.inputs));
5860
MinMaxInitState<SimdLevel::AVX512> visitor(
59-
ctx, *args.inputs[0].type, args.kernel->signature->out_type().type(),
61+
ctx, *args.inputs[0].type, std::move(out_type.type),
6062
static_cast<const ScalarAggregateOptions&>(*args.options));
6163
return visitor.Create();
6264
}
6365

6466
void AddSumAvx512AggKernels(ScalarAggregateFunction* func) {
65-
AddBasicAggKernels(SumInitAvx512, internal::SignedIntTypes(), int64(), func,
67+
AddBasicAggKernels(SumInitAvx512, SignedIntTypes(), int64(), func, SimdLevel::AVX512);
68+
AddBasicAggKernels(SumInitAvx512, UnsignedIntTypes(), uint64(), func,
6669
SimdLevel::AVX512);
67-
AddBasicAggKernels(SumInitAvx512, internal::UnsignedIntTypes(), uint64(), func,
68-
SimdLevel::AVX512);
69-
AddBasicAggKernels(SumInitAvx512, internal::FloatingPointTypes(), float64(), func,
70+
AddBasicAggKernels(SumInitAvx512, FloatingPointTypes(), float64(), func,
7071
SimdLevel::AVX512);
7172
}
7273

7374
void AddMeanAvx512AggKernels(ScalarAggregateFunction* func) {
74-
aggregate::AddBasicAggKernels(MeanInitAvx512, internal::NumericTypes(), float64(), func,
75+
aggregate::AddBasicAggKernels(MeanInitAvx512, NumericTypes(), float64(), func,
7576
SimdLevel::AVX512);
7677
}
7778

7879
void AddMinMaxAvx512AggKernels(ScalarAggregateFunction* func) {
7980
// Enable 32/64 int types for avx512 variants, no advantage on 8/16 int.
8081
AddMinMaxKernels(MinMaxInitAvx512, {int32(), uint32(), int64(), uint64()}, func,
8182
SimdLevel::AVX512);
83+
AddMinMaxKernels(MinMaxInitAvx512, TemporalTypes(), func, SimdLevel::AVX512);
84+
AddMinMaxKernels(MinMaxInitAvx512, BaseBinaryTypes(), func, SimdLevel::AVX2);
85+
AddMinMaxKernel(MinMaxInitAvx512, Type::FIXED_SIZE_BINARY, func, SimdLevel::AVX2);
86+
AddMinMaxKernel(MinMaxInitAvx512, Type::INTERVAL_MONTHS, func, SimdLevel::AVX512);
8287
}
8388

8489
} // namespace aggregate

0 commit comments

Comments
 (0)