Skip to content

Commit 50a7d15

Browse files
authored
ARROW-16695: [R][Python][C++] Extension types are not supported in joins (apache#13501)
This is to resolve [ARROW-16695](https://issues.apache.org/jira/browse/ARROW-16695). Authored-by: Rok <rok@mihevc.org> Signed-off-by: Rok <rok@mihevc.org>
1 parent 672431b commit 50a7d15

7 files changed

Lines changed: 246 additions & 12 deletions

File tree

cpp/src/arrow/compute/exec/hash_join_node.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ bool HashJoinSchema::IsTypeSupported(const DataType& type) {
4545
if (id == Type::DICTIONARY) {
4646
return IsTypeSupported(*checked_cast<const DictionaryType&>(type).value_type());
4747
}
48+
if (id == Type::EXTENSION) {
49+
return IsTypeSupported(*checked_cast<const ExtensionType&>(type).storage_type());
50+
}
4851
return is_fixed_width(id) || is_binary_like(id) || is_large_binary_like(id);
4952
}
5053

cpp/src/arrow/compute/exec/hash_join_node_test.cc

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "arrow/compute/exec/util.h"
2828
#include "arrow/compute/kernels/row_encoder.h"
2929
#include "arrow/compute/kernels/test_util.h"
30+
#include "arrow/testing/extension_type.h"
3031
#include "arrow/testing/gtest_util.h"
3132
#include "arrow/testing/matchers.h"
3233
#include "arrow/testing/random.h"
@@ -1801,6 +1802,114 @@ TEST(HashJoin, UnsupportedTypes) {
18011802
}
18021803
}
18031804

1805+
void TestSimpleJoinHelper(BatchesWithSchema input_left, BatchesWithSchema input_right,
1806+
BatchesWithSchema expected) {
1807+
ExecContext exec_ctx;
1808+
ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_ctx));
1809+
AsyncGenerator<util::optional<ExecBatch>> sink_gen;
1810+
1811+
ExecNode* left_source;
1812+
ExecNode* right_source;
1813+
ASSERT_OK_AND_ASSIGN(
1814+
left_source,
1815+
MakeExecNode("source", plan.get(), {},
1816+
SourceNodeOptions{input_left.schema, input_left.gen(/*parallel=*/false,
1817+
/*slow=*/false)}));
1818+
1819+
ASSERT_OK_AND_ASSIGN(right_source,
1820+
MakeExecNode("source", plan.get(), {},
1821+
SourceNodeOptions{input_right.schema,
1822+
input_right.gen(/*parallel=*/false,
1823+
/*slow=*/false)}));
1824+
1825+
HashJoinNodeOptions join_opts{JoinType::INNER,
1826+
/*left_keys=*/{"lkey"},
1827+
/*right_keys=*/{"rkey"}, literal(true), "_l", "_r"};
1828+
1829+
ASSERT_OK_AND_ASSIGN(
1830+
auto hashjoin,
1831+
MakeExecNode("hashjoin", plan.get(), {left_source, right_source}, join_opts));
1832+
1833+
ASSERT_OK_AND_ASSIGN(std::ignore, MakeExecNode("sink", plan.get(), {hashjoin},
1834+
SinkNodeOptions{&sink_gen}));
1835+
1836+
ASSERT_FINISHES_OK_AND_ASSIGN(auto result, StartAndCollect(plan.get(), sink_gen));
1837+
1838+
ASSERT_OK_AND_ASSIGN(
1839+
auto output_rows_test,
1840+
TableFromExecBatches(std::move(hashjoin->output_schema()), result));
1841+
ASSERT_OK_AND_ASSIGN(
1842+
auto expected_rows_test,
1843+
TableFromExecBatches(std::move(expected.schema), expected.batches));
1844+
1845+
AssertTablesEqual(*output_rows_test, *expected_rows_test, /*same_chunk_layout=*/false,
1846+
/*flatten=*/true);
1847+
AssertSchemaEqual(expected.schema, hashjoin->output_schema());
1848+
}
1849+
1850+
TEST(HashJoin, ExtensionTypesSwissJoin) {
1851+
// For simpler types swiss join will be used.
1852+
auto ext_arr = ExampleUuid();
1853+
auto l_int_arr = ArrayFromJSON(int32(), "[1, 2, 3, 4]");
1854+
auto l_int_arr2 = ArrayFromJSON(int32(), "[4, 5, 6, 7]");
1855+
auto r_int_arr = ArrayFromJSON(int32(), "[4, 3, 2, null, 1]");
1856+
1857+
BatchesWithSchema input_left;
1858+
ASSERT_OK_AND_ASSIGN(ExecBatch left_batches,
1859+
ExecBatch::Make({l_int_arr, l_int_arr2, ext_arr}));
1860+
input_left.batches = {left_batches};
1861+
input_left.schema = schema(
1862+
{field("lkey", int32()), field("shared", int32()), field("ldistinct", uuid())});
1863+
1864+
BatchesWithSchema input_right;
1865+
ASSERT_OK_AND_ASSIGN(ExecBatch right_batches, ExecBatch::Make({r_int_arr}));
1866+
input_right.batches = {right_batches};
1867+
input_right.schema = schema({field("rkey", int32())});
1868+
1869+
BatchesWithSchema expected;
1870+
ASSERT_OK_AND_ASSIGN(ExecBatch expected_batches,
1871+
ExecBatch::Make({l_int_arr, l_int_arr2, ext_arr, l_int_arr}));
1872+
expected.batches = {expected_batches};
1873+
expected.schema = schema({field("lkey", int32()), field("shared", int32()),
1874+
field("ldistinct", uuid()), field("rkey", int32())});
1875+
1876+
TestSimpleJoinHelper(input_left, input_right, expected);
1877+
}
1878+
1879+
TEST(HashJoin, ExtensionTypesHashJoin) {
1880+
// Swiss join doesn't support dictionaries so HashJoin will be used.
1881+
auto dict_type = dictionary(int64(), int8());
1882+
auto ext_arr = ExampleUuid();
1883+
auto l_int_arr = ArrayFromJSON(int32(), "[1, 2, 3, 4]");
1884+
auto l_int_arr2 = ArrayFromJSON(int32(), "[4, 5, 6, 7]");
1885+
auto r_int_arr = ArrayFromJSON(int32(), "[4, 3, 2, null, 1]");
1886+
auto l_dict_array =
1887+
DictArrayFromJSON(dict_type, R"([2, 0, 1, null])", R"([null, 0, 1])");
1888+
1889+
BatchesWithSchema input_left;
1890+
ASSERT_OK_AND_ASSIGN(ExecBatch left_batches,
1891+
ExecBatch::Make({l_int_arr, l_int_arr2, ext_arr, l_dict_array}));
1892+
input_left.batches = {left_batches};
1893+
input_left.schema = schema({field("lkey", int32()), field("shared", int32()),
1894+
field("ldistinct", uuid()), field("dict_type", dict_type)});
1895+
1896+
BatchesWithSchema input_right;
1897+
ASSERT_OK_AND_ASSIGN(ExecBatch right_batches, ExecBatch::Make({r_int_arr}));
1898+
input_right.batches = {right_batches};
1899+
input_right.schema = schema({field("rkey", int32())});
1900+
1901+
BatchesWithSchema expected;
1902+
ASSERT_OK_AND_ASSIGN(
1903+
ExecBatch expected_batches,
1904+
ExecBatch::Make({l_int_arr, l_int_arr2, ext_arr, l_dict_array, l_int_arr}));
1905+
expected.batches = {expected_batches};
1906+
expected.schema = schema({field("lkey", int32()), field("shared", int32()),
1907+
field("ldistinct", uuid()), field("dict_type", dict_type),
1908+
field("rkey", int32())});
1909+
1910+
TestSimpleJoinHelper(input_left, input_right, expected);
1911+
}
1912+
18041913
TEST(HashJoin, CheckHashJoinNodeOptionsValidation) {
18051914
auto exec_ctx =
18061915
arrow::internal::make_unique<ExecContext>(default_memory_pool(), nullptr);

cpp/src/arrow/compute/exec/util_test.cc

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include "arrow/compute/exec/hash_join_node.h"
1919
#include "arrow/compute/exec/schema_util.h"
20+
#include "arrow/testing/extension_type.h"
2021
#include "arrow/testing/gtest_util.h"
2122
#include "arrow/testing/matchers.h"
2223

@@ -128,5 +129,60 @@ TEST(FieldMap, TwoKeyFields) {
128129
})));
129130
}
130131

