Skip to content

Commit d575858

Browse files
bkietzlidavidm
andcommitted
ARROW-11797: [C++][Dataset] Provide batch stream Scanner methods
Closes apache#9589 from bkietz/11797-Provide-Scanner-methods-t Lead-authored-by: Benjamin Kietzman <bengilgit@gmail.com> Co-authored-by: David Li <li.davidm96@gmail.com> Signed-off-by: David Li <li.davidm96@gmail.com>
1 parent 02cdeab commit d575858

17 files changed

Lines changed: 675 additions & 170 deletions

File tree

cpp/src/arrow/dataset/file_base.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,9 +460,25 @@ Status FileSystemDataset::Write(const FileSystemDatasetWriteOptions& write_optio
460460
//
461461
// NB: neither of these will have any impact whatsoever on the common case of writing
462462
// an in-memory table to disk.
463+
464+
#if defined(__GNUC__) || defined(__clang__)
465+
#pragma GCC diagnostic push
466+
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
467+
#elif defined(_MSC_VER)
468+
#pragma warning(push)
469+
#pragma warning(disable : 4996)
470+
#endif
471+
472+
// TODO: (ARROW-11782/ARROW-12288) Remove calls to Scan()
463473
ARROW_ASSIGN_OR_RAISE(auto scan_task_it, scanner->Scan());
464474
ARROW_ASSIGN_OR_RAISE(ScanTaskVector scan_tasks, scan_task_it.ToVector());
465475

476+
#if defined(__GNUC__) || defined(__clang__)
477+
#pragma GCC diagnostic pop
478+
#elif defined(_MSC_VER)
479+
#pragma warning(pop)
480+
#endif
481+
466482
WriteState state(write_options);
467483
auto res = internal::RunSynchronously<arrow::detail::Empty>(
468484
[&](internal::Executor* cpu_executor) -> Future<> {

cpp/src/arrow/dataset/file_csv_test.cc

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -270,17 +270,11 @@ N/A,bar
270270
ASSERT_OK(builder.Project({"str"}));
271271
ASSERT_OK_AND_ASSIGN(auto scanner, builder.Finish());
272272

273-
ASSERT_OK_AND_ASSIGN(auto scan_task_it, scanner->Scan());
274-
for (auto maybe_scan_task : scan_task_it) {
275-
ASSERT_OK_AND_ASSIGN(auto scan_task, maybe_scan_task);
276-
ASSERT_OK_AND_ASSIGN(auto batch_it, scan_task->Execute());
277-
for (auto maybe_batch : batch_it) {
278-
ASSERT_OK_AND_ASSIGN(auto batch, maybe_batch);
279-
// Run through the scan checking for errors to ensure that "f64" is read with the
280-
// specified type and does not revert to the inferred type (if it reverts to
281-
// inferring float64 then evaluation of the comparison expression should break)
282-
}
283-
}
273+
ASSERT_OK_AND_ASSIGN(auto batch_it, scanner->ScanBatches());
274+
// Run through the scan checking for errors to ensure that "f64" is read with the
275+
// specified type and does not revert to the inferred type (if it reverts to
276+
// inferring float64 then evaluation of the comparison expression should break)
277+
ASSERT_OK(batch_it.Visit([](TaggedRecordBatch) { return Status::OK(); }));
284278
}
285279

286280
INSTANTIATE_TEST_SUITE_P(TestUncompressedCsv, TestCsvFileFormat,

cpp/src/arrow/dataset/scanner.cc

Lines changed: 234 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,15 @@
1818
#include "arrow/dataset/scanner.h"
1919

2020
#include <algorithm>
21+
#include <condition_variable>
2122
#include <memory>
2223
#include <mutex>
24+
#include <sstream>
2325

26+
#include "arrow/array/array_primitive.h"
2427
#include "arrow/compute/api_scalar.h"
28+
#include "arrow/compute/api_vector.h"
29+
#include "arrow/compute/cast.h"
2530
#include "arrow/dataset/dataset.h"
2631
#include "arrow/dataset/dataset_internal.h"
2732
#include "arrow/dataset/scanner_internal.h"
@@ -132,37 +137,124 @@ Result<EnumeratedRecordBatchIterator> Scanner::AddPositioningToInOrderScan(
132137
EnumeratingIterator{std::make_shared<State>(std::move(scan), std::move(first))});
133138
}
134139

135-
Result<TaggedRecordBatchIterator> SyncScanner::ScanBatches() {
136-
// TODO(ARROW-11797) Provide a better implementation that does readahead. Also, add
137-
// unit testing
138-
ARROW_ASSIGN_OR_RAISE(auto scan_task_it, Scan());
139-
struct BatchIter {
140-
explicit BatchIter(ScanTaskIterator scan_task_it)
141-
: scan_task_it(std::move(scan_task_it)) {}
142-
143-
Result<TaggedRecordBatch> Next() {
144-
while (true) {
145-
if (current_task == nullptr) {
146-
ARROW_ASSIGN_OR_RAISE(current_task, scan_task_it.Next());
147-
if (IsIterationEnd<std::shared_ptr<ScanTask>>(current_task)) {
148-
return IterationEnd<TaggedRecordBatch>();
149-
}
150-
ARROW_ASSIGN_OR_RAISE(batch_it, current_task->Execute());
151-
}
152-
ARROW_ASSIGN_OR_RAISE(auto next, batch_it.Next());
153-
if (IsIterationEnd<std::shared_ptr<RecordBatch>>(next)) {
154-
current_task = nullptr;
155-
} else {
156-
return TaggedRecordBatch{next, current_task->fragment()};
157-
}
140+
struct ScanBatchesState : public std::enable_shared_from_this<ScanBatchesState> {
141+
explicit ScanBatchesState(ScanTaskIterator scan_task_it,
142+
std::shared_ptr<TaskGroup> task_group_)
143+
: scan_tasks(std::move(scan_task_it)), task_group(std::move(task_group_)) {}
144+
145+
void ResizeBatches(size_t task_index) {
146+
if (task_batches.size() <= task_index) {
147+
task_batches.resize(task_index + 1);
148+
task_drained.resize(task_index + 1);
149+
}
150+
}
151+
152+
void Push(TaggedRecordBatch batch, size_t task_index) {
153+
{
154+
std::lock_guard<std::mutex> lock(mutex);
155+
ResizeBatches(task_index);
156+
task_batches[task_index].push_back(std::move(batch));
157+
}
158+
ready.notify_one();
159+
}
160+
161+
Status Finish(size_t task_index) {
162+
{
163+
std::lock_guard<std::mutex> lock(mutex);
164+
ResizeBatches(task_index);
165+
task_drained[task_index] = true;
166+
}
167+
ready.notify_one();
168+
return Status::OK();
169+
}
170+
171+
void PushScanTask() {
172+
if (no_more_tasks) return;
173+
std::unique_lock<std::mutex> lock(mutex);
174+
auto maybe_task = scan_tasks.Next();
175+
if (!maybe_task.ok()) {
176+
no_more_tasks = true;
177+
iteration_error = maybe_task.status();
178+
return;
179+
}
180+
auto scan_task = maybe_task.ValueOrDie();
181+
if (IsIterationEnd(scan_task)) {
182+
no_more_tasks = true;
183+
return;
184+
}
185+
auto state = shared_from_this();
186+
auto id = next_scan_task_id++;
187+
ResizeBatches(id);
188+
189+
lock.unlock();
190+
task_group->Append([state, id, scan_task]() {
191+
ARROW_ASSIGN_OR_RAISE(auto batch_it, scan_task->Execute());
192+
for (auto maybe_batch : batch_it) {
193+
ARROW_ASSIGN_OR_RAISE(auto batch, maybe_batch);
194+
state->Push(TaggedRecordBatch{std::move(batch), scan_task->fragment()}, id);
158195
}
196+
return state->Finish(id);
197+
});
198+
}
199+
200+
Result<TaggedRecordBatch> Pop() {
201+
std::unique_lock<std::mutex> lock(mutex);
202+
ready.wait(lock, [this, &lock] {
203+
while (pop_cursor < task_batches.size()) {
204+
// queue for current scan task contains at least one batch, pop that
205+
if (!task_batches[pop_cursor].empty()) return true;
206+
// queue is empty but will be appended to eventually, wait for that
207+
if (!task_drained[pop_cursor]) return false;
208+
209+
// Finished draining current scan task, enqueue a new one
210+
++pop_cursor;
211+
// Must unlock since serial task group will execute synchronously
212+
lock.unlock();
213+
PushScanTask();
214+
lock.lock();
215+
}
216+
DCHECK(no_more_tasks);
217+
// all scan tasks drained (or getting next task failed), terminate
218+
return true;
219+
});
220+
221+
if (pop_cursor == task_batches.size()) {
222+
// Don't report an error until we yield up everything we can first
223+
RETURN_NOT_OK(iteration_error);
224+
return IterationEnd<TaggedRecordBatch>();
159225
}
160226

161-
ScanTaskIterator scan_task_it;
162-
RecordBatchIterator batch_it;
163-
std::shared_ptr<ScanTask> current_task;
164-
};
165-
return TaggedRecordBatchIterator(BatchIter(std::move(scan_task_it)));
227+
auto batch = std::move(task_batches[pop_cursor].front());
228+
task_batches[pop_cursor].pop_front();
229+
return batch;
230+
}
231+
232+
/// Protecting mutating accesses to batches
233+
std::mutex mutex;
234+
std::condition_variable ready;
235+
ScanTaskIterator scan_tasks;
236+
std::shared_ptr<TaskGroup> task_group;
237+
int next_scan_task_id = 0;
238+
bool no_more_tasks = false;
239+
Status iteration_error;
240+
std::vector<std::deque<TaggedRecordBatch>> task_batches;
241+
std::vector<bool> task_drained;
242+
size_t pop_cursor = 0;
243+
};
244+
245+
Result<TaggedRecordBatchIterator> SyncScanner::ScanBatches() {
246+
ARROW_ASSIGN_OR_RAISE(auto scan_task_it, ScanInternal());
247+
auto task_group = scan_options_->TaskGroup();
248+
auto state = std::make_shared<ScanBatchesState>(std::move(scan_task_it), task_group);
249+
for (int i = 0; i < scan_options_->fragment_readahead; i++) {
250+
state->PushScanTask();
251+
}
252+
return MakeFunctionIterator([task_group, state]() -> Result<TaggedRecordBatch> {
253+
ARROW_ASSIGN_OR_RAISE(auto batch, state->Pop());
254+
if (!IsIterationEnd(batch)) return batch;
255+
RETURN_NOT_OK(task_group->Finish());
256+
return IterationEnd<TaggedRecordBatch>();
257+
});
166258
}
167259

168260
Result<FragmentIterator> SyncScanner::GetFragments() {
@@ -176,7 +268,30 @@ Result<FragmentIterator> SyncScanner::GetFragments() {
176268
return GetFragmentsFromDatasets({dataset_}, scan_options_->filter);
177269
}
178270

179-
Result<ScanTaskIterator> SyncScanner::Scan() {
271+
Result<ScanTaskIterator> SyncScanner::Scan() { return ScanInternal(); }
272+
273+
Status SyncScanner::Scan(std::function<Status(TaggedRecordBatch)> visitor) {
274+
ARROW_ASSIGN_OR_RAISE(auto scan_task_it, ScanInternal());
275+
276+
auto task_group = scan_options_->TaskGroup();
277+
278+
for (auto maybe_scan_task : scan_task_it) {
279+
ARROW_ASSIGN_OR_RAISE(auto scan_task, maybe_scan_task);
280+
task_group->Append([scan_task, visitor] {
281+
ARROW_ASSIGN_OR_RAISE(auto batch_it, scan_task->Execute());
282+
for (auto maybe_batch : batch_it) {
283+
ARROW_ASSIGN_OR_RAISE(auto batch, maybe_batch);
284+
RETURN_NOT_OK(
285+
visitor(TaggedRecordBatch{std::move(batch), scan_task->fragment()}));
286+
}
287+
return Status::OK();
288+
});
289+
}
290+
291+
return task_group->Finish();
292+
}
293+
294+
Result<ScanTaskIterator> SyncScanner::ScanInternal() {
180295
// Transforms Iterator<Fragment> into a unified
181296
// Iterator<ScanTask>. The first Iterator::Next invocation is going to do
182297
// all the work of unwinding the chained iterators.
@@ -315,7 +430,7 @@ Result<std::shared_ptr<Table>> SyncScanner::ToTable() {
315430

316431
Future<std::shared_ptr<Table>> SyncScanner::ToTableInternal(
317432
internal::Executor* cpu_executor) {
318-
ARROW_ASSIGN_OR_RAISE(auto scan_task_it, Scan());
433+
ARROW_ASSIGN_OR_RAISE(auto scan_task_it, ScanInternal());
319434
auto task_group = scan_options_->TaskGroup();
320435

321436
/// Wraps the state in a shared_ptr to ensure that failing ScanTasks don't
@@ -343,5 +458,94 @@ Future<std::shared_ptr<Table>> SyncScanner::ToTableInternal(
343458
FlattenRecordBatchVector(std::move(state->batches)));
344459
}
345460

461+
Result<std::shared_ptr<Table>> Scanner::TakeRows(const Array& indices) {
462+
if (indices.null_count() != 0) {
463+
return Status::NotImplemented("null take indices");
464+
}
465+
466+
compute::ExecContext ctx(scan_options_->pool);
467+
468+
const Array* original_indices;
469+
// If we have to cast, this is the backing reference
470+
std::shared_ptr<Array> original_indices_ptr;
471+
if (indices.type_id() != Type::INT64) {
472+
ARROW_ASSIGN_OR_RAISE(
473+
original_indices_ptr,
474+
compute::Cast(indices, int64(), compute::CastOptions::Safe(), &ctx));
475+
original_indices = original_indices_ptr.get();
476+
} else {
477+
original_indices = &indices;
478+
}
479+
480+
std::shared_ptr<Array> unsort_indices;
481+
{
482+
ARROW_ASSIGN_OR_RAISE(
483+
auto sort_indices,
484+
compute::SortIndices(*original_indices, compute::SortOrder::Ascending, &ctx));
485+
ARROW_ASSIGN_OR_RAISE(original_indices_ptr,
486+
compute::Take(*original_indices, *sort_indices,
487+
compute::TakeOptions::Defaults(), &ctx));
488+
original_indices = original_indices_ptr.get();
489+
ARROW_ASSIGN_OR_RAISE(
490+
unsort_indices,
491+
compute::SortIndices(*sort_indices, compute::SortOrder::Ascending, &ctx));
492+
}
493+
494+
RecordBatchVector out_batches;
495+
496+
auto raw_indices = static_cast<const Int64Array&>(*original_indices).raw_values();
497+
int64_t offset = 0, row_begin = 0;
498+
499+
ARROW_ASSIGN_OR_RAISE(auto batch_it, ScanBatches());
500+
while (true) {
501+
ARROW_ASSIGN_OR_RAISE(auto batch, batch_it.Next());
502+
if (IsIterationEnd(batch)) break;
503+
if (offset == original_indices->length()) break;
504+
DCHECK_LT(offset, original_indices->length());
505+
506+
int64_t length = 0;
507+
while (offset + length < original_indices->length()) {
508+
auto rel_index = raw_indices[offset + length] - row_begin;
509+
if (rel_index >= batch.record_batch->num_rows()) break;
510+
++length;
511+
}
512+
DCHECK_LE(offset + length, original_indices->length());
513+
if (length == 0) {
514+
row_begin += batch.record_batch->num_rows();
515+
continue;
516+
}
517+
518+
Datum rel_indices = original_indices->Slice(offset, length);
519+
ARROW_ASSIGN_OR_RAISE(rel_indices,
520+
compute::Subtract(rel_indices, Datum(row_begin),
521+
compute::ArithmeticOptions(), &ctx));
522+
523+
ARROW_ASSIGN_OR_RAISE(Datum out_batch,
524+
compute::Take(batch.record_batch, rel_indices,
525+
compute::TakeOptions::Defaults(), &ctx));
526+
out_batches.push_back(out_batch.record_batch());
527+
528+
offset += length;
529+
row_begin += batch.record_batch->num_rows();
530+
}
531+
532+
if (offset < original_indices->length()) {
533+
std::stringstream error;
534+
const int64_t max_values_shown = 3;
535+
const int64_t num_remaining = original_indices->length() - offset;
536+
for (int64_t i = 0; i < std::min<int64_t>(max_values_shown, num_remaining); i++) {
537+
if (i > 0) error << ", ";
538+
error << static_cast<const Int64Array*>(original_indices)->Value(offset + i);
539+
}
540+
if (num_remaining > max_values_shown) error << ", ...";
541+
return Status::IndexError("Some indices were out of bounds: ", error.str());
542+
}
543+
ARROW_ASSIGN_OR_RAISE(Datum out, Table::FromRecordBatches(options()->projected_schema,
544+
std::move(out_batches)));
545+
ARROW_ASSIGN_OR_RAISE(
546+
out, compute::Take(out, unsort_indices, compute::TakeOptions::Defaults(), &ctx));
547+
return out.table();
548+
}
549+
346550
} // namespace dataset
347551
} // namespace arrow

cpp/src/arrow/dataset/scanner.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,14 @@ class ARROW_DS_EXPORT Scanner {
255255
/// in a concurrent fashion and outlive the iterator.
256256
///
257257
/// Note: Not supported by the async scanner
258-
/// TODO(ARROW-11797) Deprecate Scan()
258+
/// Planned for removal from the public API in ARROW-11782.
259+
ARROW_DEPRECATED("Deprecated in 4.0.0 for removal in 5.0.0. Use ScanBatches().")
259260
virtual Result<ScanTaskIterator> Scan();
261+
262+
/// \brief Apply a visitor to each RecordBatch as it is scanned. If multiple threads
263+
/// are used (via use_threads), the visitor will be invoked from those threads and is
264+
/// responsible for any synchronization.
265+
virtual Status Scan(std::function<Status(TaggedRecordBatch)> visitor) = 0;
260266
/// \brief Convert a Scanner into a Table.
261267
///
262268
/// Use this convenience utility with care. This will serially materialize the
@@ -279,6 +285,10 @@ class ARROW_DS_EXPORT Scanner {
279285
/// To make up for the out-of-order iteration each batch is further tagged with
280286
/// positional information.
281287
virtual Result<EnumeratedRecordBatchIterator> ScanBatchesUnordered();
288+
/// \brief A convenience to synchronously load the given rows by index.
289+
///
290+
/// Will only consume as many batches as needed from ScanBatches().
291+
virtual Result<std::shared_ptr<Table>> TakeRows(const Array& indices);
282292

283293
/// \brief Get the options for this scan.
284294
const std::shared_ptr<ScanOptions>& options() const { return scan_options_; }
@@ -306,12 +316,15 @@ class ARROW_DS_EXPORT SyncScanner : public Scanner {
306316

307317
Result<ScanTaskIterator> Scan() override;
308318

319+
Status Scan(std::function<Status(TaggedRecordBatch)> visitor) override;
320+
309321
Result<std::shared_ptr<Table>> ToTable() override;
310322

311323
protected:
312324
/// \brief GetFragments returns an iterator over all Fragments in this scan.
313325
Result<FragmentIterator> GetFragments();
314326
Future<std::shared_ptr<Table>> ToTableInternal(internal::Executor* cpu_executor);
327+
Result<ScanTaskIterator> ScanInternal();
315328

316329
std::shared_ptr<Dataset> dataset_;
317330
// TODO(ARROW-8065) remove fragment_ after a Dataset is constuctible from fragments

0 commit comments

Comments
 (0)