Skip to content

Commit ad1fc6c

Browse files
fsaintjacquesnealrichardson
authored andcommitted
ARROW-6952: [C++][Dataset] Implement predicate pushdown with ParqueFileFragment
The proposed change can be divided in 3 parts: - Implement the `StatisticsAsScalars(Statistics& stats, Scalar* min, Scalar* max)` function to convert `parquet::Statistic`s min and max as `arrow::Scalar`s. - Implement the `RowGroupStatisticsAsExpression(RowGroupMetadata& meta, Expression* out)` function to represents the RowGroup's statistics as an expression of conjunction, e.g. `(a_min <= a AND a <= a_max) AND (b_min <= b AND b <= b_max) AND ...` - Modifies ParquetScanTaskIterator to skip RowGroups by checking the expression derived from the metadata with the filter expression. Closes apache#5765 from fsaintjacques/ARROW-6952-dataset-parquet-predicate-pushdown and squashes the following commits: 5fd0efd <François Saint-Jacques> Add unit test and fix issues 04e4a45 <François Saint-Jacques> Review comments 0d14b8f <François Saint-Jacques> Expose SchemaManifest publicly 235a0d6 <François Saint-Jacques> ARROW-6952: Implement predicate pushdown with ParquetFileFragment Authored-by: François Saint-Jacques <fsaintjacques@gmail.com> Signed-off-by: Neal Richardson <neal.p.richardson@gmail.com>
1 parent 21ad7ac commit ad1fc6c

15 files changed

Lines changed: 570 additions & 159 deletions

cpp/src/arrow/dataset/file_parquet.cc

Lines changed: 111 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,21 @@
2828
#include "arrow/util/range.h"
2929
#include "arrow/util/stl.h"
3030
#include "parquet/arrow/reader.h"
31+
#include "parquet/arrow/schema.h"
3132
#include "parquet/file_reader.h"
33+
#include "parquet/statistics.h"
3234

