1818#include " arrow/dataset/scanner.h"
1919
2020#include < algorithm>
21+ #include < condition_variable>
2122#include < memory>
2223#include < mutex>
24+ #include < sstream>
2325
26+ #include " arrow/array/array_primitive.h"
2427#include " arrow/compute/api_scalar.h"
28+ #include " arrow/compute/api_vector.h"
29+ #include " arrow/compute/cast.h"
2530#include " arrow/dataset/dataset.h"
2631#include " arrow/dataset/dataset_internal.h"
2732#include " arrow/dataset/scanner_internal.h"
@@ -132,37 +137,124 @@ Result<EnumeratedRecordBatchIterator> Scanner::AddPositioningToInOrderScan(
132137 EnumeratingIterator{std::make_shared<State>(std::move (scan), std::move (first))});
133138}
134139
135- Result<TaggedRecordBatchIterator> SyncScanner::ScanBatches () {
136- // TODO(ARROW-11797) Provide a better implementation that does readahead. Also, add
137- // unit testing
138- ARROW_ASSIGN_OR_RAISE (auto scan_task_it, Scan ());
139- struct BatchIter {
140- explicit BatchIter (ScanTaskIterator scan_task_it)
141- : scan_task_it(std::move(scan_task_it)) {}
142-
143- Result<TaggedRecordBatch> Next () {
144- while (true ) {
145- if (current_task == nullptr ) {
146- ARROW_ASSIGN_OR_RAISE (current_task, scan_task_it.Next ());
147- if (IsIterationEnd<std::shared_ptr<ScanTask>>(current_task)) {
148- return IterationEnd<TaggedRecordBatch>();
149- }
150- ARROW_ASSIGN_OR_RAISE (batch_it, current_task->Execute ());
151- }
152- ARROW_ASSIGN_OR_RAISE (auto next, batch_it.Next ());
153- if (IsIterationEnd<std::shared_ptr<RecordBatch>>(next)) {
154- current_task = nullptr ;
155- } else {
156- return TaggedRecordBatch{next, current_task->fragment ()};
157- }
140+ struct ScanBatchesState : public std ::enable_shared_from_this<ScanBatchesState> {
141+ explicit ScanBatchesState (ScanTaskIterator scan_task_it,
142+ std::shared_ptr<TaskGroup> task_group_)
143+ : scan_tasks(std::move(scan_task_it)), task_group(std::move(task_group_)) {}
144+
145+ void ResizeBatches (size_t task_index) {
146+ if (task_batches.size () <= task_index) {
147+ task_batches.resize (task_index + 1 );
148+ task_drained.resize (task_index + 1 );
149+ }
150+ }
151+
152+ void Push (TaggedRecordBatch batch, size_t task_index) {
153+ {
154+ std::lock_guard<std::mutex> lock (mutex);
155+ ResizeBatches (task_index);
156+ task_batches[task_index].push_back (std::move (batch));
157+ }
158+ ready.notify_one ();
159+ }
160+
161+ Status Finish (size_t task_index) {
162+ {
163+ std::lock_guard<std::mutex> lock (mutex);
164+ ResizeBatches (task_index);
165+ task_drained[task_index] = true ;
166+ }
167+ ready.notify_one ();
168+ return Status::OK ();
169+ }
170+
171+ void PushScanTask () {
172+ if (no_more_tasks) return ;
173+ std::unique_lock<std::mutex> lock (mutex);
174+ auto maybe_task = scan_tasks.Next ();
175+ if (!maybe_task.ok ()) {
176+ no_more_tasks = true ;
177+ iteration_error = maybe_task.status ();
178+ return ;
179+ }
180+ auto scan_task = maybe_task.ValueOrDie ();
181+ if (IsIterationEnd (scan_task)) {
182+ no_more_tasks = true ;
183+ return ;
184+ }
185+ auto state = shared_from_this ();
186+ auto id = next_scan_task_id++;
187+ ResizeBatches (id);
188+
189+ lock.unlock ();
190+ task_group->Append ([state, id, scan_task]() {
191+ ARROW_ASSIGN_OR_RAISE (auto batch_it, scan_task->Execute ());
192+ for (auto maybe_batch : batch_it) {
193+ ARROW_ASSIGN_OR_RAISE (auto batch, maybe_batch);
194+ state->Push (TaggedRecordBatch{std::move (batch), scan_task->fragment ()}, id);
158195 }
196+ return state->Finish (id);
197+ });
198+ }
199+
200+ Result<TaggedRecordBatch> Pop () {
201+ std::unique_lock<std::mutex> lock (mutex);
202+ ready.wait (lock, [this , &lock] {
203+ while (pop_cursor < task_batches.size ()) {
204+ // queue for current scan task contains at least one batch, pop that
205+ if (!task_batches[pop_cursor].empty ()) return true ;
206+ // queue is empty but will be appended to eventually, wait for that
207+ if (!task_drained[pop_cursor]) return false ;
208+
209+ // Finished draining current scan task, enqueue a new one
210+ ++pop_cursor;
211+ // Must unlock since serial task group will execute synchronously
212+ lock.unlock ();
213+ PushScanTask ();
214+ lock.lock ();
215+ }
216+ DCHECK (no_more_tasks);
217+ // all scan tasks drained (or getting next task failed), terminate
218+ return true ;
219+ });
220+
221+ if (pop_cursor == task_batches.size ()) {
222+ // Don't report an error until we yield up everything we can first
223+ RETURN_NOT_OK (iteration_error);
224+ return IterationEnd<TaggedRecordBatch>();
159225 }
160226
161- ScanTaskIterator scan_task_it;
162- RecordBatchIterator batch_it;
163- std::shared_ptr<ScanTask> current_task;
164- };
165- return TaggedRecordBatchIterator (BatchIter (std::move (scan_task_it)));
227+ auto batch = std::move (task_batches[pop_cursor].front ());
228+ task_batches[pop_cursor].pop_front ();
229+ return batch;
230+ }
231+
232+ // / Protecting mutating accesses to batches
233+ std::mutex mutex;
234+ std::condition_variable ready;
235+ ScanTaskIterator scan_tasks;
236+ std::shared_ptr<TaskGroup> task_group;
237+ int next_scan_task_id = 0 ;
238+ bool no_more_tasks = false ;
239+ Status iteration_error;
240+ std::vector<std::deque<TaggedRecordBatch>> task_batches;
241+ std::vector<bool > task_drained;
242+ size_t pop_cursor = 0 ;
243+ };
244+
245+ Result<TaggedRecordBatchIterator> SyncScanner::ScanBatches () {
246+ ARROW_ASSIGN_OR_RAISE (auto scan_task_it, ScanInternal ());
247+ auto task_group = scan_options_->TaskGroup ();
248+ auto state = std::make_shared<ScanBatchesState>(std::move (scan_task_it), task_group);
249+ for (int i = 0 ; i < scan_options_->fragment_readahead ; i++) {
250+ state->PushScanTask ();
251+ }
252+ return MakeFunctionIterator ([task_group, state]() -> Result<TaggedRecordBatch> {
253+ ARROW_ASSIGN_OR_RAISE (auto batch, state->Pop ());
254+ if (!IsIterationEnd (batch)) return batch;
255+ RETURN_NOT_OK (task_group->Finish ());
256+ return IterationEnd<TaggedRecordBatch>();
257+ });
166258}
167259
168260Result<FragmentIterator> SyncScanner::GetFragments () {
@@ -176,7 +268,30 @@ Result<FragmentIterator> SyncScanner::GetFragments() {
176268 return GetFragmentsFromDatasets ({dataset_}, scan_options_->filter );
177269}
178270
179- Result<ScanTaskIterator> SyncScanner::Scan () {
271+ Result<ScanTaskIterator> SyncScanner::Scan () { return ScanInternal (); }
272+
273+ Status SyncScanner::Scan (std::function<Status (TaggedRecordBatch)> visitor) {
274+ ARROW_ASSIGN_OR_RAISE (auto scan_task_it, ScanInternal ());
275+
276+ auto task_group = scan_options_->TaskGroup ();
277+
278+ for (auto maybe_scan_task : scan_task_it) {
279+ ARROW_ASSIGN_OR_RAISE (auto scan_task, maybe_scan_task);
280+ task_group->Append ([scan_task, visitor] {
281+ ARROW_ASSIGN_OR_RAISE (auto batch_it, scan_task->Execute ());
282+ for (auto maybe_batch : batch_it) {
283+ ARROW_ASSIGN_OR_RAISE (auto batch, maybe_batch);
284+ RETURN_NOT_OK (
285+ visitor (TaggedRecordBatch{std::move (batch), scan_task->fragment ()}));
286+ }
287+ return Status::OK ();
288+ });
289+ }
290+
291+ return task_group->Finish ();
292+ }
293+
294+ Result<ScanTaskIterator> SyncScanner::ScanInternal () {
180295 // Transforms Iterator<Fragment> into a unified
181296 // Iterator<ScanTask>. The first Iterator::Next invocation is going to do
182297 // all the work of unwinding the chained iterators.
@@ -315,7 +430,7 @@ Result<std::shared_ptr<Table>> SyncScanner::ToTable() {
315430
316431Future<std::shared_ptr<Table>> SyncScanner::ToTableInternal (
317432 internal::Executor* cpu_executor) {
318- ARROW_ASSIGN_OR_RAISE (auto scan_task_it, Scan ());
433+ ARROW_ASSIGN_OR_RAISE (auto scan_task_it, ScanInternal ());
319434 auto task_group = scan_options_->TaskGroup ();
320435
321436 // / Wraps the state in a shared_ptr to ensure that failing ScanTasks don't
@@ -343,5 +458,94 @@ Future<std::shared_ptr<Table>> SyncScanner::ToTableInternal(
343458 FlattenRecordBatchVector (std::move (state->batches )));
344459}
345460
461+ Result<std::shared_ptr<Table>> Scanner::TakeRows (const Array& indices) {
462+ if (indices.null_count () != 0 ) {
463+ return Status::NotImplemented (" null take indices" );
464+ }
465+
466+ compute::ExecContext ctx (scan_options_->pool );
467+
468+ const Array* original_indices;
469+ // If we have to cast, this is the backing reference
470+ std::shared_ptr<Array> original_indices_ptr;
471+ if (indices.type_id () != Type::INT64 ) {
472+ ARROW_ASSIGN_OR_RAISE (
473+ original_indices_ptr,
474+ compute::Cast (indices, int64 (), compute::CastOptions::Safe (), &ctx));
475+ original_indices = original_indices_ptr.get ();
476+ } else {
477+ original_indices = &indices;
478+ }
479+
480+ std::shared_ptr<Array> unsort_indices;
481+ {
482+ ARROW_ASSIGN_OR_RAISE (
483+ auto sort_indices,
484+ compute::SortIndices (*original_indices, compute::SortOrder::Ascending, &ctx));
485+ ARROW_ASSIGN_OR_RAISE (original_indices_ptr,
486+ compute::Take (*original_indices, *sort_indices,
487+ compute::TakeOptions::Defaults (), &ctx));
488+ original_indices = original_indices_ptr.get ();
489+ ARROW_ASSIGN_OR_RAISE (
490+ unsort_indices,
491+ compute::SortIndices (*sort_indices, compute::SortOrder::Ascending, &ctx));
492+ }
493+
494+ RecordBatchVector out_batches;
495+
496+ auto raw_indices = static_cast <const Int64Array&>(*original_indices).raw_values ();
497+ int64_t offset = 0 , row_begin = 0 ;
498+
499+ ARROW_ASSIGN_OR_RAISE (auto batch_it, ScanBatches ());
500+ while (true ) {
501+ ARROW_ASSIGN_OR_RAISE (auto batch, batch_it.Next ());
502+ if (IsIterationEnd (batch)) break ;
503+ if (offset == original_indices->length ()) break ;
504+ DCHECK_LT (offset, original_indices->length ());
505+
506+ int64_t length = 0 ;
507+ while (offset + length < original_indices->length ()) {
508+ auto rel_index = raw_indices[offset + length] - row_begin;
509+ if (rel_index >= batch.record_batch ->num_rows ()) break ;
510+ ++length;
511+ }
512+ DCHECK_LE (offset + length, original_indices->length ());
513+ if (length == 0 ) {
514+ row_begin += batch.record_batch ->num_rows ();
515+ continue ;
516+ }
517+
518+ Datum rel_indices = original_indices->Slice (offset, length);
519+ ARROW_ASSIGN_OR_RAISE (rel_indices,
520+ compute::Subtract (rel_indices, Datum (row_begin),
521+ compute::ArithmeticOptions (), &ctx));
522+
523+ ARROW_ASSIGN_OR_RAISE (Datum out_batch,
524+ compute::Take (batch.record_batch , rel_indices,
525+ compute::TakeOptions::Defaults (), &ctx));
526+ out_batches.push_back (out_batch.record_batch ());
527+
528+ offset += length;
529+ row_begin += batch.record_batch ->num_rows ();
530+ }
531+
532+ if (offset < original_indices->length ()) {
533+ std::stringstream error;
534+ const int64_t max_values_shown = 3 ;
535+ const int64_t num_remaining = original_indices->length () - offset;
536+ for (int64_t i = 0 ; i < std::min<int64_t >(max_values_shown, num_remaining); i++) {
537+ if (i > 0 ) error << " , " ;
538+ error << static_cast <const Int64Array*>(original_indices)->Value (offset + i);
539+ }
540+ if (num_remaining > max_values_shown) error << " , ..." ;
541+ return Status::IndexError (" Some indices were out of bounds: " , error.str ());
542+ }
543+ ARROW_ASSIGN_OR_RAISE (Datum out, Table::FromRecordBatches (options ()->projected_schema ,
544+ std::move (out_batches)));
545+ ARROW_ASSIGN_OR_RAISE (
546+ out, compute::Take (out, unsort_indices, compute::TakeOptions::Defaults (), &ctx));
547+ return out.table ();
548+ }
549+
346550} // namespace dataset
347551} // namespace arrow
0 commit comments