Skip to content

Commit 7fb41bd

Browse files
Add initial support for castFLOAT4 and castFLOAT8 for varbinary
1 parent 71a3265 commit 7fb41bd

5 files changed

Lines changed: 102 additions & 17 deletions

File tree

cpp/src/gandiva/gdv_function_stubs.cc

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "arrow/util/base64.h"
2626
#include "arrow/util/formatting.h"
2727
#include "arrow/util/utf8.h"
28+
#include "arrow/util/double_conversion.h"
2829
#include "arrow/util/value_parsing.h"
2930
#include "gandiva/engine.h"
3031
#include "gandiva/exported_funcs.h"
@@ -765,6 +766,29 @@ const char* gdv_fn_initcap_utf8(int64_t context, const char* data, int32_t data_
765766
*out_len = out_idx;
766767
return out;
767768
}
769+
770+
#define CAST_FLOAT_VARBINARY(OUT_TYPE, TYPE_NAME) \
771+
GANDIVA_EXPORT \
772+
OUT_TYPE gdv_fn_cast##TYPE_NAME##_varbinary(gdv_int64 context, const char* in, \
773+
int32_t in_len) { \
774+
if (in_len < 0) { \
775+
gdv_fn_context_set_error_msg(context, "Buffer length can not be negative"); \
776+
return -1; \
777+
} \
778+
if (in_len == 0) { \
779+
gdv_fn_context_set_error_msg(context, "Buffer can't be empty"); \
780+
return -1; \
781+
} \
782+
int flags = double_conversion::StringToDoubleConverter::ALLOW_HEX; \
783+
arrow::util::double_conversion::StringToDoubleConverter converter(flags, -1, -1, \
784+
"inf", "NaN"); \
785+
return converter.StringToDouble(in, in_len, 0); \
786+
}
787+
788+
CAST_FLOAT_VARBINARY(float, FLOAT4)
789+
CAST_FLOAT_VARBINARY(double, FLOAT8)
790+
791+
#undef CAST_FLOAT_VARBINARY
768792
}
769793

