Skip to content

Commit 188efb7

Browse files
authored
ARROW-17067: [C++][Gandiva] Implement Substring_Index Function. (apache#13600)
Adding Substring_Index Function. Authored-by: SG011 <sahaj.gupta@dremio.com> Signed-off-by: Pindikura Ravindra <ravindra@dremio.com>
1 parent 0330353 commit 188efb7

5 files changed

Lines changed: 226 additions & 1 deletion

File tree

cpp/src/gandiva/function_registry_string.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -515,8 +515,11 @@ std::vector<NativeFunction> GetStringFunctionRegistry() {
515515

516516
NativeFunction("translate", {}, DataTypeVector{utf8(), utf8(), utf8()}, utf8(),
517517
kResultNullIfNull, "translate_utf8_utf8_utf8",
518-
NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors)};
518+
NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors),
519519

520+
NativeFunction("substring_index", {}, DataTypeVector{utf8(), utf8(), int32()},
521+
utf8(), kResultNullIfNull, "gdv_fn_substring_index",
522+
NativeFunction::kNeedsContext | NativeFunction::kCanReturnErrors)};
520523
return string_fn_registry_;
521524
}
522525

cpp/src/gandiva/gdv_function_stubs.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,4 +346,9 @@ gdv_timestamp to_utc_timezone_timestamp(int64_t context, gdv_timestamp time_mili
346346
GANDIVA_EXPORT
347347
gdv_timestamp from_utc_timezone_timestamp(int64_t context, gdv_timestamp time_miliseconds,
348348
const char* timezone, int32_t length);
349+
350+
GANDIVA_EXPORT
351+
const char* gdv_fn_substring_index(int64_t context, const char* txt, int32_t txt_len,
352+
const char* pat, int32_t pat_len, int32_t cnt,
353+
int32_t* out_len);
349354
}

cpp/src/gandiva/gdv_function_stubs_test.cc

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,77 @@ TEST(TestGdvFnStubs, TestCastVARCHARFromDouble) {
453453
EXPECT_FALSE(ctx.has_error());
454454
}
455455

