forked from apache/arrow
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathunion_node.cc
More file actions
129 lines (106 loc) · 4.15 KB
/
Copy pathunion_node.cc
File metadata and controls
129 lines (106 loc) · 4.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <mutex>
#include "arrow/compute/api.h"
#include "arrow/compute/exec/exec_plan.h"
#include "arrow/compute/exec/options.h"
#include "arrow/compute/exec/util.h"
#include "arrow/util/bitmap_ops.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/future.h"
#include "arrow/util/logging.h"
#include "arrow/util/string.h"
#include "arrow/util/thread_pool.h"
#include "arrow/util/tracing_internal.h"
namespace arrow {
using internal::checked_cast;
using internal::ToChars;
namespace compute {
namespace {
std::vector<std::string> GetInputLabels(const ExecNode::NodeVector& inputs) {
std::vector<std::string> labels(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
labels[i] = "input_" + std::to_string(i) + "_label";
}
return labels;
}
} // namespace
class UnionNode : public ExecNode, public TracedNode<UnionNode> {
public:
UnionNode(ExecPlan* plan, std::vector<ExecNode*> inputs)
: ExecNode(plan, inputs, GetInputLabels(inputs),
/*output_schema=*/inputs[0]->output_schema()) {
bool counter_completed = input_count_.SetTotal(static_cast<int>(inputs.size()));
ARROW_DCHECK(counter_completed == false);
}
const char* kind_name() const override { return "UnionNode"; }
static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
const ExecNodeOptions& options) {
RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, static_cast<int>(inputs.size()),
"UnionNode"));
if (inputs.size() < 1) {
return Status::Invalid("Constructing a `UnionNode` with inputs size less than 1");
}
auto schema = inputs.at(0)->output_schema();
for (auto input : inputs) {
if (!input->output_schema()->Equals(schema)) {
return Status::Invalid(
"UnionNode input schemas must all match, first schema was: ",
schema->ToString(), " got schema: ", input->output_schema()->ToString());
}
}
return plan->EmplaceNode<UnionNode>(plan, std::move(inputs));
}
Status InputReceived(ExecNode* input, ExecBatch batch) override {
NoteInputReceived(batch);
ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end());
return output_->InputReceived(this, std::move(batch));
}
Status InputFinished(ExecNode* input, int total_batches) override {
ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end());
total_batches_.fetch_add(total_batches);
if (input_count_.Increment()) {
return output_->InputFinished(this, total_batches_.load());
}
return Status::OK();
}
Status StartProducing() override {
NoteStartProducing(ToStringExtra());
return Status::OK();
}
void PauseProducing(ExecNode* output, int32_t counter) override {
for (auto* input : inputs_) {
input->PauseProducing(this, counter);
}
}
void ResumeProducing(ExecNode* output, int32_t counter) override {
for (auto* input : inputs_) {
input->ResumeProducing(this, counter);
}
}
Status StopProducingImpl() override { return Status::OK(); }
private:
AtomicCounter input_count_;
std::atomic<int> total_batches_{0};
};
namespace internal {
void RegisterUnionNode(ExecFactoryRegistry* registry) {
DCHECK_OK(registry->AddFactory("union", UnionNode::Make));
}
} // namespace internal
} // namespace compute
} // namespace arrow