132+
TEST(FieldMap, ExtensionTypeSwissJoin) {
133+
// For simpler types swiss join will be used.
134+
HashJoinSchema schema_mgr;
135+
136+
auto left = schema({field("i32", int32()), field("ext", uuid())});
137+
auto right = schema({field("i32", int32())});
138+
139+
ASSERT_OK(schema_mgr.Init(JoinType::INNER, *left, {"i32"}, *right, {"i32"},
140+
literal(true), kLeftSuffix, kRightSuffix));
141+
142+
EXPECT_EQ(schema_mgr.proj_maps[0].num_cols(HashJoinProjection::INPUT), 2);
143+
EXPECT_EQ(schema_mgr.proj_maps[0].num_cols(HashJoinProjection::KEY), 1);
144+
EXPECT_EQ(schema_mgr.proj_maps[1].num_cols(HashJoinProjection::KEY), 1);
145+
EXPECT_EQ(schema_mgr.proj_maps[0].num_cols(HashJoinProjection::OUTPUT), 2);
146+
147+
auto output = schema_mgr.MakeOutputSchema(kLeftSuffix, kRightSuffix);
148+
EXPECT_THAT(*output, Eq(Schema({field("i32.left", int32()), field("ext", uuid()),
149+
field("i32.right", int32())})));
150+
151+
auto i =
152+
schema_mgr.proj_maps[0].map(HashJoinProjection::INPUT, HashJoinProjection::OUTPUT);
153+
EXPECT_EQ(i.get(0), 0);
154+
}
155+
156+
TEST(FieldMap, ExtensionTypeHashJoin) {
157+
// Swiss join doesn't support dictionaries so HashJoin will be used.
158+
HashJoinSchema schema_mgr;
159+
160+
auto dict_type = dictionary(int64(), int8());
161+
auto left = schema({field("i32", int32()), field("ext", uuid())});
162+
auto right = schema({field("i32", int32()), field("dict_type", dict_type)});
163+
164+
ASSERT_OK(schema_mgr.Init(JoinType::INNER, *left, {"i32"}, *right, {"i32"},
165+
literal(true), kLeftSuffix, kRightSuffix));
166+
167+
EXPECT_EQ(schema_mgr.proj_maps[0].num_cols(HashJoinProjection::INPUT), 2);
168+
EXPECT_EQ(schema_mgr.proj_maps[1].num_cols(HashJoinProjection::INPUT), 2);
169+
EXPECT_EQ(schema_mgr.proj_maps[0].num_cols(HashJoinProjection::KEY), 1);
170+
EXPECT_EQ(schema_mgr.proj_maps[1].num_cols(HashJoinProjection::KEY), 1);
171+
EXPECT_EQ(schema_mgr.proj_maps[0].num_cols(HashJoinProjection::OUTPUT), 2);
172+
EXPECT_EQ(schema_mgr.proj_maps[1].num_cols(HashJoinProjection::OUTPUT), 2);
173+
174+
auto output = schema_mgr.MakeOutputSchema(kLeftSuffix, kRightSuffix);
175+
EXPECT_THAT(*output, Eq(Schema({
176+
field("i32.left", int32()),
177+
field("ext", uuid()),
178+
field("i32.right", int32()),
179+
field("dict_type", dict_type),
180+
})));
181+
182+
auto i =
183+
schema_mgr.proj_maps[0].map(HashJoinProjection::INPUT, HashJoinProjection::OUTPUT);
184+
EXPECT_EQ(i.get(0), 0);
185+
}
186+
131187
} // namespace compute
132188
} // namespace arrow

