Skip to content

Commit 77f099f

Browse files
LouisCltpitrouwjones127
authored
ARROW-17825: [C++] Allow the possibility to write several tables in ORCFileWriter (apache#14219)
I had the need to write an ORC file little by little, so as to not consume too much memory. Following [this](apache#14211) discussion, it appeared that the API did not seemed to prevent doing that, but that the internal implementation was not reusing the writer accordingly. This PR makes the needed changes to reuse the "writer_" correctly. I do not think that the preceding behaviour was correct, as calling several time the "Write" method would lead to incorrect ORC files. Lead-authored-by: LouisClt <louis1110@hotmail.fr> Co-authored-by: Antoine Pitrou <pitrou@free.fr> Co-authored-by: Will Jones <willjones127@gmail.com> Signed-off-by: Antoine Pitrou <antoine@python.org>
1 parent 4d8c21b commit 77f099f

3 files changed

Lines changed: 149 additions & 10 deletions

File tree

cpp/src/arrow/adapters/orc/adapter.cc

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -727,12 +727,23 @@ class ORCFileWriter::Impl {
727727
}
728728

729729
Status Write(const Table& table) {
730-
ARROW_ASSIGN_OR_RAISE(auto orc_schema, GetOrcType(*(table.schema())));
731-
ARROW_ASSIGN_OR_RAISE(auto orc_options, MakeOrcWriterOptions(write_options_));
730+
if (!writer_.get()) {
731+
ARROW_ASSIGN_OR_RAISE(orc_schema_, GetOrcType(*(table.schema())));
732+
ARROW_ASSIGN_OR_RAISE(auto orc_options, MakeOrcWriterOptions(write_options_));
733+
arrow_schema_ = table.schema();
734+
ORC_CATCH_NOT_OK(
735+
writer_ = liborc::createWriter(*orc_schema_, out_stream_.get(), orc_options))
736+
} else {
737+
bool schemas_matching = table.schema()->Equals(arrow_schema_, false);
738+
if (!schemas_matching) {
739+
return Status::TypeError(
740+
"The schema of the RecordBatch does not match"
741+
" the initial schema. All exported RecordBatches/Tables"
742+
" must have the same schema.\nInitial:\n",
743+
*arrow_schema_, "\nCurrent:\n", *table.schema());
744+
}
745+
}
732746
auto batch_size = static_cast<uint64_t>(write_options_.batch_size);
733-
ORC_CATCH_NOT_OK(
734-
writer_ = liborc::createWriter(*orc_schema, out_stream_.get(), orc_options))
735-
736747
int64_t num_rows = table.num_rows();
737748
const int num_cols = table.num_columns();
738749
std::vector<int64_t> arrow_index_offset(num_cols, 0);
@@ -744,7 +755,7 @@ class ORCFileWriter::Impl {
744755
while (num_rows > 0) {
745756
for (int i = 0; i < num_cols; i++) {
746757
RETURN_NOT_OK(adapters::orc::WriteBatch(
747-
*(table.column(i)), batch_size, &(arrow_chunk_offset[i]),
758+
*table.column(i), batch_size, &(arrow_chunk_offset[i]),
748759
&(arrow_index_offset[i]), (root->fields)[i]));
749760
}
750761
root->numElements = (root->fields)[0]->numElements;
@@ -765,7 +776,9 @@ class ORCFileWriter::Impl {
765776
private:
766777
std::unique_ptr<liborc::Writer> writer_;
767778
std::unique_ptr<liborc::OutputStream> out_stream_;
779+
std::shared_ptr<Schema> arrow_schema_;
768780
WriteOptions write_options_;
781+
ORC_UNIQUE_PTR<liborc::Type> orc_schema_;
769782
};
770783

771784
ORCFileWriter::~ORCFileWriter() {}
@@ -783,6 +796,11 @@ Result<std::unique_ptr<ORCFileWriter>> ORCFileWriter::Open(
783796

784797
Status ORCFileWriter::Write(const Table& table) { return impl_->Write(table); }
785798

799+
Status ORCFileWriter::Write(const RecordBatch& record_batch) {
800+
auto table = Table::Make(record_batch.schema(), record_batch.columns());
801+
return impl_->Write(*table);
802+
}
803+
786804
Status ORCFileWriter::Close() { return impl_->Close(); }
787805

788806
} // namespace orc

cpp/src/arrow/adapters/orc/adapter.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,12 +272,24 @@ class ARROW_EXPORT ORCFileWriter {
272272
io::OutputStream* output_stream,
273273
const WriteOptions& write_options = WriteOptions());
274274

275-
/// \brief Write a table
275+
/// \brief Write a table. This can be called multiple times.
276276
///
277-
/// \param[in] table the Arrow table from which data is extracted
277+
/// Tables passed in subsequent calls must match the schema of the table that was
278+
/// written first.
279+
///
280+
/// \param[in] table the Arrow table from which data is extracted.
278281
/// \return Status
279282
Status Write(const Table& table);
280283

284+
/// \brief Write a RecordBatch. This can be called multiple times.
285+
///
286+
/// RecordBatches passed in subsequent calls must match the schema of the
287+
/// RecordBatch that was written first.
288+
///
289+
/// \param[in] record_batch the Arrow RecordBatch from which data is extracted.
290+
/// \return Status
291+
Status Write(const RecordBatch& record_batch);
292+
281293
/// \brief Close an ORC writer (orc::Writer)
282294
///
283295
/// \return Status

cpp/src/arrow/adapters/orc/adapter_test.cc

Lines changed: 111 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ std::shared_ptr<Table> GenerateRandomTable(const std::shared_ptr<Schema>& schema
224224
return Table::Make(schema, cv);
225225
}
226226

227-
void AssertTableWriteReadEqual(const std::shared_ptr<Table>& input_table,
227+
void AssertTableWriteReadEqual(const std::vector<std::shared_ptr<Table>>& input_tables,
228228
const std::shared_ptr<Table>& expected_output_table,
229229
const int64_t max_size = kDefaultSmallMemStreamSize,
230230
std::vector<int>* opt_selected_read_indices = nullptr) {
@@ -241,7 +241,9 @@ void AssertTableWriteReadEqual(const std::shared_ptr<Table>& input_table,
241241
write_options.row_index_stride = 5000;
242242
EXPECT_OK_AND_ASSIGN(auto writer, adapters::orc::ORCFileWriter::Open(
243243
buffer_output_stream.get(), write_options));
244-
ARROW_EXPECT_OK(writer->Write(*input_table));
244+
for (const auto& input_table : input_tables) {
245+
ARROW_EXPECT_OK(writer->Write(*input_table));
246+
}
245247
ARROW_EXPECT_OK(writer->Close());
246248
EXPECT_OK_AND_ASSIGN(auto buffer, buffer_output_stream->Finish());
247249
std::shared_ptr<io::RandomAccessFile> in_stream(new io::BufferReader(buffer));
@@ -259,6 +261,48 @@ void AssertTableWriteReadEqual(const std::shared_ptr<Table>& input_table,
259261
AssertTablesEqual(*expected_output_table, *actual_output_table, false, false);
260262
}
261263

264+
void AssertBatchWriteReadEqual(
265+
const std::vector<std::shared_ptr<RecordBatch>>& input_batches,
266+
const std::shared_ptr<Table>& expected_output_table,
267+
const int64_t max_size = kDefaultSmallMemStreamSize) {
268+
EXPECT_OK_AND_ASSIGN(auto buffer_output_stream,
269+
io::BufferOutputStream::Create(max_size));
270+
auto write_options = adapters::orc::WriteOptions();
271+
#ifdef ARROW_WITH_SNAPPY
272+
write_options.compression = Compression::SNAPPY;
273+
#else
274+
write_options.compression = Compression::UNCOMPRESSED;
275+
#endif
276+
write_options.file_version = adapters::orc::FileVersion(0, 11);
277+
write_options.compression_block_size = 32768;
278+
write_options.row_index_stride = 5000;
279+
EXPECT_OK_AND_ASSIGN(auto writer, adapters::orc::ORCFileWriter::Open(
280+
buffer_output_stream.get(), write_options));
281+
for (auto& input_batch : input_batches) {
282+
ARROW_EXPECT_OK(writer->Write(*input_batch));
283+
}
284+
ARROW_EXPECT_OK(writer->Close());
285+
EXPECT_OK_AND_ASSIGN(auto buffer, buffer_output_stream->Finish());
286+
std::shared_ptr<io::RandomAccessFile> in_stream(new io::BufferReader(buffer));
287+
EXPECT_OK_AND_ASSIGN(
288+
auto reader, adapters::orc::ORCFileReader::Open(in_stream, default_memory_pool()));
289+
ASSERT_EQ(reader->GetFileVersion(), write_options.file_version);
290+
ASSERT_EQ(reader->GetCompression(), write_options.compression);
291+
ASSERT_EQ(reader->GetCompressionSize(), write_options.compression_block_size);
292+
ASSERT_EQ(reader->GetRowIndexStride(), write_options.row_index_stride);
293+
EXPECT_OK_AND_ASSIGN(auto actual_output_table, reader->Read());
294+
AssertTablesEqual(*expected_output_table, *actual_output_table, false, false);
295+
}
296+
297+
void AssertTableWriteReadEqual(const std::shared_ptr<Table>& input_table,
298+
const std::shared_ptr<Table>& expected_output_table,
299+
const int64_t max_size = kDefaultSmallMemStreamSize,
300+
std::vector<int>* opt_selected_read_indices = nullptr) {
301+
std::vector<std::shared_ptr<Table>> input_tables;
302+
input_tables.push_back(input_table);
303+
AssertTableWriteReadEqual(input_tables, expected_output_table, max_size,
304+
opt_selected_read_indices);
305+
}
262306
void AssertArrayWriteReadEqual(const std::shared_ptr<Array>& input_array,
263307
const std::shared_ptr<Array>& expected_output_array,
264308
const int64_t max_size = kDefaultSmallMemStreamSize) {
@@ -767,4 +811,69 @@ TEST_F(TestORCWriterSingleArray, WriteListOfMap) {
767811
AssertArrayWriteReadEqual(array, array, kDefaultSmallMemStreamSize * 10);
768812
}
769813

814+
class TestORCWriterMultipleWrite : public ::testing::Test {
815+
public:
816+
TestORCWriterMultipleWrite() : rand(kRandomSeed) {}
817+
818+
protected:
819+
random::RandomArrayGenerator rand;
820+
};
821+
822+
TEST_F(TestORCWriterMultipleWrite, MultipleWritesIntField) {
823+
const int64_t num_rows = 1234;
824+
const int num_writes = 5;
825+
std::shared_ptr<Schema> input_schema = schema({field("col0", int32())});
826+
ArrayVector vect;
827+
std::vector<std::shared_ptr<Table>> input_tables;
828+
for (int i = 0; i < num_writes; i++) {
829+
auto array_int = rand.ArrayOf(int32(), num_rows, 0);
830+
vect.push_back(array_int);
831+
auto input_chunked_array = std::make_shared<ChunkedArray>(array_int);
832+
input_tables.emplace_back(Table::Make(input_schema, {input_chunked_array}));
833+
}
834+
auto expected_output_chunked_array = std::make_shared<ChunkedArray>(vect);
835+
std::shared_ptr<Table> expected_output_table =
836+
Table::Make(input_schema, {expected_output_chunked_array});
837+
AssertTableWriteReadEqual(input_tables, expected_output_table,
838+
kDefaultSmallMemStreamSize * 100);
839+
}
840+
841+
TEST_F(TestORCWriterMultipleWrite, MultipleWritesIncoherentSchema) {
842+
const int64_t num_rows = 1234;
843+
auto array_int = rand.ArrayOf(int32(), num_rows, 0);
844+
std::shared_ptr<Schema> input_schema = schema({field("col0", array_int->type())});
845+
auto array_int2 = rand.ArrayOf(int64(), num_rows, 0);
846+
std::shared_ptr<Schema> input_schema2 = schema({field("col0", array_int2->type())});
847+
848+
std::shared_ptr<Table> input_table = Table::Make(input_schema, {array_int});
849+
std::shared_ptr<Table> input_table2 = Table::Make(input_schema2, {array_int2});
850+
EXPECT_OK_AND_ASSIGN(auto buffer_output_stream,
851+
io::BufferOutputStream::Create(kDefaultSmallMemStreamSize));
852+
auto write_options = adapters::orc::WriteOptions();
853+
EXPECT_OK_AND_ASSIGN(auto writer, adapters::orc::ORCFileWriter::Open(
854+
buffer_output_stream.get(), write_options));
855+
ARROW_EXPECT_OK(writer->Write(*input_table));
856+
857+
// This should not pass
858+
ASSERT_RAISES(TypeError, writer->Write(*input_table2));
859+
860+
ARROW_EXPECT_OK(writer->Close());
861+
}
862+
TEST_F(TestORCWriterMultipleWrite, MultipleWritesIntFieldRecordBatch) {
863+
const int64_t num_rows = 1234;
864+
const int num_writes = 5;
865+
std::shared_ptr<Schema> input_schema = schema({field("col0", int32())});
866+
ArrayVector vect;
867+
std::vector<std::shared_ptr<RecordBatch>> input_batches;
868+
for (int i = 0; i < num_writes; i++) {
869+
auto array_int = rand.ArrayOf(int32(), num_rows, 0);
870+
vect.push_back(array_int);
871+
input_batches.emplace_back(RecordBatch::Make(input_schema, num_rows, {array_int}));
872+
}
873+
auto expected_output_chunked_array = std::make_shared<ChunkedArray>(vect);
874+
std::shared_ptr<Table> expected_output_table =
875+
Table::Make(input_schema, {expected_output_chunked_array});
876+
AssertBatchWriteReadEqual(input_batches, expected_output_table,
877+
kDefaultSmallMemStreamSize * 100);
878+
}
770879
} // namespace arrow

0 commit comments

Comments
 (0)