Skip to content

Commit 79ae4f6

Browse files
committed
ARROW-8732: [C++] Add basic cancellation API
In this model, a StopSource is instantiated by the consumer, which passes a corresponding StopToken to producer API(s). Closes apache#9528 from pitrou/ARROW-8732-cancel-v2 Authored-by: Antoine Pitrou <antoine@python.org> Signed-off-by: Antoine Pitrou <antoine@python.org>
1 parent 7301f68 commit 79ae4f6

30 files changed

Lines changed: 1491 additions & 134 deletions

cpp/src/arrow/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ set(ARROW_SRCS
181181
util/bitmap_builders.cc
182182
util/bitmap_ops.cc
183183
util/bpacking.cc
184+
util/cancel.cc
184185
util/compression.cc
185186
util/cpu_info.cc
186187
util/decimal.cc

cpp/src/arrow/csv/reader.cc

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -313,12 +313,13 @@ class ReaderMixin {
313313
public:
314314
ReaderMixin(MemoryPool* pool, std::shared_ptr<io::InputStream> input,
315315
const ReadOptions& read_options, const ParseOptions& parse_options,
316-
const ConvertOptions& convert_options)
316+
const ConvertOptions& convert_options, StopToken stop_token)
317317
: pool_(pool),
318318
read_options_(read_options),
319319
parse_options_(parse_options),
320320
convert_options_(convert_options),
321-
input_(std::move(input)) {}
321+
input_(std::move(input)),
322+
stop_token_(std::move(stop_token)) {}
322323

323324
protected:
324325
// Read header and column names from buffer, create column builders
@@ -500,6 +501,7 @@ class ReaderMixin {
500501

501502
std::shared_ptr<io::InputStream> input_;
502503
std::shared_ptr<internal::TaskGroup> task_group_;
504+
StopToken stop_token_;
503505
};
504506

505507
/////////////////////////////////////////////////////////////////////////
@@ -697,7 +699,7 @@ class SerialStreamingReader : public BaseStreamingReader {
697699
ARROW_ASSIGN_OR_RAISE(auto rh_it,
698700
MakeReadaheadIterator(std::move(istream_it), block_queue_size));
699701
buffer_iterator_ = CSVBufferIterator::Make(std::move(rh_it));
700-
task_group_ = internal::TaskGroup::MakeSerial();
702+
task_group_ = internal::TaskGroup::MakeSerial(stop_token_);
701703

702704
// Read schema from first batch
703705
ARROW_ASSIGN_OR_RAISE(pending_batch_, ReadNext());
@@ -710,6 +712,10 @@ class SerialStreamingReader : public BaseStreamingReader {
710712
if (eof_) {
711713
return nullptr;
712714
}
715+
if (stop_token_.IsStopRequested()) {
716+
eof_ = true;
717+
return stop_token_.Poll();
718+
}
713719
if (!block_iterator_) {
714720
Status st = SetupReader();
715721
if (!st.ok()) {
@@ -790,7 +796,7 @@ class SerialTableReader : public BaseTableReader {
790796
}
791797

792798
Result<std::shared_ptr<Table>> Read() override {
793-
task_group_ = internal::TaskGroup::MakeSerial();
799+
task_group_ = internal::TaskGroup::MakeSerial(stop_token_);
794800

795801
// First block
796802
ARROW_ASSIGN_OR_RAISE(auto first_buffer, buffer_iterator_.Next());
@@ -804,6 +810,8 @@ class SerialTableReader : public BaseTableReader {
804810
MakeChunker(parse_options_),
805811
std::move(first_buffer));
806812
while (true) {
813+
RETURN_NOT_OK(stop_token_.Poll());
814+
807815
ARROW_ASSIGN_OR_RAISE(auto maybe_block, block_iterator.Next());
808816
if (maybe_block == IterationTraits<CSVBlock>::End()) {
809817
// EOF
@@ -833,9 +841,10 @@ class AsyncThreadedTableReader
833841
AsyncThreadedTableReader(MemoryPool* pool, std::shared_ptr<io::InputStream> input,
834842
const ReadOptions& read_options,
835843
const ParseOptions& parse_options,
836-
const ConvertOptions& convert_options, Executor* cpu_executor,
837-
Executor* io_executor)
838-
: BaseTableReader(pool, input, read_options, parse_options, convert_options),
844+
const ConvertOptions& convert_options, StopToken stop_token,
845+
Executor* cpu_executor, Executor* io_executor)
846+
: BaseTableReader(pool, input, read_options, parse_options, convert_options,
847+
std::move(stop_token)),
839848
cpu_executor_(cpu_executor),
840849
io_executor_(io_executor) {}
841850

@@ -870,7 +879,7 @@ class AsyncThreadedTableReader
870879
Result<std::shared_ptr<Table>> Read() override { return ReadAsync().result(); }
871880

872881
Future<std::shared_ptr<Table>> ReadAsync() override {
873-
task_group_ = internal::TaskGroup::MakeThreaded(cpu_executor_);
882+
task_group_ = internal::TaskGroup::MakeThreaded(cpu_executor_, stop_token_);
874883

875884
auto self = shared_from_this();
876885
return ProcessFirstBuffer().Then([self](std::shared_ptr<Buffer> first_buffer) {
@@ -939,17 +948,30 @@ Result<std::shared_ptr<TableReader>> MakeTableReader(
939948
if (read_options.use_threads) {
940949
auto cpu_executor = internal::GetCpuThreadPool();
941950
auto io_executor = io_context.executor();
942-
reader = std::make_shared<AsyncThreadedTableReader>(pool, input, read_options,
943-
parse_options, convert_options,
944-
cpu_executor, io_executor);
951+
reader = std::make_shared<AsyncThreadedTableReader>(
952+
pool, input, read_options, parse_options, convert_options,
953+
io_context.stop_token(), cpu_executor, io_executor);
945954
} else {
946-
reader = std::make_shared<SerialTableReader>(pool, input, read_options, parse_options,
947-
convert_options);
955+
reader =
956+
std::make_shared<SerialTableReader>(pool, input, read_options, parse_options,
957+
convert_options, io_context.stop_token());
948958
}
949959
RETURN_NOT_OK(reader->Init());
950960
return reader;
951961
}
952962

963+
Result<std::shared_ptr<StreamingReader>> MakeStreamingReader(
964+
io::IOContext io_context, std::shared_ptr<io::InputStream> input,
965+
const ReadOptions& read_options, const ParseOptions& parse_options,
966+
const ConvertOptions& convert_options) {
967+
std::shared_ptr<BaseStreamingReader> reader;
968+
reader = std::make_shared<SerialStreamingReader>(io_context.pool(), input, read_options,
969+
parse_options, convert_options,
970+
io_context.stop_token());
971+
RETURN_NOT_OK(reader->Init());
972+
return reader;
973+
}
974+
953975
} // namespace
954976

955977
/////////////////////////////////////////////////////////////////////////
@@ -975,13 +997,17 @@ Result<std::shared_ptr<StreamingReader>> StreamingReader::Make(
975997
MemoryPool* pool, std::shared_ptr<io::InputStream> input,
976998
const ReadOptions& read_options, const ParseOptions& parse_options,
977999
const ConvertOptions& convert_options) {
978-
std::shared_ptr<BaseStreamingReader> reader;
979-
reader = std::make_shared<SerialStreamingReader>(pool, input, read_options,
980-
parse_options, convert_options);
981-
RETURN_NOT_OK(reader->Init());
982-
return reader;
1000+
return MakeStreamingReader(io::IOContext(pool), std::move(input), read_options,
1001+
parse_options, convert_options);
9831002
}
9841003

985-
} // namespace csv
1004+
Result<std::shared_ptr<StreamingReader>> StreamingReader::Make(
1005+
io::IOContext io_context, std::shared_ptr<io::InputStream> input,
1006+
const ReadOptions& read_options, const ParseOptions& parse_options,
1007+
const ConvertOptions& convert_options) {
1008+
return MakeStreamingReader(io_context, std::move(input), read_options, parse_options,
1009+
convert_options);
1010+
}
9861011

1012+
} // namespace csv
9871013
} // namespace arrow

cpp/src/arrow/csv/reader.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class ARROW_EXPORT TableReader {
5252
const ParseOptions&,
5353
const ConvertOptions&);
5454

55-
ARROW_DEPRECATED("Use MemoryPool-less overload (the IOContext holds a pool already)")
55+
ARROW_DEPRECATED("Use MemoryPool-less variant (the IOContext holds a pool already)")
5656
static Result<std::shared_ptr<TableReader>> Make(
5757
MemoryPool* pool, io::IOContext io_context, std::shared_ptr<io::InputStream> input,
5858
const ReadOptions&, const ParseOptions&, const ConvertOptions&);
@@ -67,6 +67,11 @@ class ARROW_EXPORT StreamingReader : public RecordBatchReader {
6767
///
6868
/// Currently, the StreamingReader is always single-threaded (parallel
6969
/// readahead is not supported).
70+
static Result<std::shared_ptr<StreamingReader>> Make(
71+
io::IOContext io_context, std::shared_ptr<io::InputStream> input,
72+
const ReadOptions&, const ParseOptions&, const ConvertOptions&);
73+
74+
ARROW_DEPRECATED("Use IOContext-based overload")
7075
static Result<std::shared_ptr<StreamingReader>> Make(
7176
MemoryPool* pool, std::shared_ptr<io::InputStream> input, const ReadOptions&,
7277
const ParseOptions&, const ConvertOptions&);

cpp/src/arrow/dataset/file_csv.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,9 @@ static inline Result<std::shared_ptr<csv::StreamingReader>> OpenReader(
119119
GetConvertOptions(format, scan_options, *first_block, pool));
120120
}
121121

122-
auto maybe_reader = csv::StreamingReader::Make(pool, std::move(input), reader_options,
123-
parse_options, convert_options);
122+
auto maybe_reader =
123+
csv::StreamingReader::Make(io::IOContext(pool), std::move(input), reader_options,
124+
parse_options, convert_options);
124125
if (!maybe_reader.ok()) {
125126
return maybe_reader.status().WithMessage("Could not open CSV input source '",
126127
source.path(), "': ", maybe_reader.status());

cpp/src/arrow/io/interfaces.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ namespace io {
4848

4949
static IOContext g_default_io_context{};
5050

51-
IOContext::IOContext(MemoryPool* pool) : IOContext(pool, internal::GetIOThreadPool()) {}
51+
IOContext::IOContext(MemoryPool* pool, StopToken stop_token)
52+
: IOContext(pool, internal::GetIOThreadPool(), std::move(stop_token)) {}
5253

5354
const IOContext& default_io_context() { return g_default_io_context; }
5455

cpp/src/arrow/io/interfaces.h

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
#include "arrow/io/type_fwd.h"
2626
#include "arrow/type_fwd.h"
27+
#include "arrow/util/cancel.h"
2728
#include "arrow/util/macros.h"
2829
#include "arrow/util/string_view.h"
2930
#include "arrow/util/type_fwd.h"
@@ -56,17 +57,28 @@ struct ReadRange {
5657
/// multiple sources and must distinguish tasks associated with this IOContext).
5758
struct ARROW_EXPORT IOContext {
5859
// No specified executor: will use a global IO thread pool
59-
IOContext() : IOContext(default_memory_pool()) {}
60+
IOContext() : IOContext(default_memory_pool(), StopToken::Unstoppable()) {}
6061

61-
// No specified executor: will use a global IO thread pool
62-
explicit IOContext(MemoryPool* pool);
62+
explicit IOContext(StopToken stop_token)
63+
: IOContext(default_memory_pool(), std::move(stop_token)) {}
64+
65+
explicit IOContext(MemoryPool* pool, StopToken stop_token = StopToken::Unstoppable());
6366

6467
explicit IOContext(MemoryPool* pool, ::arrow::internal::Executor* executor,
68+
StopToken stop_token = StopToken::Unstoppable(),
6569
int64_t external_id = -1)
66-
: pool_(pool), executor_(executor), external_id_(external_id) {}
70+
: pool_(pool),
71+
executor_(executor),
72+
external_id_(external_id),
73+
stop_token_(std::move(stop_token)) {}
6774

68-
explicit IOContext(::arrow::internal::Executor* executor, int64_t external_id = -1)
69-
: pool_(default_memory_pool()), executor_(executor), external_id_(external_id) {}
75+
explicit IOContext(::arrow::internal::Executor* executor,
76+
StopToken stop_token = StopToken::Unstoppable(),
77+
int64_t external_id = -1)
78+
: pool_(default_memory_pool()),
79+
executor_(executor),
80+
external_id_(external_id),
81+
stop_token_(std::move(stop_token)) {}
7082

7183
MemoryPool* pool() const { return pool_; }
7284

@@ -75,10 +87,13 @@ struct ARROW_EXPORT IOContext {
7587
// An application-specific ID, forwarded to executor task submissions
7688
int64_t external_id() const { return external_id_; }
7789

90+
StopToken stop_token() const { return stop_token_; }
91+
7892
private:
7993
MemoryPool* pool_;
8094
::arrow::internal::Executor* executor_;
8195
int64_t external_id_;
96+
StopToken stop_token_;
8297
};
8398

8499
struct ARROW_DEPRECATED("renamed to IOContext in 4.0.0") AsyncContext : public IOContext {

cpp/src/arrow/status.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ std::string Status::CodeAsString(StatusCode code) {
6868
case StatusCode::Invalid:
6969
type = "Invalid";
7070
break;
71+
case StatusCode::Cancelled:
72+
type = "Cancelled";
73+
break;
7174
case StatusCode::IOError:
7275
type = "IOError";
7376
break;

cpp/src/arrow/status.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ enum class StatusCode : char {
8383
IOError = 5,
8484
CapacityError = 6,
8585
IndexError = 7,
86+
Cancelled = 8,
8687
UnknownError = 9,
8788
NotImplemented = 10,
8889
SerializationError = 11,
@@ -204,6 +205,12 @@ class ARROW_MUST_USE_TYPE ARROW_EXPORT Status : public util::EqualityComparable<
204205
return Status::FromArgs(StatusCode::Invalid, std::forward<Args>(args)...);
205206
}
206207

208+
/// Return an error status for cancelled operation
209+
template <typename... Args>
210+
static Status Cancelled(Args&&... args) {
211+
return Status::FromArgs(StatusCode::Cancelled, std::forward<Args>(args)...);
212+
}
213+
207214
/// Return an error status when an index is out of bounds
208215
template <typename... Args>
209216
static Status IndexError(Args&&... args) {
@@ -263,6 +270,8 @@ class ARROW_MUST_USE_TYPE ARROW_EXPORT Status : public util::EqualityComparable<
263270
bool IsKeyError() const { return code() == StatusCode::KeyError; }
264271
/// Return true iff the status indicates invalid data.
265272
bool IsInvalid() const { return code() == StatusCode::Invalid; }
273+
/// Return true iff the status indicates a cancelled operation.
274+
bool IsCancelled() const { return code() == StatusCode::Cancelled; }
266275
/// Return true iff the status indicates an IO-related failure.
267276
bool IsIOError() const { return code() == StatusCode::IOError; }
268277
/// Return true iff the status indicates a container reaching capacity limits.

cpp/src/arrow/testing/gtest_util.cc

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,24 @@ EnvVarGuard::~EnvVarGuard() {
539539
}
540540
}
541541

542+
struct SignalHandlerGuard::Impl {
543+
int signum_;
544+
internal::SignalHandler old_handler_;
545+
546+
Impl(int signum, const internal::SignalHandler& handler)
547+
: signum_(signum), old_handler_(*internal::SetSignalHandler(signum, handler)) {}
548+
549+
~Impl() { ARROW_EXPECT_OK(internal::SetSignalHandler(signum_, old_handler_)); }
550+
};
551+
552+
SignalHandlerGuard::SignalHandlerGuard(int signum, Callback cb)
553+
: SignalHandlerGuard(signum, internal::SignalHandler(cb)) {}
554+
555+
SignalHandlerGuard::SignalHandlerGuard(int signum, const internal::SignalHandler& handler)
556+
: impl_(new Impl{signum, handler}) {}
557+
558+
SignalHandlerGuard::~SignalHandlerGuard() = default;
559+
542560
namespace {
543561

544562
// Used to prevent compiler optimizing away side-effect-less statements
@@ -576,6 +594,13 @@ void SleepFor(double seconds) {
576594
std::chrono::nanoseconds(static_cast<int64_t>(seconds * 1e9)));
577595
}
578596

597+
void BusyWait(double seconds, std::function<bool()> predicate) {
598+
const double period = 0.001;
599+
for (int i = 0; !predicate() && i * period < seconds; ++i) {
600+
SleepFor(period);
601+
}
602+
}
603+
579604
///////////////////////////////////////////////////////////////////////////
580605
// Extension types
581606

cpp/src/arrow/testing/gtest_util.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include <cstdint>
2222
#include <cstdlib>
2323
#include <cstring>
24+
#include <functional>
2425
#include <memory>
2526
#include <string>
2627
#include <type_traits>
@@ -472,6 +473,10 @@ inline void BitmapFromVector(const std::vector<T>& is_valid,
472473
ARROW_TESTING_EXPORT
473474
void SleepFor(double seconds);
474475

476+
// Wait until predicate is true or timeout in seconds expires.
477+
ARROW_TESTING_EXPORT
478+
void BusyWait(double seconds, std::function<bool()> predicate);
479+
475480
template <typename T>
476481
std::vector<T> IteratorToVector(Iterator<T> iterator) {
477482
EXPECT_OK_AND_ASSIGN(auto out, iterator.ToVector());
@@ -504,6 +509,23 @@ class ARROW_TESTING_EXPORT EnvVarGuard {
504509
bool was_set_;
505510
};
506511

512+
namespace internal {
513+
class SignalHandler;
514+
}
515+
516+
class ARROW_TESTING_EXPORT SignalHandlerGuard {
517+
public:
518+
typedef void (*Callback)(int);
519+
520+
SignalHandlerGuard(int signum, Callback cb);
521+
SignalHandlerGuard(int signum, const internal::SignalHandler& handler);
522+
~SignalHandlerGuard();
523+
524+
protected:
525+
struct Impl;
526+
std::unique_ptr<Impl> impl_;
527+
};
528+
507529
#ifndef ARROW_LARGE_MEMORY_TESTS
508530
#define LARGE_MEMORY_TEST(name) DISABLED_##name
509531
#else

0 commit comments

Comments
 (0)