Skip to content

Commit 1accdc9

Browse files
fsaintjacquesbkietz
authored andcommitted
ARROW-7210: [C++][R] Allow Numeric <-> Temporal Scalar casts
The end goal of this PR is to minimally support filtering in R on temporal columns. - Refactor Scalar classes - Follow the same hierarchy as the type hierarchy. - Provide 2 constructor for each Scalar, one for constructing a Null value and one providing an explicit value. The is_valid flag is not required by caller anymore. - All scalar types provide a non failing null constructor, i.e MakeNullScalar will not fail. - Ensure that we can cast from a NumericScalar to a TemporalScalar - Add R unit test Closes apache#5921 from fsaintjacques/ARROW-7210-scalar-cast-time and squashes the following commits: 6487568 <François Saint-Jacques> Rebase 3db4b9c <François Saint-Jacques> Review 4b69655 <François Saint-Jacques> Try MSCV fix ea1eb9f <François Saint-Jacques> ARROW-7210: Allow Numeric <-> Temporal casts Authored-by: François Saint-Jacques <fsaintjacques@gmail.com> Signed-off-by: Benjamin Kietzman <bengilgit@gmail.com>
1 parent 68903ac commit 1accdc9

23 files changed

Lines changed: 600 additions & 401 deletions

cpp/src/arrow/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ set(ARROW_SRCS
136136
util/string_builder.cc
137137
util/task_group.cc
138138
util/thread_pool.cc
139+
util/time.cc
139140
util/trie.cc
140141
util/uri.cc
141142
util/utf8.cc

cpp/src/arrow/compare.cc

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -830,33 +830,31 @@ class ScalarEqualsVisitor {
830830
return Status::OK();
831831
}
832832

833+
Status Visit(const BooleanScalar& left) {
834+
const auto& right = checked_cast<const BooleanScalar&>(right_);
835+
result_ = left.value == right.value;
836+
return Status::OK();
837+
}
838+
833839
template <typename T>
834-
typename std::enable_if<std::is_base_of<internal::PrimitiveScalar, T>::value,
835-
Status>::type
840+
typename std::enable_if<
841+
std::is_base_of<internal::PrimitiveScalar<typename T::TypeClass>, T>::value ||
842+
std::is_base_of<TemporalScalar<typename T::TypeClass>, T>::value,
843+
Status>::type
836844
Visit(const T& left_) {
837845
const auto& right = checked_cast<const T&>(right_);
838846
result_ = right.value == left_.value;
839847
return Status::OK();
840848
}
841849

842850
template <typename T>
843-
typename std::enable_if<std::is_base_of<BinaryScalar, T>::value, Status>::type Visit(
844-
const T& left_) {
845-
const auto& left = checked_cast<const BinaryScalar&>(left_);
851+
typename std::enable_if<std::is_base_of<BaseBinaryScalar, T>::value, Status>::type
852+
Visit(const T& left) {
846853
const auto& right = checked_cast<const BinaryScalar&>(right_);
847854
result_ = internal::SharedPtrEquals(left.value, right.value);
848855
return Status::OK();
849856
}
850857

851-
template <typename T>
852-
typename std::enable_if<std::is_base_of<LargeBinaryScalar, T>::value, Status>::type
853-
Visit(const T& left_) {
854-
const auto& left = checked_cast<const LargeBinaryScalar&>(left_);
855-
const auto& right = checked_cast<const LargeBinaryScalar&>(right_);
856-
result_ = internal::SharedPtrEquals(left.value, right.value);
857-
return Status::OK();
858-
}
859-
860858
Status Visit(const Decimal128Scalar& left) {
861859
const auto& right = checked_cast<const Decimal128Scalar&>(right_);
862860
result_ = left.value == right.value;

cpp/src/arrow/compute/kernels/aggregate_test.cc

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ static Datum NaiveSum(const Array& array) {
8383
auto result = NaiveSumPartial<ArrowType>(array);
8484
bool is_valid = result.second > 0;
8585

86-
return Datum(std::make_shared<SumScalarType>(result.first, is_valid));
86+
if (!is_valid) return Datum(std::make_shared<SumScalarType>());
87+
return Datum(std::make_shared<SumScalarType>(result.first));
8788
}
8889

8990
template <typename ArrowType>
@@ -115,11 +116,9 @@ TYPED_TEST(TestNumericSumKernel, SimpleSum) {
115116
using ScalarType = typename TypeTraits<SumType>::ScalarType;
116117
using T = typename TypeParam::c_type;
117118

118-
ValidateSum<TypeParam>(&this->ctx_, "[]",
119-
Datum(std::make_shared<ScalarType>(0, false)));
119+
ValidateSum<TypeParam>(&this->ctx_, "[]", Datum(std::make_shared<ScalarType>()));
120120

121-
ValidateSum<TypeParam>(&this->ctx_, "[null]",
122-
Datum(std::make_shared<ScalarType>(0, false)));
121+
ValidateSum<TypeParam>(&this->ctx_, "[null]", Datum(std::make_shared<ScalarType>()));
123122

124123
ValidateSum<TypeParam>(&this->ctx_, "[0, 1, 2, 3, 4, 5]",
125124
Datum(std::make_shared<ScalarType>(static_cast<T>(5 * 6 / 2))));
@@ -180,7 +179,8 @@ static Datum NaiveMean(const Array& array) {
180179
static_cast<double>(result.second ? result.second : 1UL);
181180
const bool is_valid = result.second > 0;
182181

183-
return Datum(std::make_shared<MeanScalarType>(mean, is_valid));
182+
if (!is_valid) return Datum(std::make_shared<MeanScalarType>());
183+
return Datum(std::make_shared<MeanScalarType>(mean));
184184
}
185185

186186
template <typename ArrowType>
@@ -210,11 +210,9 @@ TYPED_TEST_CASE(TestMeanKernelNumeric, NumericArrowTypes);
210210
TYPED_TEST(TestMeanKernelNumeric, SimpleMean) {
211211
using ScalarType = typename TypeTraits<DoubleType>::ScalarType;
212212

213-
ValidateMean<TypeParam>(&this->ctx_, "[]",
214-
Datum(std::make_shared<ScalarType>(0.0, false)));
213+
ValidateMean<TypeParam>(&this->ctx_, "[]", Datum(std::make_shared<ScalarType>()));
215214

216-
ValidateMean<TypeParam>(&this->ctx_, "[null]",
217-
Datum(std::make_shared<ScalarType>(0.0, false)));
215+
ValidateMean<TypeParam>(&this->ctx_, "[null]", Datum(std::make_shared<ScalarType>()));
218216

219217
ValidateMean<TypeParam>(&this->ctx_, "[1, null, 1]",
220218
Datum(std::make_shared<ScalarType>(1.0)));

cpp/src/arrow/compute/kernels/cast.cc

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#include "arrow/util/logging.h"
3838
#include "arrow/util/macros.h"
3939
#include "arrow/util/parsing.h" // IWYU pragma: keep
40+
#include "arrow/util/time.h"
4041
#include "arrow/util/utf8.h"
4142
#include "arrow/visitor_inline.h"
4243

@@ -407,16 +408,17 @@ struct CastFunctor<
407408
// From one timestamp to another
408409

409410
template <typename in_type, typename out_type>
410-
void ShiftTime(FunctionContext* ctx, const CastOptions& options, const bool is_multiply,
411-
const int64_t factor, const ArrayData& input, ArrayData* output) {
411+
void ShiftTime(FunctionContext* ctx, const CastOptions& options,
412+
const util::DivideOrMultiply factor_op, const int64_t factor,
413+
const ArrayData& input, ArrayData* output) {
412414
const in_type* in_data = input.GetValues<in_type>(1);
413415
auto out_data = output->GetMutableValues<out_type>(1);
414416

415417
if (factor == 1) {
416418
for (int64_t i = 0; i < input.length; i++) {
417419
out_data[i] = static_cast<out_type>(in_data[i]);
418420
}
419-
} else if (is_multiply) {
421+
} else if (factor_op == util::MULTIPLY) {
420422
if (options.allow_time_overflow) {
421423
for (int64_t i = 0; i < input.length; i++) {
422424
out_data[i] = static_cast<out_type>(in_data[i] * factor);
@@ -488,18 +490,6 @@ void ShiftTime(FunctionContext* ctx, const CastOptions& options, const bool is_m
488490
}
489491
}
490492

491-
namespace {
492-
493-
// {is_multiply, factor}
494-
const std::pair<bool, int64_t> kTimeConversionTable[4][4] = {
495-
{{true, 1}, {true, 1000}, {true, 1000000}, {true, 1000000000L}}, // SECOND
496-
{{false, 1000}, {true, 1}, {true, 1000}, {true, 1000000}}, // MILLI
497-
{{false, 1000000}, {false, 1000}, {true, 1}, {true, 1000}}, // MICRO
498-
{{false, 1000000000L}, {false, 1000000}, {false, 1000}, {true, 1}}, // NANO
499-
};
500-
501-
} // namespace
502-
503493
// <TimestampType, TimestampType> and <DurationType, DurationType>
504494
template <typename O, typename I>
505495
struct CastFunctor<
@@ -517,10 +507,8 @@ struct CastFunctor<
517507
return;
518508
}
519509

520-
std::pair<bool, int64_t> conversion =
521-
kTimeConversionTable[static_cast<int>(in_type.unit())]
522-
[static_cast<int>(out_type.unit())];
523-
510+
auto conversion = util::kTimestampConversionTable[static_cast<int>(in_type.unit())]
511+
[static_cast<int>(out_type.unit())];
524512
ShiftTime<int64_t, int64_t>(ctx, options, conversion.first, conversion.second, input,
525513
output);
526514
}
@@ -540,7 +528,7 @@ struct CastFunctor<Date32Type, TimestampType> {
540528
};
541529

542530
const int64_t factor = kTimestampToDateFactors[static_cast<int>(in_type.unit())];
543-
ShiftTime<int64_t, int32_t>(ctx, options, false, factor, input, output);
531+
ShiftTime<int64_t, int32_t>(ctx, options, util::DIVIDE, factor, input, output);
544532
}
545533
};
546534

@@ -550,10 +538,8 @@ struct CastFunctor<Date64Type, TimestampType> {
550538
const ArrayData& input, ArrayData* output) {
551539
const auto& in_type = checked_cast<const TimestampType&>(*input.type);
552540

553-
std::pair<bool, int64_t> conversion =
554-
kTimeConversionTable[static_cast<int>(in_type.unit())]
555-
[static_cast<int>(TimeUnit::MILLI)];
556-
541+
auto conversion = util::kTimestampConversionTable[static_cast<int>(in_type.unit())]
542+
[static_cast<int>(TimeUnit::MILLI)];
557543
ShiftTime<int64_t, int64_t>(ctx, options, conversion.first, conversion.second, input,
558544
output);
559545
if (!ctx->status().ok()) {
@@ -611,9 +597,8 @@ struct CastFunctor<O, I, enable_if_t<is_time_type<I>::value && is_time_type<O>::
611597
return;
612598
}
613599

614-
std::pair<bool, int64_t> conversion =
615-
kTimeConversionTable[static_cast<int>(in_type.unit())]
616-
[static_cast<int>(out_type.unit())];
600+
auto conversion = util::kTimestampConversionTable[static_cast<int>(in_type.unit())]
601+
[static_cast<int>(out_type.unit())];
617602

618603
ShiftTime<in_t, out_t>(ctx, options, conversion.first, conversion.second, input,
619604
output);
@@ -627,15 +612,17 @@ template <>
627612
struct CastFunctor<Date64Type, Date32Type> {
628613
void operator()(FunctionContext* ctx, const CastOptions& options,
629614
const ArrayData& input, ArrayData* output) {
630-
ShiftTime<int32_t, int64_t>(ctx, options, true, kMillisecondsInDay, input, output);
615+
ShiftTime<int32_t, int64_t>(ctx, options, util::MULTIPLY, kMillisecondsInDay, input,
616+
output);
631617
}
632618
};
633619

634620
template <>
635621
struct CastFunctor<Date32Type, Date64Type> {
636622
void operator()(FunctionContext* ctx, const CastOptions& options,
637623
const ArrayData& input, ArrayData* output) {
638-
ShiftTime<int64_t, int32_t>(ctx, options, false, kMillisecondsInDay, input, output);
624+
ShiftTime<int64_t, int32_t>(ctx, options, util::DIVIDE, kMillisecondsInDay, input,
625+
output);
639626
}
640627
};
641628

cpp/src/arrow/compute/kernels/compare_test.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,9 +324,8 @@ TYPED_TEST(TestNumericCompareKernel, SimpleCompareScalarArray) {
324324
TYPED_TEST(TestNumericCompareKernel, TestNullScalar) {
325325
/* Ensure that null scalar broadcast to all null results. */
326326
using ScalarType = typename TypeTraits<TypeParam>::ScalarType;
327-
using CType = typename TypeTraits<TypeParam>::CType;
328327

329-
Datum null(std::make_shared<ScalarType>(CType(0), false));
328+
Datum null(std::make_shared<ScalarType>());
330329
EXPECT_FALSE(null.scalar()->is_valid);
331330

332331
CompareOptions eq(CompareOperator::EQUAL);

cpp/src/arrow/compute/kernels/mean.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ struct MeanState {
4747
const double divisor = static_cast<double>(is_valid ? count : 1UL);
4848
const double mean = static_cast<double>(sum) / divisor;
4949

50-
return std::make_shared<ScalarType>(mean, is_valid);
50+
if (!is_valid) return std::make_shared<ScalarType>();
51+
return std::make_shared<ScalarType>(mean);
5152
}
5253

5354
static std::shared_ptr<DataType> out_type() {

cpp/src/arrow/compute/kernels/sum.cc

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,11 @@ struct SumState {
4242
std::shared_ptr<Scalar> Finalize() const {
4343
using ScalarType = typename TypeTraits<SumType>::ScalarType;
4444

45-
auto boxed = std::make_shared<ScalarType>(this->sum);
4645
if (count == 0) {
47-
// TODO(wesm): Currently null, but fix this
48-
boxed->is_valid = false;
46+
return std::make_shared<ScalarType>();
4947
}
5048

51-
return std::move(boxed);
49+
return MakeScalar(sum);
5250
}
5351

5452
static std::shared_ptr<DataType> out_type() {

cpp/src/arrow/compute/test_util.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,7 @@ struct DatumEqual<Type, enable_if_integer<Type>> {
9797
if (lhs.kind() == Datum::SCALAR) {
9898
auto left = internal::checked_cast<const ScalarType*>(lhs.scalar().get());
9999
auto right = internal::checked_cast<const ScalarType*>(rhs.scalar().get());
100-
ASSERT_EQ(left->is_valid, right->is_valid);
101-
ASSERT_EQ(left->type->id(), right->type->id());
102-
ASSERT_EQ(left->value, right->value);
100+
ASSERT_EQ(*left, *right);
103101
}
104102
}
105103
};

cpp/src/arrow/dataset/file_parquet.cc

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -298,13 +298,7 @@ static ExpressionPtr ColumnChunkStatisticsAsExpression(
298298

299299
// Optimize for corner case where all values are nulls
300300
if (statistics->num_values() == statistics->null_count()) {
301-
auto null_scalar = MakeNullScalar(field->type());
302-
if (null_scalar.ok()) {
303-
// MakeNullScalar can fail for some nested/repeated types.
304-
return scalar(true);
305-
}
306-
307-
return equal(field_expr, scalar(*null_scalar));
301+
return equal(field_expr, scalar(MakeNullScalar(field->type())));
308302
}
309303

310304
std::shared_ptr<Scalar> min, max;

cpp/src/arrow/dataset/filter.cc

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,7 @@ bool ComparisonExpression::Equals(const Expression& other) const {
704704

705705
bool ScalarExpression::Equals(const Expression& other) const {
706706
return other.type() == ExpressionType::SCALAR &&
707-
value_->Equals(checked_cast<const ScalarExpression&>(other).value_);
707+
value_->Equals(*checked_cast<const ScalarExpression&>(other).value_);
708708
}
709709

710710
bool FieldExpression::Equals(const Expression& other) const {
@@ -876,6 +876,14 @@ Result<std::shared_ptr<DataType>> CastExpression::Validate(const Schema& schema)
876876
ARROW_ASSIGN_OR_RAISE(to_type, like->Validate(schema));
877877
}
878878

879+
// Until expressions carry a shape, detect scalar and try to cast it. Works
880+
// if the operand is a scalar leaf.
881+
if (operand_->type() == ExpressionType::SCALAR) {
882+
auto scalar_expr = checked_pointer_cast<ScalarExpression>(operand_);
883+
ARROW_ASSIGN_OR_RAISE(std::ignore, scalar_expr->value()->CastTo(to_type));
884+
return to_type;
885+
}
886+
879887
std::unique_ptr<compute::UnaryKernel> kernel;
880888
RETURN_NOT_OK(GetCastFunction(*operand_type, to_type, options_, &kernel));
881889
return to_type;
@@ -1031,8 +1039,7 @@ std::shared_ptr<ExpressionEvaluator> ExpressionEvaluator::Null() {
10311039
Result<Datum> Evaluate(const Expression& expr, const RecordBatch& batch,
10321040
MemoryPool* pool) const override {
10331041
ARROW_ASSIGN_OR_RAISE(auto type, expr.Validate(*batch.schema()));
1034-
ARROW_ASSIGN_OR_RAISE(auto out, MakeNullScalar(type));
1035-
return Datum(std::move(out));
1042+
return Datum(MakeNullScalar(type));
10361043
}
10371044

10381045
Result<std::shared_ptr<RecordBatch>> Filter(const Datum& selection,
@@ -1205,7 +1212,7 @@ Result<std::shared_ptr<RecordBatch>> TreeEvaluator::Filter(
12051212
selection.kind(), " of type ", *selection.type());
12061213
}
12071214

1208-
if (BooleanScalar(true).Equals(selection.scalar())) {
1215+
if (BooleanScalar(true).Equals(*selection.scalar())) {
12091216
return batch;
12101217
}
12111218

0 commit comments

Comments
 (0)