456+
TEST(TestGdvFnStubs, TestSubstringIndex) {
457+
gandiva::ExecutionContext ctx;
458+
uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);
459+
gdv_int32 out_len = 0;
460+
461+
const char* out_str =
462+
gdv_fn_substring_index(ctx_ptr, "Abc.DE.fGh", 10, ".", 1, 2, &out_len);
463+
EXPECT_EQ(std::string(out_str, out_len), "Abc.DE");
464+
EXPECT_FALSE(ctx.has_error());
465+
466+
out_str = gdv_fn_substring_index(ctx_ptr, "Abc.DE.fGh", 10, ".", 1, -2, &out_len);
467+
EXPECT_EQ(std::string(out_str, out_len), "fGh");
468+
EXPECT_FALSE(ctx.has_error());
469+
470+
out_str = gdv_fn_substring_index(ctx_ptr, "S;DCGS;JO!L", 11, ";", 1, 1, &out_len);
471+
EXPECT_EQ(std::string(out_str, out_len), "S");
472+
EXPECT_FALSE(ctx.has_error());
473+
474+
out_str = gdv_fn_substring_index(ctx_ptr, "S;DCGS;JO!L", 11, ";", 1, -1, &out_len);
475+
EXPECT_EQ(std::string(out_str, out_len), "DCGS;JO!L");
476+
EXPECT_FALSE(ctx.has_error());
477+
478+
out_str = gdv_fn_substring_index(ctx_ptr, "www.mysql.com", 13, "Q", 1, 1, &out_len);
479+
EXPECT_EQ(std::string(out_str, out_len), "www.mysql.com");
480+
EXPECT_FALSE(ctx.has_error());
481+
482+
out_str = gdv_fn_substring_index(ctx_ptr, "www||mysql||com", 15, "||", 2, 2, &out_len);
483+
EXPECT_EQ(std::string(out_str, out_len), "www||mysql");
484+
EXPECT_FALSE(ctx.has_error());
485+
486+
out_str = gdv_fn_substring_index(ctx_ptr, "", 0, ".", 1, 1, &out_len);
487+
EXPECT_EQ(std::string(out_str, out_len).size(), 0);
488+
EXPECT_FALSE(ctx.has_error());
489+
490+
out_str = gdv_fn_substring_index(ctx_ptr, "www||mysql||com", 15, "", 0, 1, &out_len);
491+
EXPECT_EQ(std::string(out_str, out_len).size(), 0);
492+
EXPECT_FALSE(ctx.has_error());
493+
494+
out_str = gdv_fn_substring_index(ctx_ptr, "www||mysql||com", 15, "||", 2, 0, &out_len);
495+
EXPECT_EQ(std::string(out_str, out_len).size(), 0);
496+
EXPECT_FALSE(ctx.has_error());
497+
498+
out_str = gdv_fn_substring_index(ctx_ptr, "www||mysql||com", 15, "||", 2, -2, &out_len);
499+
EXPECT_EQ(std::string(out_str, out_len), "com");
500+
EXPECT_FALSE(ctx.has_error());
501+
502+
out_str = gdv_fn_substring_index(ctx_ptr, "MÜNCHEN", 8, "Ü", 2, 1, &out_len);
503+
EXPECT_EQ(std::string(out_str, out_len), "M");
504+
EXPECT_FALSE(ctx.has_error());
505+
506+
out_str = gdv_fn_substring_index(ctx_ptr, "MÜNCHEN", 8, "Ü", 2, -1, &out_len);
507+
EXPECT_EQ(std::string(out_str, out_len), "NCHEN");
508+
EXPECT_FALSE(ctx.has_error());
509+
510+
out_str = gdv_fn_substring_index(ctx_ptr, "citroën", 8, "ë", 2, -1, &out_len);
511+
EXPECT_EQ(std::string(out_str, out_len), "n");
512+
EXPECT_FALSE(ctx.has_error());
513+
514+
out_str = gdv_fn_substring_index(ctx_ptr, "citroën", 8, "ë", 2, 1, &out_len);
515+
EXPECT_EQ(std::string(out_str, out_len), "citro");
516+
EXPECT_FALSE(ctx.has_error());
517+
518+
out_str = gdv_fn_substring_index(ctx_ptr, "路学\\L", 8, "\\", 1, 1, &out_len);
519+
EXPECT_EQ(std::string(out_str, out_len), "路学");
520+
EXPECT_FALSE(ctx.has_error());
521+
522+
out_str = gdv_fn_substring_index(ctx_ptr, "路学\\L", 8, "\\", 1, -1, &out_len);
523+
EXPECT_EQ(std::string(out_str, out_len), "L");
524+
EXPECT_FALSE(ctx.has_error());
525+
}
526+
456527
TEST(TestGdvFnStubs, TestUpper) {
457528
gandiva::ExecutionContext ctx;
458529
uint64_t ctx_ptr = reinterpret_cast<gdv_int64>(&ctx);

cpp/src/gandiva/gdv_string_function_stubs.cc

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,94 @@ const char* gdv_fn_upper_utf8(int64_t context, const char* data, int32_t data_le
336336
return out;
337337
}
338338

339+
// Substring_index
340+
GDV_FORCE_INLINE
341+
const char* gdv_fn_substring_index(int64_t context, const char* txt, int32_t txt_len,
342+
const char* pat, int32_t pat_len, int32_t cnt,
343+
int32_t* out_len) {
344+
if (txt_len == 0 || pat_len == 0 || cnt == 0) {
345+
*out_len = 0;
346+
return "";
347+
}
348+
349+
char* out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, txt_len));
350+
if (out == nullptr) {
351+
gdv_fn_context_set_error_msg(context, "Could not allocate memory for output string");
352+
*out_len = 0;
353+
return "";
354+
}
355+
356+
std::vector<int> lps(pat_len);
357+
int len = 0;
358+
359+
lps[0] = 0; // lps[0] is always 0
360+
361+
// the loop calculates lps[i] for i = 1 to M-1
362+
int i = 1;
363+
while (i < pat_len) {
364+
if (pat[i] == pat[len]) {
365+
len++;
366+
lps[i] = len;
367+
i++;
368+
} else {
369+
// (pat[i] != pat[len])
370+
// This is tricky. Consider the example.
371+
// AAACAAAA and i = 7. The idea is similar
372+
// to search step.
373+
if (len != 0) {
374+
len = lps[len - 1];
375+
376+
// Also, note that we do not increment
377+
// i here
378+
} else {
379+
// if (len == 0)
380+
lps[i] = 0;
381+
i++;
382+
}
383+
}
384+
}
385+
386+
std::vector<int> occ;
387+
388+
i = 0; // index for txt[]
389+
int j = 0; // index for pat[]
390+
while (i < txt_len) {
391+
if (pat[j] == txt[i]) {
392+
j++;
393+
i++;
394+
}
395+
396+
if (j == pat_len) {
397+
occ.push_back(i - j);
398+
j = lps[j - 1];
399+
} else if (i < txt_len && pat[j] != txt[i]) {
400+
// mismatch after j matches
401+
// Do not match lps[0..lps[j-1]] characters,
402+
// they will match anyway
403+
if (j != 0)
404+
j = lps[j - 1];
405+
else
406+
i = i + 1;
407+
}
408+
}
409+
410+
if (static_cast<int32_t>(abs(cnt)) <= static_cast<int32_t>(occ.size()) && cnt > 0) {
411+
memcpy(out, txt, occ[cnt - 1]);
412+
*out_len = occ[cnt - 1];
413+
return out;
414+
} else if (static_cast<int32_t>(abs(cnt)) <= static_cast<int32_t>(occ.size()) &&
415+
cnt < 0) {
416+
int32_t temp = static_cast<int32_t>(abs(cnt));
417+
memcpy(out, txt + occ[temp - 1] + pat_len, txt_len - occ[temp - 1] - pat_len);
418+
*out_len = txt_len - occ[temp - 1] - pat_len;
419+
return out;
420+
} else {
421+
*out_len = txt_len;
422+
memcpy(out, txt, txt_len);
423+
return out;
424+
}
425+
}
426+
339427
// Any codepoint, except the ones for lowercase letters, uppercase letters,
340428
// titlecase letters, decimal digits and letter numbers categories will be
341429
// considered as word separators.
@@ -855,6 +943,21 @@ void ExportedStringFunctions::AddMappings(Engine* engine) const {
855943
types->i8_ptr_type() /*return_type*/, args,
856944
reinterpret_cast<void*>(gdv_fn_upper_utf8));
857945

