Skip to content

Commit d190683

Browse files
tensorflower-gardenergunan
authored andcommitted
Deal with parallel concat as a registered graph optimization pass.
Change: 144905301
1 parent a4218e7 commit d190683

2 files changed

Lines changed: 126 additions & 83 deletions

File tree

tensorflow/core/common_runtime/graph_optimizer.cc

Lines changed: 0 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -22,84 +22,6 @@ limitations under the License.
2222
#include "tensorflow/core/graph/optimizer_cse.h"
2323

2424
namespace tensorflow {
25-
namespace {
26-
27-
// Replaces occurrences of parallel_concat with the implementation based on
28-
// unsafe ops. Sets removed_any to true if any parallel_concats were removed;
29-
// leaves it untouched otherwise.
30-
Status RemoveParallelConcat(bool* removed_any, Graph* g) {
31-
gtl::InlinedVector<Node*, 2> matches;
32-
for (Node* n : g->nodes()) {
33-
if (n->type_string() == "ParallelConcat") {
34-
matches.push_back(n);
35-
}
36-
}
37-
for (Node* n : matches) {
38-
AttrSlice n_attrs(n->def());
39-
auto make_node = [n, g, &n_attrs](string op) {
40-
NodeBuilder node_builder(g->NewName(n->name()), op);
41-
node_builder.Device(n->def().device());
42-
string colo;
43-
if (GetNodeAttr(n_attrs, "_class", &colo).ok()) {
44-
node_builder.Attr("_class", colo);
45-
}
46-
return node_builder;
47-
};
48-
DataType dtype;
49-
TF_RETURN_IF_ERROR(GetNodeAttr(n_attrs, "T", &dtype));
50-
TensorShapeProto shape;
51-
TF_RETURN_IF_ERROR(GetNodeAttr(n_attrs, "shape", &shape));
52-
53-
// Add the start node
54-
Node* start;
55-
TF_RETURN_IF_ERROR(make_node("_ParallelConcatStart")
56-
.Attr("shape", shape)
57-
.Attr("dtype", dtype)
58-
.Finalize(g, &start));
59-
60-
// Add all the inplace_updates.
61-
std::vector<Node*> control_nodes;
62-
int64 i = 0;
63-
for (const Edge* input_edge : n->in_edges()) {
64-
if (input_edge->IsControlEdge()) {
65-
g->AddControlEdge(input_edge->src(), start);
66-
continue;
67-
}
68-
69-
Node* update;
70-
TF_RETURN_IF_ERROR(make_node("_ParallelConcatUpdate")
71-
.Attr("loc", i)
72-
.Input(start)
73-
.Input(input_edge->src(), input_edge->src_output())
74-
.Finalize(g, &update));
75-
control_nodes.push_back(update);
76-
77-
++i;
78-
}
79-
80-
// Add the final identity.
81-
NodeBuilder identity_def = make_node("Identity");
82-
identity_def.Input(start, 0);
83-
for (Node* s : control_nodes) {
84-
identity_def.ControlInput(s);
85-
}
86-
Node* identity_node;
87-
TF_RETURN_IF_ERROR(identity_def.Finalize(g, &identity_node));
88-
89-
// Remove the node and redirect edges.
90-
for (auto* e : n->out_edges()) {
91-
if (e->IsControlEdge()) {
92-
g->AddControlEdge(identity_node, e->dst());
93-
} else {
94-
g->AddEdge(identity_node, 0, e->dst(), e->dst_input());
95-
}
96-
}
97-
g->RemoveNode(n);
98-
*removed_any = true;
99-
}
100-
return Status::OK();
101-
}
102-
}
10325

10426
GraphOptimizer::GraphOptimizer(const OptimizerOptions& opts) : opts_(opts) {
10527
if (opts_.opt_level() >= OptimizerOptions::L1) {
@@ -123,11 +45,6 @@ void GraphOptimizer::Optimize(FunctionLibraryRuntime* runtime, Env* env,
12345
DumpGraph("RemoveListArrayConverter", g);
12446
changed = true;
12547
}
126-
auto s = RemoveParallelConcat(&changed, g);
127-
if (!s.ok()) {
128-
// TODO(apassos): figure out how to halt here.
129-
LOG(WARNING) << s;
130-
}
13148
if (opts_.do_function_inlining() && RemoveDeadNodes(g)) {
13249
DumpGraph("RemoveDeadNodes", g);
13350
changed = true;
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/core/common_runtime/graph_optimizer.h"
17+
18+
#include "tensorflow/core/common_runtime/constant_folding.h"
19+
#include "tensorflow/core/common_runtime/function.h"
20+
#include "tensorflow/core/common_runtime/optimization_registry.h"
21+
#include "tensorflow/core/graph/algorithm.h"
22+
#include "tensorflow/core/graph/node_builder.h"
23+
#include "tensorflow/core/graph/optimizer_cse.h"
24+
25+
namespace tensorflow {
26+
namespace {
27+
28+
// Replaces occurrences of parallel_concat with the implementation based on
29+
// unsafe ops. Sets removed_any to true if any parallel_concats were removed;
30+
// leaves it untouched otherwise.
31+
class ParallelConcatRemovePass : public GraphOptimizationPass {
32+
public:
33+
Status Run(const GraphOptimizationPassOptions& options) override {
34+
if (options.graph == nullptr) {
35+
// TODO(apassos) returning OK feels weird here as we can't do anything
36+
// without a graph, but some tests require this.
37+
return Status::OK();
38+
}
39+
Graph* g = options.graph->get();
40+
if (g == nullptr) {
41+
return errors::Internal(
42+
"Parallel concat removal should happen before partitioning and a "
43+
"graph should be available.");
44+
}
45+
gtl::InlinedVector<Node*, 2> matches;
46+
for (Node* n : g->nodes()) {
47+
if (n->type_string() == "ParallelConcat") {
48+
matches.push_back(n);
49+
}
50+
}
51+
for (Node* n : matches) {
52+
AttrSlice n_attrs(n->def());
53+
auto base_make_node = [n, g, &n_attrs](const string& op,
54+
const string& name) {
55+
NodeBuilder node_builder(name, op);
56+
node_builder.Device(n->def().device());
57+
string colo;
58+
if (GetNodeAttr(n_attrs, "_class", &colo).ok()) {
59+
node_builder.Attr("_class", colo);
60+
}
61+
return node_builder;
62+
};
63+
auto make_node = [n, g, &n_attrs, &base_make_node](string op) {
64+
return base_make_node(
65+
op, g->NewName(strings::StrCat(n->name(), "/Internal")));
66+
};
67+
DataType dtype;
68+
TF_RETURN_IF_ERROR(GetNodeAttr(n_attrs, "T", &dtype));
69+
TensorShapeProto shape;
70+
TF_RETURN_IF_ERROR(GetNodeAttr(n_attrs, "shape", &shape));
71+
72+
// Add the start node
73+
Node* start;
74+
TF_RETURN_IF_ERROR(make_node("_ParallelConcatStart")
75+
.Attr("shape", shape)
76+
.Attr("dtype", dtype)
77+
.Finalize(g, &start));
78+
79+
// Add all the inplace_updates.
80+
std::vector<Node*> control_nodes;
81+
int64 i = 0;
82+
for (const Edge* input_edge : n->in_edges()) {
83+
if (input_edge->IsControlEdge()) {
84+
g->AddControlEdge(input_edge->src(), start);
85+
continue;
86+
}
87+
88+
Node* update;
89+
TF_RETURN_IF_ERROR(
90+
make_node("_ParallelConcatUpdate")
91+
.Attr("loc", i)
92+
.Input(start)
93+
.Input(input_edge->src(), input_edge->src_output())
94+
.Finalize(g, &update));
95+
control_nodes.push_back(update);
96+
97+
++i;
98+
}
99+
100+
// Add the final identity.
101+
NodeBuilder identity_def = base_make_node("Identity", n->name());
102+
identity_def.Input(start, 0);
103+
for (Node* s : control_nodes) {
104+
identity_def.ControlInput(s);
105+
}
106+
Node* identity_node;
107+
TF_RETURN_IF_ERROR(identity_def.Finalize(g, &identity_node));
108+
109+
// Remove the node and redirect edges.
110+
for (auto* e : n->out_edges()) {
111+
if (e->IsControlEdge()) {
112+
g->AddControlEdge(identity_node, e->dst());
113+
} else {
114+
g->AddEdge(identity_node, 0, e->dst(), e->dst_input());
115+
}
116+
}
117+
g->RemoveNode(n);
118+
}
119+
return Status::OK();
120+
}
121+
};
122+
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 0,
123+
ParallelConcatRemovePass);
124+
125+
} // namespace
126+
} // namespace tensorflow

0 commit comments

Comments
 (0)