770794
namespace gandiva {
@@ -1020,6 +1044,22 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const {
10201044
engine->AddGlobalMappingForFunc("gdv_fn_castFLOAT8_utf8", types->double_type(), args,
10211045
reinterpret_cast<void*>(gdv_fn_castFLOAT8_utf8));
10221046

1047+
args = {types->i64_type(), // int64_t context_ptr
1048+
types->i8_ptr_type(), // const char* data
1049+
types->i32_type()}; // int32_t lenr
1050+
1051+
engine->AddGlobalMappingForFunc("gdv_fn_castFLOAT4_varbinary", types->float_type(),
1052+
args,
1053+
reinterpret_cast<void*>(gdv_fn_castFLOAT4_varbinary));
1054+
1055+
args = {types->i64_type(), // int64_t context_ptr
1056+
types->i8_ptr_type(), // const char* data
1057+
types->i32_type()}; // int32_t lenr
1058+
1059+
engine->AddGlobalMappingForFunc("gdv_fn_castFLOAT8_varbinary", types->double_type(),
1060+
args,
1061+
reinterpret_cast<void*>(gdv_fn_castFLOAT8_varbinary));
1062+
10231063
// gdv_fn_castVARCHAR_int32_int64
10241064
args = {types->i64_type(), // int64_t execution_context
10251065
types->i32_type(), // int32_t value

cpp/src/gandiva/gdv_function_stubs.h

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,14 +159,8 @@ const char* gdv_fn_initcap_utf8(int64_t context, const char* data, int32_t data_
159159
int32_t* out_len);
160160

161161
GANDIVA_EXPORT
162-
int32_t gdv_fn_castINT_varbinary(int64_t context, const char* data, int32_t data_len);
162+
float gdv_fn_castFLOAT4_varbinary(gdv_int64 context, const char* in, int32_t in_len);
163163

164164
GANDIVA_EXPORT
165-
int64_t gdv_fn_castBIGINT_varbinary(int64_t context, const char* data, int32_t data_len);
166-
167-
GANDIVA_EXPORT
168-
float gdv_fn_castFLOAT4_varbinary(int64_t context, const char* data, int32_t data_len);
169-
170-
GANDIVA_EXPORT
171-
double gdv_fn_castFLOAT8_varbinary(int64_t context, const char* data, int32_t data_len);
165+
double gdv_fn_castFLOAT8_varbinary(gdv_int64 context, const char* in, int32_t in_len);
172166
}

cpp/src/gandiva/gdv_function_stubs_test.cc

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -759,4 +759,53 @@ TEST(TestGdvFnStubs, TestCastVarbinaryFloat8) {
759759
ctx.Reset();
760760
}
761761

762+
TEST(TestGdvFnStubs, TestCastFLOAT4Varbinary) {
763+
gandiva::ExecutionContext ctx;
764+
uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
765+
766+
EXPECT_EQ(gdv_fn_castFLOAT4_varbinary(ctx_ptr, "-FFF.3", 6), -65523);
767+
EXPECT_FALSE(ctx.has_error());
768+
ctx.Reset();
769+
770+
EXPECT_EQ(gdv_fn_castFLOAT4_varbinary(ctx_ptr, "FFF3", 4), 65523);
771+
EXPECT_FALSE(ctx.has_error());
772+
ctx.Reset();
773+
774+
EXPECT_EQ(gdv_fn_castFLOAT4_varbinary(ctx_ptr, "-7FFFFFFFFFFFFFFF", 17), INT64_MIN + 1);
775+
EXPECT_FALSE(ctx.has_error());
776+
ctx.Reset();
777+
778+
EXPECT_EQ(gdv_fn_castFLOAT4_varbinary(ctx_ptr, "7FFFFFFFFFFFFFFF", 16), INT64_MAX);
779+
EXPECT_FALSE(ctx.has_error());
780+
ctx.Reset();
781+
782+
EXPECT_EQ(gdv_fn_castFLOAT4_varbinary(ctx_ptr, "0", 1), 0);
783+
EXPECT_FALSE(ctx.has_error());
784+
ctx.Reset();
785+
786+
EXPECT_EQ(gdv_fn_castFLOAT4_varbinary(ctx_ptr, "-0", 2), 0);
787+
EXPECT_FALSE(ctx.has_error());
788+
ctx.Reset();
789+
790+
gdv_fn_castFLOAT4_varbinary(ctx_ptr, "", 0);
791+
EXPECT_STREQ(ctx.get_error().c_str(), "Can't cast an empty string.");
792+
ctx.Reset();
793+
794+
gdv_fn_castFLOAT4_varbinary(ctx_ptr, "-", 1);
795+
EXPECT_STREQ(ctx.get_error().c_str(), "Can't cast hexadecimal with only a minus sign.");
796+
ctx.Reset();
797+
798+
gdv_fn_castFLOAT4_varbinary(ctx_ptr, "8FFFFFFFFFFFFFFF", 16);
799+
EXPECT_STREQ(ctx.get_error().c_str(), "Integer overflow.");
800+
ctx.Reset();
801+
802+
gdv_fn_castFLOAT4_varbinary(ctx_ptr, "-8FFFFFFFFFFFFFFF", 17);
803+
EXPECT_STREQ(ctx.get_error().c_str(), "Integer overflow.");
804+
ctx.Reset();
805+
806+
gdv_fn_castFLOAT4_varbinary(ctx_ptr, "-8FFFFFGF", 8);
807+
EXPECT_STREQ(ctx.get_error().c_str(), "The hexadecimal given has invalid characters.");
808+
ctx.Reset();
809+
}
810+
762811
} // namespace gandiva

cpp/src/gandiva/precompiled/string_ops_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1695,7 +1695,7 @@ TEST(TestStringOps, TestCastINTVarbinary) {
16951695
EXPECT_FALSE(ctx.has_error());
16961696
ctx.Reset();
16971697

1698-
EXPECT_EQ(castINT_varbinary(ctx_ptr, "-7FFFFFFF", 9), INT32_MIN+1);
1698+
EXPECT_EQ(castINT_varbinary(ctx_ptr, "-7FFFFFFF", 9), INT32_MIN + 1);
16991699
EXPECT_FALSE(ctx.has_error());
17001700
ctx.Reset();
17011701

@@ -1744,7 +1744,7 @@ TEST(TestStringOps, TestCastBIGINTVarbinary) {
17441744
EXPECT_FALSE(ctx.has_error());
17451745
ctx.Reset();
17461746

1747-
EXPECT_EQ(castBIGINT_varbinary(ctx_ptr, "-7FFFFFFFFFFFFFFF", 17), INT64_MIN+1);
1747+
EXPECT_EQ(castBIGINT_varbinary(ctx_ptr, "-7FFFFFFFFFFFFFFF", 17), INT64_MIN + 1);
17481748
EXPECT_FALSE(ctx.has_error());
17491749
ctx.Reset();
17501750

cpp/src/gandiva/tests/projector_test.cc

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,22 +1146,24 @@ TEST_F(TestProjector, TestCastVarbinaryFunction) {
11461146

11471147
std::shared_ptr<Projector> projector;
11481148

1149-
// {cast_expr_float4, cast_expr_float8, cast_expr_int4, cast_expr_int8}
1150-
auto status = Projector::Make(
1151-
schema, {cast_expr_float4, cast_expr_float8, cast_expr_int4, cast_expr_int8},
1152-
TestConfiguration(), &projector);
1149+
// {cast_expr_int4, cast_expr_int8}
1150+
auto status = Projector::Make(schema, {cast_expr_int4, cast_expr_int8},
1151+
TestConfiguration(), &projector);
11531152
EXPECT_TRUE(status.ok());
11541153

11551154
// Create a row-batch with some sample data
11561155
int num_records = 4;
11571156

11581157
// Last validity is false and the cast functions throw error when input is empty. Should
11591158
// not be evaluated due to addition of NativeFunction::kCanReturnErrors
1160-
auto array0 = MakeArrowArrayBinary({"25", "-7FFFFFFF", "7FFFFFFF", "4"}, {true, true, true, false});
1159+
auto array0 = MakeArrowArrayBinary({"25", "-7FFFFFFF", "7FFFFFFF", "4"},
1160+
{true, true, true, false});
11611161
auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array0});
11621162

1163-
auto out_int4 = MakeArrowArrayInt32({37, INT32_MIN+1, INT32_MAX, 0}, {true, true, true, false});
1164-
auto out_int8 = MakeArrowArrayInt64({37, INT32_MIN+1, INT32_MAX, 0}, {true, true, true, false});
1163+
auto out_int4 =
1164+
MakeArrowArrayInt32({37, INT32_MIN + 1, INT32_MAX, 0}, {true, true, true, false});
1165+
auto out_int8 =
1166+
MakeArrowArrayInt64({37, INT32_MIN + 1, INT32_MAX, 0}, {true, true, true, false});
11651167

11661168
arrow::ArrayVector outputs;
11671169

0 commit comments

Comments
 (0)