946+
// gdv_fn_substring_index
947+
args = {
948+
types->i64_type(), // context
949+
types->i8_ptr_type(), // txt
950+
types->i32_type(), // txt_len
951+
types->i8_ptr_type(), // pat
952+
types->i32_type(), // pat_len
953+
types->i32_type(), // cnt
954+
types->i32_ptr_type(), // out_len
955+
};
956+
957+
engine->AddGlobalMappingForFunc("gdv_fn_substring_index",
958+
types->i8_ptr_type() /*return_type*/, args,
959+
reinterpret_cast<void*>(gdv_fn_substring_index));
960+
858961
// gdv_fn_initcap_utf8
859962
args = {
860963
types->i64_type(), // context

cpp/src/gandiva/tests/projector_test.cc

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2992,6 +2992,49 @@ TEST_F(TestProjector, TestUCase) {
29922992
EXPECT_ARROW_ARRAY_EQUALS(out_1, outputs.at(0));
29932993
}
29942994

2995+
TEST_F(TestProjector, TestSubstringIndex) {
2996+
auto field1 = field("f1", arrow::utf8());
2997+
auto field2 = field("f2", arrow::utf8());
2998+
auto field3 = field("f3", arrow::int32());
2999+
auto schema = arrow::schema({field1, field2, field3});
3000+
3001+
// output fields
3002+
auto substring_index = field("substring", arrow::utf8());
3003+
3004+
// Build expression
3005+
auto substring_expr = TreeExprBuilder::MakeExpression(
3006+
"substring_index", {field1, field2, field3}, substring_index);
3007+
3008+
std::shared_ptr<Projector> projector;
3009+
3010+
auto status =
3011+
Projector::Make(schema, {substring_expr}, TestConfiguration(), &projector);
3012+
3013+
EXPECT_TRUE(status.ok());
3014+
3015+
// Create a row-batch with some sample data
3016+
int num_records = 3;
3017+
3018+
auto array1 = MakeArrowArrayUtf8({"www||mysql||com", "www||mysql||com", "S;DCGS;JO!L"},
3019+
{true, true, true});
3020+
3021+
auto array2 = MakeArrowArrayUtf8({"||", "||", ";"}, {true, true, true});
3022+
3023+
auto array3 = MakeArrowArrayInt32({2, -2, -1}, {true, true, true});
3024+
3025+
auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array1, array2, array3});
3026+
3027+
auto out_1 = MakeArrowArrayUtf8({"www||mysql", "com", "DCGS;JO!L"}, {true, true, true});
3028+
3029+
arrow::ArrayVector outputs;
3030+
3031+
// Evaluate expression
3032+
status = projector->Evaluate(*in_batch, pool_, &outputs);
3033+
EXPECT_TRUE(status.ok());
3034+
3035+
EXPECT_ARROW_ARRAY_EQUALS(out_1, outputs.at(0));
3036+
}
3037+
29953038
TEST_F(TestProjector, TestLCase) {
29963039
auto field0 = field("f0", arrow::utf8());
29973040
auto schema = arrow::schema({field0});

0 commit comments

Comments
 (0)