Skip to content

Commit b596f29

Browse files
committed
ARROW-15241: [C++] MakeArrayOfNull fails on extension types with a nested storage type
Closes apache#12066 from westonpace/bugfix/ARROW-15241--make-array-of-null-nested-storage Authored-by: Weston Pace <weston.pace@gmail.com> Signed-off-by: Weston Pace <weston.pace@gmail.com>
1 parent 49093a1 commit b596f29

4 files changed

Lines changed: 62 additions & 6 deletions

File tree

cpp/src/arrow/array/array_test.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,13 +364,15 @@ TEST_F(TestArray, TestMakeArrayOfNull) {
364364
dense_union(union_fields1, union_type_codes),
365365
dense_union(union_fields2, union_type_codes),
366366
smallint(), // extension type
367+
list_extension_type(), // nested extension type
367368
// clang-format on
368369
};
369370

370371
for (int64_t length : {0, 1, 16, 133}) {
371372
for (auto type : types) {
372373
ARROW_SCOPED_TRACE("type = ", type->ToString());
373374
ASSERT_OK_AND_ASSIGN(auto array, MakeArrayOfNull(type, length));
375+
ASSERT_EQ(array->type(), type);
374376
ASSERT_OK(array->ValidateFull());
375377
ASSERT_EQ(array->length(), length);
376378
if (is_union(type->id())) {

cpp/src/arrow/array/util.cc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -456,19 +456,19 @@ class NullArrayFactory {
456456
template <typename T>
457457
enable_if_var_size_list<T, Status> Visit(const T& type) {
458458
out_->buffers.resize(2, buffer_);
459-
ARROW_ASSIGN_OR_RAISE(out_->child_data[0], CreateChild(0, /*length=*/0));
459+
ARROW_ASSIGN_OR_RAISE(out_->child_data[0], CreateChild(type, 0, /*length=*/0));
460460
return Status::OK();
461461
}
462462

463463
Status Visit(const FixedSizeListType& type) {
464464
ARROW_ASSIGN_OR_RAISE(out_->child_data[0],
465-
CreateChild(0, length_ * type.list_size()));
465+
CreateChild(type, 0, length_ * type.list_size()));
466466
return Status::OK();
467467
}
468468

469469
Status Visit(const StructType& type) {
470470
for (int i = 0; i < type_->num_fields(); ++i) {
471-
ARROW_ASSIGN_OR_RAISE(out_->child_data[i], CreateChild(i, length_));
471+
ARROW_ASSIGN_OR_RAISE(out_->child_data[i], CreateChild(type, i, length_));
472472
}
473473
return Status::OK();
474474
}
@@ -498,7 +498,7 @@ class NullArrayFactory {
498498
child_length = 1;
499499
}
500500
for (int i = 0; i < type_->num_fields(); ++i) {
501-
ARROW_ASSIGN_OR_RAISE(out_->child_data[i], CreateChild(i, child_length));
501+
ARROW_ASSIGN_OR_RAISE(out_->child_data[i], CreateChild(type, i, child_length));
502502
}
503503
return Status::OK();
504504
}
@@ -511,6 +511,7 @@ class NullArrayFactory {
511511
}
512512

513513
Status Visit(const ExtensionType& type) {
514+
out_->child_data.resize(type.storage_type()->num_fields());
514515
RETURN_NOT_OK(VisitTypeInline(*type.storage_type(), this));
515516
return Status::OK();
516517
}
@@ -519,8 +520,9 @@ class NullArrayFactory {
519520
return Status::NotImplemented("construction of all-null ", type);
520521
}
521522

522-
Result<std::shared_ptr<ArrayData>> CreateChild(int i, int64_t length) {
523-
NullArrayFactory child_factory(pool_, type_->field(i)->type(), length);
523+
Result<std::shared_ptr<ArrayData>> CreateChild(const DataType& type, int i,
524+
int64_t length) {
525+
NullArrayFactory child_factory(pool_, type.field(i)->type(), length);
524526
child_factory.buffer_ = buffer_;
525527
return child_factory.Create();
526528
}

cpp/src/arrow/testing/extension_type.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ class ARROW_TESTING_EXPORT SmallintArray : public ExtensionArray {
5454
using ExtensionArray::ExtensionArray;
5555
};
5656

57+
class ARROW_TESTING_EXPORT ListExtensionArray : public ExtensionArray {
58+
public:
59+
using ExtensionArray::ExtensionArray;
60+
};
61+
5762
class ARROW_TESTING_EXPORT SmallintType : public ExtensionType {
5863
public:
5964
SmallintType() : ExtensionType(int16()) {}
@@ -71,6 +76,23 @@ class ARROW_TESTING_EXPORT SmallintType : public ExtensionType {
7176
std::string Serialize() const override { return "smallint"; }
7277
};
7378

79+
class ARROW_TESTING_EXPORT ListExtensionType : public ExtensionType {
80+
public:
81+
ListExtensionType() : ExtensionType(list(int32())) {}
82+
83+
std::string extension_name() const override { return "list-ext"; }
84+
85+
bool ExtensionEquals(const ExtensionType& other) const override;
86+
87+
std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override;
88+
89+
Result<std::shared_ptr<DataType>> Deserialize(
90+
std::shared_ptr<DataType> storage_type,
91+
const std::string& serialized) const override;
92+
93+
std::string Serialize() const override { return "list-ext"; }
94+
};
95+
7496
class ARROW_TESTING_EXPORT DictExtensionType : public ExtensionType {
7597
public:
7698
DictExtensionType() : ExtensionType(dictionary(int8(), utf8())) {}
@@ -118,6 +140,9 @@ std::shared_ptr<DataType> uuid();
118140
ARROW_TESTING_EXPORT
119141
std::shared_ptr<DataType> smallint();
120142

143+
ARROW_TESTING_EXPORT
144+
std::shared_ptr<DataType> list_extension_type();
145+
121146
ARROW_TESTING_EXPORT
122147
std::shared_ptr<DataType> dict_extension_type();
123148

cpp/src/arrow/testing/gtest_util.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,29 @@ Result<std::shared_ptr<DataType>> SmallintType::Deserialize(
797797
return std::make_shared<SmallintType>();
798798
}
799799

800+
bool ListExtensionType::ExtensionEquals(const ExtensionType& other) const {
801+
return (other.extension_name() == this->extension_name());
802+
}
803+
804+
std::shared_ptr<Array> ListExtensionType::MakeArray(
805+
std::shared_ptr<ArrayData> data) const {
806+
DCHECK_EQ(data->type->id(), Type::EXTENSION);
807+
DCHECK_EQ("list-ext", static_cast<const ExtensionType&>(*data->type).extension_name());
808+
return std::make_shared<ListExtensionArray>(data);
809+
}
810+
811+
Result<std::shared_ptr<DataType>> ListExtensionType::Deserialize(
812+
std::shared_ptr<DataType> storage_type, const std::string& serialized) const {
813+
if (serialized != "list-ext") {
814+
return Status::Invalid("Type identifier did not match: '", serialized, "'");
815+
}
816+
if (!storage_type->Equals(*list(int32()))) {
817+
return Status::Invalid("Invalid storage type for ListExtensionType: ",
818+
storage_type->ToString());
819+
}
820+
return std::make_shared<ListExtensionType>();
821+
}
822+
800823
bool DictExtensionType::ExtensionEquals(const ExtensionType& other) const {
801824
return (other.extension_name() == this->extension_name());
802825
}
@@ -847,6 +870,10 @@ std::shared_ptr<DataType> uuid() { return std::make_shared<UuidType>(); }
847870

848871
std::shared_ptr<DataType> smallint() { return std::make_shared<SmallintType>(); }
849872

873+
std::shared_ptr<DataType> list_extension_type() {
874+
return std::make_shared<ListExtensionType>();
875+
}
876+
850877
std::shared_ptr<DataType> dict_extension_type() {
851878
return std::make_shared<DictExtensionType>();
852879
}

0 commit comments

Comments
 (0)