cpp/src/arrow/compute/kernels/row_encoder.cc

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,20 @@ Result<std::shared_ptr<ArrayData>> DictionaryKeyEncoder::Decode(uint8_t** encode
257257
void RowEncoder::Init(const std::vector<TypeHolder>& column_types, ExecContext* ctx) {
258258
ctx_ = ctx;
259259
encoders_.resize(column_types.size());
260+
extension_types_.resize(column_types.size());
260261

261262
for (size_t i = 0; i < column_types.size(); ++i) {
262-
const TypeHolder& type = column_types[i];
263+
const bool is_extension = column_types[i].id() == Type::EXTENSION;
264+
const TypeHolder& type = is_extension
265+
? arrow::internal::checked_pointer_cast<ExtensionType>(
266+
column_types[i].GetSharedPtr())
267+
->storage_type()
268+
: column_types[i];
269+
270+
if (is_extension) {
271+
extension_types_[i] = arrow::internal::checked_pointer_cast<ExtensionType>(
272+
column_types[i].GetSharedPtr());
273+
}
263274
if (type.id() == Type::BOOL) {
264275
encoders_[i] = std::make_shared<BooleanKeyEncoder>();
265276
continue;
@@ -354,9 +365,16 @@ Result<ExecBatch> RowEncoder::Decode(int64_t num_rows, const int32_t* row_ids) {
354365
out.values.resize(encoders_.size());
355366
for (size_t i = 0; i < encoders_.size(); ++i) {
356367
ARROW_ASSIGN_OR_RAISE(
357-
out.values[i],
368+
auto column_array_data,
358369
encoders_[i]->Decode(buf_ptrs.data(), static_cast<int32_t>(num_rows),
359370
ctx_->memory_pool()));
371+
372+
if (extension_types_[i] != nullptr) {
373+
ARROW_ASSIGN_OR_RAISE(out.values[i], ::arrow::internal::GetArrayView(
374+
column_array_data, extension_types_[i]))
375+
} else {
376+
out.values[i] = column_array_data;
377+
}
360378
}
361379

362380
return out;

cpp/src/arrow/compute/kernels/row_encoder.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ class ARROW_EXPORT RowEncoder {
280280
std::vector<int32_t> offsets_;
281281
std::vector<uint8_t> bytes_;
282282
std::vector<uint8_t> encoded_nulls_;
283+
std::vector<std::shared_ptr<ExtensionType>> extension_types_;
283284
};
284285

285286
} // namespace internal

cpp/src/arrow/compute/light_array.cc

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -109,31 +109,37 @@ KeyColumnArray KeyColumnArray::Slice(int64_t offset, int64_t length) const {
109109

110110
Result<KeyColumnMetadata> ColumnMetadataFromDataType(
111111
const std::shared_ptr<DataType>& type) {
112-
if (type->id() == Type::DICTIONARY) {
112+
const bool is_extension = type->id() == Type::EXTENSION;
113+
const std::shared_ptr<DataType>& typ =
114+
is_extension
115+
? arrow::internal::checked_pointer_cast<ExtensionType>(type->GetSharedPtr())
116+
->storage_type()
117+
: type;
118+
119+
if (typ->id() == Type::DICTIONARY) {
113120
auto bit_width =
114-
arrow::internal::checked_cast<const FixedWidthType&>(*type).bit_width();
121+
arrow::internal::checked_cast<const FixedWidthType&>(*typ).bit_width();
115122
ARROW_DCHECK(bit_width % 8 == 0);
116123
return KeyColumnMetadata(true, bit_width / 8);
117124
}
118-
if (type->id() == Type::BOOL) {
125+
if (typ->id() == Type::BOOL) {
119126
return KeyColumnMetadata(true, 0);
120127
}
121-
if (is_fixed_width(type->id())) {
128+
if (is_fixed_width(typ->id())) {
122129
return KeyColumnMetadata(
123-
true,
124-
arrow::internal::checked_cast<const FixedWidthType&>(*type).bit_width() / 8);
130+
true, arrow::internal::checked_cast<const FixedWidthType&>(*typ).bit_width() / 8);
125131
}
126-
if (is_binary_like(type->id())) {
132+
if (is_binary_like(typ->id())) {
127133
return KeyColumnMetadata(false, sizeof(uint32_t));
128134
}
129-
if (is_large_binary_like(type->id())) {
135+
if (is_large_binary_like(typ->id())) {
130136
return KeyColumnMetadata(false, sizeof(uint64_t));
131137
}
132-
if (type->id() == Type::NA) {
138+
if (typ->id() == Type::NA) {
133139
return KeyColumnMetadata(true, 0, true);
134140
}
135141
// Caller attempted to create a KeyColumnArray from an invalid type
136-
return Status::TypeError("Unsupported column data type ", type->name(),
142+
return Status::TypeError("Unsupported column data type ", typ->name(),
137143
" used with KeyColumnMetadata");
138144
}
139145

python/pyarrow/tests/test_exec_plan.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import pytest
1919
import pyarrow as pa
2020
import pyarrow.compute as pc
21+
from .test_extension_type import IntegerType
2122

2223
try:
2324
import pyarrow.dataset as ds
@@ -280,3 +281,43 @@ def test_complex_filter_table():
280281
"a": [2, 4, 6], # second six must be omitted because 6*10 != 61
281282
"b": [20, 40, 60]
282283
})
284+
285+
286+
def test_join_extension_array_column():
287+
storage = pa.array([1, 2, 3], type=pa.int64())
288+
ty = IntegerType()
289+
ext_array = pa.ExtensionArray.from_storage(ty, storage)
290+
dict_array = pa.DictionaryArray.from_arrays(
291+
pa.array([0, 2, 1]), pa.array(['a', 'b', 'c']))
292+
t1 = pa.table({
293+
"colA": [1, 2, 6],
294+
"colB": ext_array,
295+
"colVals": ext_array,
296+
})
297+
298+
t2 = pa.table({
299+
"colA": [99, 2, 1],
300+
"colC": ext_array,
301+
})
302+
303+
t3 = pa.table({
304+
"colA": [99, 2, 1],
305+
"colC": ext_array,
306+
"colD": dict_array,
307+
})
308+
309+
result = ep._perform_join(
310+
"left outer", t1, ["colA"], t2, ["colA"])
311+
assert result["colVals"] == pa.chunked_array(ext_array)
312+
313+
result = ep._perform_join(
314+
"left outer", t1, ["colB"], t2, ["colC"])
315+
assert result["colB"] == pa.chunked_array(ext_array)
316+
317+
result = ep._perform_join(
318+
"left outer", t1, ["colA"], t3, ["colA"])
319+
assert result["colVals"] == pa.chunked_array(ext_array)
320+
321+
result = ep._perform_join(
322+
"left outer", t1, ["colB"], t3, ["colC"])
323+
assert result["colB"] == pa.chunked_array(ext_array)

0 commit comments

Comments
 (0)