3335
namespace arrow {
3436
namespace dataset {
3537

36-
/// \brief A ScanTask backed by a parquet file and a subset of RowGroups.
38+
/// \brief A ScanTask backed by a parquet file and a RowGroup within a parquet file.
3739
class ParquetScanTask : public ScanTask {
3840
public:
39-
ParquetScanTask(std::vector<int> row_groups, std::vector<int> columns_projection,
41+
ParquetScanTask(int row_group, std::vector<int> columns_projection,
4042
std::shared_ptr<parquet::arrow::FileReader> reader,
4143
std::shared_ptr<ScanOptions> options,
4244
std::shared_ptr<ScanContext> context)
43-
: row_groups_(std::move(row_groups)),
45+
: row_group_(row_group),
4446
columns_projection_(std::move(columns_projection)),
4547
reader_(reader),
4648
options_(std::move(options)),
@@ -54,7 +56,7 @@ class ParquetScanTask : public ScanTask {
5456
// Thus the memory incurred by the RecordBatchReader is allocated when
5557
// Scan is called.
5658
std::unique_ptr<RecordBatchReader> record_batch_reader;
57-
auto status = reader_->GetRecordBatchReader(row_groups_, columns_projection_,
59+
auto status = reader_->GetRecordBatchReader({row_group_}, columns_projection_,
5860
&record_batch_reader);
5961
// Propagate the previous error as an error iterator.
6062
if (!status.ok()) {
@@ -65,8 +67,7 @@ class ParquetScanTask : public ScanTask {
6567
}
6668

6769
private:
68-
// Subset of RowGroups and columns bound to this task.
69-
std::vector<int> row_groups_;
70+
int row_group_;
7071
std::vector<int> columns_projection_;
7172
// The ScanTask _must_ hold a reference to reader_ because there's no
7273
// guarantee the producing ParquetScanTaskIterator is still alive. This is a
@@ -77,35 +78,52 @@ class ParquetScanTask : public ScanTask {
7778
std::shared_ptr<ScanContext> context_;
7879
};
7980

80-
constexpr int64_t kDefaultRowCountPerPartition = 1U << 16;
81-
82-
// A class that clusters RowGroups of a Parquet file until the cluster has a specified
83-
// total row count. This doesn't guarantee exact row counts; it may exceed the target.
84-
class ParquetRowGroupPartitioner {
81+
// Skip RowGroups with a filter and metadata
82+
class RowGroupSkipper {
8583
public:
86-
ParquetRowGroupPartitioner(std::shared_ptr<parquet::FileMetaData> metadata,
87-
int64_t row_count = kDefaultRowCountPerPartition)
88-
: metadata_(std::move(metadata)), row_count_(row_count), row_group_idx_(0) {
84+
static constexpr int kIterationDone = -1;
85+
86+
RowGroupSkipper(std::shared_ptr<parquet::FileMetaData> metadata,
87+
std::shared_ptr<Expression> filter)
88+
: metadata_(std::move(metadata)), filter_(filter), row_group_idx_(0) {
8989
num_row_groups_ = metadata_->num_row_groups();
9090
}
9191

92-
std::vector<int> Next() {
93-
int64_t partition_size = 0;
94-
std::vector<int> partitions;
92+
int Next() {
93+
while (row_group_idx_ < num_row_groups_) {
94+
const auto row_group_idx = row_group_idx_++;
95+
const auto row_group = metadata_->RowGroup(row_group_idx);
9596

96-
while (row_group_idx_ < num_row_groups_ && partition_size < row_count_) {
97-
partition_size += metadata_->RowGroup(row_group_idx_)->num_rows();
98-
partitions.push_back(row_group_idx_++);
97+
const auto num_rows = row_group->num_rows();
98+
if (CanSkip(*row_group)) {
99+
rows_skipped_ += num_rows;
100+
continue;
101+
}
102+
103+
return row_group_idx;
99104
}
100105

101-
return partitions;
106+
return kIterationDone;
102107
}
103108

104109
private:
110+
bool CanSkip(const parquet::RowGroupMetaData& metadata) const {
111+
auto maybe_stats_expr = RowGroupStatisticsAsExpression(metadata);
112+
// Errors with statistics are ignored and post-filtering will apply.
113+
if (!maybe_stats_expr.ok()) {
114+
return false;
115+
}
116+
117+
auto stats_expr = maybe_stats_expr.ValueOrDie();
118+
auto expr = filter_->Assume(stats_expr);
119+
return (expr->IsNull() || expr->Equals(false));
120+
}
121+
105122
std::shared_ptr<parquet::FileMetaData> metadata_;
106-
int64_t row_count_;
123+
std::shared_ptr<Expression> filter_;
107124
int row_group_idx_;
108125
int num_row_groups_;
126+
int64_t rows_skipped_;
109127
};
110128

111129
class ParquetScanTaskIterator {
@@ -130,16 +148,16 @@ class ParquetScanTaskIterator {
130148
}
131149

132150
Status Next(std::unique_ptr<ScanTask>* task) {
133-
auto partition = partitioner_.Next();
151+
auto row_group = skipper_.Next();
134152

135153
// Iteration is done.
136-
if (partition.size() == 0) {
154+
if (row_group == RowGroupSkipper::kIterationDone) {
137155
task->reset(nullptr);
138156
return Status::OK();
139157
}
140158

141-
task->reset(new ParquetScanTask(std::move(partition), columns_projection_, reader_,
142-
options_, context_));
159+
task->reset(
160+
new ParquetScanTask(row_group, columns_projection_, reader_, options_, context_));
143161

144162
return Status::OK();
145163
}
@@ -163,13 +181,13 @@ class ParquetScanTaskIterator {
163181
: options_(std::move(options)),
164182
context_(std::move(context)),
165183
columns_projection_(columns_projection),
166-
partitioner_(std::move(metadata)),
184+
skipper_(std::move(metadata), options_->filter),
167185
reader_(std::move(reader)) {}
168186

169187
std::shared_ptr<ScanOptions> options_;
170188
std::shared_ptr<ScanContext> context_;
171189
std::vector<int> columns_projection_;
172-
ParquetRowGroupPartitioner partitioner_;
190+
RowGroupSkipper skipper_;
173191
std::shared_ptr<parquet::arrow::FileReader> reader_;
174192
};
175193

@@ -220,5 +238,70 @@ Status ParquetFileFormat::OpenReader(
220238
return Status::OK();
221239
}
222240

241+
using parquet::arrow::SchemaField;
242+
using parquet::arrow::StatisticsAsScalars;
243+
244+
static std::shared_ptr<Expression> ColumnChunkStatisticsAsExpression(
245+
const SchemaField& schema_field, const parquet::RowGroupMetaData& metadata) {
246+
// For the remaining of this function, failure to extract/parse statistics
247+
// are ignored by returning the `true` scalar. The goal is two fold. First
248+
// avoid that an optimization break the computation. Second, allow the
249+
// following columns to maybe succeed in extracting column statistics.
250+
251+
// For now, only leaf (primitive) types are supported.
252+
if (!schema_field.is_leaf()) {
253+
return scalar(true);
254+
}
255+
256+
auto column_metadata = metadata.ColumnChunk(schema_field.column_index);
257+
auto field = schema_field.field;
258+
auto field_expr = field_ref(field->name());
259+
260+
// In case of missing statistics, return nothing.
261+
if (!column_metadata->is_stats_set()) {
262+
return scalar(true);
263+
}
264+
265+
auto statistics = column_metadata->statistics();
266+
if (statistics == nullptr) {
267+
return scalar(true);
268+
}
269+
270+
// Optimize for corner case where all values are nulls
271+
if (statistics->num_values() == statistics->null_count()) {
272+
std::shared_ptr<Scalar> null_scalar;
273+
if (!MakeNullScalar(field->type(), &null_scalar).ok()) {
274+
// MakeNullScalar can fail for some nested/repeated types.
275+
return scalar(true);
276+
}
277+
278+
return equal(field_expr, scalar(null_scalar));
279+
}
280+
281+
std::shared_ptr<Scalar> min, max;
282+
if (!StatisticsAsScalars(*statistics, &min, &max).ok()) {
283+
return scalar(true);
284+
}
285+
286+
return and_(greater_equal(field_expr, scalar(min)),
287+
less_equal(field_expr, scalar(max)));
288+
}
289+
290+
using parquet::arrow::SchemaManifest;
291+
292+
Result<std::shared_ptr<Expression>> RowGroupStatisticsAsExpression(
293+
const parquet::RowGroupMetaData& metadata) {
294+
SchemaManifest manifest;
295+
RETURN_NOT_OK(SchemaManifest::Make(
296+
metadata.schema(), nullptr, parquet::default_arrow_reader_properties(), &manifest));
297+
298+
std::vector<std::shared_ptr<Expression>> expressions;
299+
for (const auto& schema_field : manifest.schema_fields) {
300+
expressions.emplace_back(ColumnChunkStatisticsAsExpression(schema_field, metadata));
301+
}
302+
303+
return expressions.empty() ? scalar(true) : and_(expressions);
304+
}
305+
223306
} // namespace dataset
224307
} // namespace arrow

cpp/src/arrow/dataset/file_parquet.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
namespace parquet {
2828
class ParquetFileReader;
29+
class RowGroupMetaData;
2930
} // namespace parquet
3031

3132
namespace arrow {
@@ -75,5 +76,8 @@ class ARROW_DS_EXPORT ParquetFragment : public FileBasedDataFragment {
7576
bool splittable() const override { return true; }
7677
};
7778

79+
Result<std::shared_ptr<Expression>> RowGroupStatisticsAsExpression(
80+
const parquet::RowGroupMetaData& metadata);
81+
7882
} // namespace dataset
7983
} // namespace arrow

cpp/src/arrow/dataset/file_parquet_test.cc

Lines changed: 94 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,8 @@ class ArrowParquetWriterMixin : public ::testing::Test {
104104
std::shared_ptr<Buffer> out;
105105
auto sink = CreateOutputStream(pool);
106106

107-
ARROW_EXPECT_OK(WriteRecordBatchReader(reader, pool, sink));
108-
ARROW_EXPECT_OK(sink->Finish(&out));
107+
ABORT_NOT_OK(WriteRecordBatchReader(reader, pool, sink));
108+
ABORT_NOT_OK(sink->Finish(&out));
109109

110110
return out;
111111
}
@@ -145,12 +145,9 @@ class ParquetBufferFixtureMixin : public ArrowParquetWriterMixin {
145145
};
146146

147147
class TestParquetFileFormat : public ParquetBufferFixtureMixin {
148-
public:
149-
TestParquetFileFormat() : ctx_(std::make_shared<ScanContext>()) {}
150-
151148
protected:
152149
std::shared_ptr<ScanOptions> opts_ = ScanOptions::Defaults();
153-
std::shared_ptr<ScanContext> ctx_;
150+
std::shared_ptr<ScanContext> ctx_ = std::make_shared<ScanContext>();
154151
};
155152

156153
TEST_F(TestParquetFileFormat, ScanRecordBatchReader) {
@@ -199,5 +196,96 @@ TEST_F(TestParquetFileFormat, Inspect) {
199196
EXPECT_EQ(*actual, *schema_);
200197
}
201198

199+
void CountRowsInScan(ScanTaskIterator& it, int64_t expected_rows,
200+
int64_t expected_batches) {
201+
int64_t actual_rows = 0;
202+
int64_t actual_batches = 0;
203+
204+
for (auto maybe_scan_task : it) {
205+
ASSERT_OK_AND_ASSIGN(auto scan_task, std::move(maybe_scan_task));
206+
for (auto maybe_record_batch : scan_task->Scan()) {
207+
ASSERT_OK_AND_ASSIGN(auto record_batch, std::move(maybe_record_batch));
208+
actual_rows += record_batch->num_rows();
209+
actual_batches++;
210+
}
211+
}
212+
213+
EXPECT_EQ(actual_rows, expected_rows);
214+
EXPECT_EQ(actual_batches, expected_batches);
215+
}
216+
217+
class TestParquetFileFormatPushDown : public TestParquetFileFormat {
218+
public:
219+
void CountRowsAndBatchesInScan(DataFragment& fragment, int64_t expected_rows,
220+
int64_t expected_batches) {
221+
int64_t actual_rows = 0;
222+
int64_t actual_batches = 0;
223+
224+
ScanTaskIterator it;
225+
ASSERT_OK(fragment.Scan(ctx_, &it));
226+
for (auto maybe_scan_task : it) {
227+
ASSERT_OK_AND_ASSIGN(auto scan_task, std::move(maybe_scan_task));
228+
for (auto maybe_record_batch : scan_task->Scan()) {
229+
ASSERT_OK_AND_ASSIGN(auto record_batch, std::move(maybe_record_batch));
230+
actual_rows += record_batch->num_rows();
231+
actual_batches++;
232+
}
233+
}
234+
235+
EXPECT_EQ(actual_rows, expected_rows);
236+
EXPECT_EQ(actual_batches, expected_batches);
237+
}
238+
};
239+
240+
TEST_F(TestParquetFileFormatPushDown, Basic) {
241+
// Given a number `n`, the arithmetic dataset creates n RecordBatches where
242+
// each RecordBatch is keyed by a unique integer in [1, n]. Let `rb_i` denote
243+
// the record batch keyed by `i`. `rb_i` is composed of `i` rows where all
244+
// values are a variant of `i`, e.g. {"i64": i, "u8": i, ... }.
245+
//
246+
// Thus the ArithmeticDataset(n) has n RecordBatches and the total number of
247+
// rows is n(n+1)/2.
248+
//
249+
// This test uses the DataFragment directly, and so no post-filtering is
250+
// applied via ScanOptions' evaluator. Thus, counting the number of returned
251+
// rows and returned row groups is a good enough proxy to check if pushdown
252+
// predicate is working.
253+
constexpr int64_t kNumRowGroups = 16;
254+
constexpr int64_t kTotalNumRows = kNumRowGroups * (kNumRowGroups + 1) / 2;
255+
256+
auto reader = ArithmeticDatasetFixture::GetRecordBatchReader(kNumRowGroups);
257+
auto source = GetFileSource(reader.get());
258+
auto fragment = std::make_shared<ParquetFragment>(*source, opts_);
259+
260+
opts_->filter = scalar(true);
261+
CountRowsAndBatchesInScan(*fragment, kTotalNumRows, kNumRowGroups);
262+
263+
for (int64_t i = 1; i <= kNumRowGroups; i++) {
264+
opts_->filter = ("i64"_ == int64_t(i)).Copy();
265+
CountRowsAndBatchesInScan(*fragment, i, 1);
266+
}
267+
268+
/* Out of bound filters should skip all RowGroups. */
269+
opts_->filter = scalar(false);
270+
CountRowsAndBatchesInScan(*fragment, 0, 0);
271+
opts_->filter = ("i64"_ == int64_t(kNumRowGroups + 1)).Copy();
272+
CountRowsAndBatchesInScan(*fragment, 0, 0);
273+
opts_->filter = ("i64"_ == int64_t(-1)).Copy();
274+
CountRowsAndBatchesInScan(*fragment, 0, 0);
275+
// No rows match 1 and 2.
276+
opts_->filter = ("i64"_ == int64_t(1) and "u8"_ == uint8_t(2)).Copy();
277+
CountRowsAndBatchesInScan(*fragment, 0, 0);
278+
279+
opts_->filter = ("i64"_ == int64_t(2) or "i64"_ == int64_t(4)).Copy();
280+
CountRowsAndBatchesInScan(*fragment, 2 + 4, 2);
281+
282+
opts_->filter = ("i64"_ < int64_t(6)).Copy();
283+
CountRowsAndBatchesInScan(*fragment, 5 * (5 + 1) / 2, 5);
284+
285+
opts_->filter = ("i64"_ >= int64_t(6)).Copy();
286+
CountRowsAndBatchesInScan(*fragment, kTotalNumRows - (5 * (5 + 1) / 2),
287+
kNumRowGroups - 5);
288+
}
289+
202290
} // namespace dataset
203291
} // namespace arrow

cpp/src/arrow/dataset/filter.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,10 @@ class ARROW_DS_EXPORT Expression {
179179
return Copy();
180180
}
181181

182+
std::shared_ptr<Expression> Assume(const std::shared_ptr<Expression>& given) const {
183+
return Assume(*given);
184+
}
185+
182186
/// returns a debug string representing this expression
183187
virtual std::string ToString() const = 0;
184188

cpp/src/arrow/dataset/filter_test.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,13 @@ class ExpressionsTest : public ::testing::Test {
6262
auto simplified = expr.Assume(given);
6363
ASSERT_EQ(E{simplified}, E{expected})
6464
<< " simplification of: " << expr.ToString() << std::endl
65-
<< " given: " << given.ToString() << std::endl;
65+
<< " given: " << given.ToString() << std::endl
66+
<< " expected: " << expected.ToString() << std::endl;
67+
}
68+
69+
void AssertSimplifiesTo(const Expression& expr, const Expression& given,
70+
const std::shared_ptr<Expression>& expected) {
71+
AssertSimplifiesTo(expr, given, *expected);
6672
}
6773

6874
std::shared_ptr<Schema> schema_ =

0 commit comments

Comments
 (0)