Skip to content

Commit 73cea49

Browse files
author
eellison
committed
make pass more composable
1 parent 2244713 commit 73cea49

File tree

9 files changed

+178
-141
lines changed

9 files changed

+178
-141
lines changed

caffe2/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
429429
${TORCH_SRC_DIR}/csrc/jit/testing/file_check.cpp
430430
${TORCH_SRC_DIR}/csrc/jit/script/final_returns.cpp
431431
${TORCH_SRC_DIR}/csrc/jit/script/convert_to_ssa.cpp
432+
${TORCH_SRC_DIR}/csrc/jit/script/inline_loop_condition.cpp
432433
${TORCH_SRC_DIR}/csrc/jit/script/schema_matching.cpp
433434
${TORCH_SRC_DIR}/csrc/jit/script/script_type_parser.cpp
434435
${TORCH_SRC_DIR}/csrc/jit/script/sugared_value.cpp

tools/build_variables.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
"torch/csrc/jit/script/logging.cpp",
117117
"torch/csrc/jit/script/final_returns.cpp",
118118
"torch/csrc/jit/script/convert_to_ssa.cpp",
119+
"torch/csrc/jit/script/inline_loop_condition.cpp",
119120
"torch/csrc/jit/script/canonicalize_modified_loop.cpp",
120121
"torch/csrc/jit/script/script_type_parser.cpp",
121122
"torch/csrc/jit/script/sugared_value.cpp",

torch/csrc/jit/ir_views.h

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,6 @@ struct LoopView {
5050
Block* bodyBlock() const {
5151
return node_->blocks().at(0);
5252
}
53-
Value* cond() const {
54-
return node_->input(0);
55-
}
5653
Value* maxTripCount() const {
5754
return node_->input(0);
5855
}
@@ -100,10 +97,10 @@ struct LoopView {
10097
}
10198

10299
void replaceMaxTripCount(Value* new_max_trip_count) {
103-
replaceInput(0, new_max_trip_count);
100+
node_->replaceInput(0, new_max_trip_count);
104101
}
105102
void replaceInputCondition(Value* new_input_condition) {
106-
replaceInput(1, new_input_condition);
103+
node_->replaceInput(1, new_input_condition);
107104
}
108105

109106
// our way of encoding loops makes them difficult to turn back into python
@@ -140,11 +137,6 @@ struct LoopView {
140137
private:
141138
Node* node_;
142139

143-
void replaceInput(size_t index, Value* new_input) {
144-
node_->removeInput(index);
145-
node_->insertInput(index, new_input);
146-
}
147-
148140
// adjust index_ordering by adding indices 0 - thorugh adjust, and
149141
// incrementing all existing inputs by adjust
150142
static std::vector<size_t> adjustIndices(

torch/csrc/jit/passes/break_continue_transform.cpp

Lines changed: 81 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ namespace script {
1414
void moveBlockBeforeNode(Node* before_node, Block* block);
1515

1616
/**
17-
* This pass transforms the graph_ so that break & continue statements are
18-
* removed. We transform the graph_ so that ops following a break or continue
19-
* are not run.
17+
* This pass transforms the graph so that break & continue statements are
18+
* removed. We transform the graph so that ops following a break or continue are
19+
* not run.
2020
*/
2121

2222
// Will a block or node continue or break
@@ -32,21 +32,33 @@ struct LoopTransformer {
3232
true_val_ = graph_->insertConstant(true);
3333
false_val_ = graph_->insertConstant(false);
3434
transform_ = transform;
35+
incrementCurString();
3536
};
3637

37-
const std::string& getVarname() {
38+
const std::string getVarname() {
39+
return cur_string;
40+
}
41+
42+
void incrementCurString() {
3843
static const std::string& break_name = "$did_break";
3944
static const std::string& continue_name = "$did_continue";
40-
return transform_ == BREAKS ? break_name : continue_name;
45+
const auto& name = transform_ == BREAKS ? break_name : continue_name;
46+
loop_count++;
47+
cur_string = name + std::to_string(loop_count);
48+
}
49+
50+
void setCurString(const std::string& new_string) {
51+
cur_string = new_string;
4152
}
4253

4354
Symbol transformKind() {
4455
return transform_ == BREAKS ? prim::BreakStmt : prim::ContinueStmt;
4556
}
4657

47-
// Recurses on the if node and returns its return status
48-
// If status != WONT, sets the block_return_val and sentinel val
49-
// of its parent block before exit
58+
// Recursively transform both blocks of the if node.
59+
// If both blocks have hit the transform variable, then the return status,
60+
// is WILL, if both will not hit the transform variable it is false.
61+
// Otherwise we may have hit it.
5062
LoopStatus handleIf(Node* node) {
5163
auto true_block = node->blocks().at(0);
5264
auto false_block = node->blocks().at(1);
@@ -63,23 +75,24 @@ struct LoopTransformer {
6375
}
6476
}
6577

78+
// if an if node might hit a break or continue statement,
79+
// we guard all subsequent nodes in the block, and only execute them
80+
// if the transform is false.
81+
// The LoopStatus is the result of recursing on the newly created if.
6682
LoopStatus guardBlockNodes(
6783
Block* block,
68-
generic_graph_node_list_iterator<Node>& iter) {
69-
// if an if node might hit a break or continue statement,
70-
// we guard all subsequent nodes in the block, and only execute them
71-
// if we did break / did continue is false.
72-
73-
auto new_if = graph_->create(prim::If, 0)->insertBefore(*iter);
84+
graph_node_list::iterator& remaining_block_nodes) {
85+
auto new_if =
86+
graph_->create(prim::If, 0)->insertBefore(*remaining_block_nodes);
7487
auto sentinel =
7588
graph_->createLoad(getVarname(), BoolType::get())->insertBefore(new_if);
7689
new_if->addInput(sentinel->output());
7790

7891
auto hit_control_flow_block = new_if->addBlock();
7992
auto guard_block = new_if->addBlock();
8093

81-
while (iter != block->nodes().end()) {
82-
auto node = *iter++;
94+
while (remaining_block_nodes != block->nodes().end()) {
95+
auto node = *remaining_block_nodes++;
8396
node->moveBefore(guard_block->return_node());
8497
}
8598

@@ -89,11 +102,14 @@ struct LoopTransformer {
89102
// In a graph like:
90103
// for i in range(3):
91104
// if cond == 2:
105+
// k : Optional[int] = None
92106
// if cond == 2:
93107
// m = 2
94108
// break
95109
// k = 1
110+
// j = 2
96111
// else:
112+
// j = 1
97113
// k = 2
98114
// m += k
99115
// We transform the inner cond == 2 block to look like:
@@ -130,65 +146,59 @@ struct LoopTransformer {
130146
iter->destroy();
131147
}
132148

133-
void inlineLoopConditionIntoLoopBody(Node* n) {
134-
auto body_block = n->blocks().at(0);
135-
auto pre_header = n->blocks().at(1);
136-
moveBlockBeforeNode(body_block->return_node(), pre_header);
137-
body_block->insertOutput(0, pre_header->outputs().at(0));
138-
n->eraseBlock(1);
139-
}
140-
141149
void handleLoop(Node* loop_node) {
142-
// transform the loop, then ensure that that it does not accidentally
143-
// pick up or assign the current transform variable outside of the loop.
150+
const std::string prev_string = getVarname();
151+
// Give current loop unique identifier
152+
incrementCurString();
153+
// transform the loop
144154
transformLoop(loop_node);
145-
Block* body_block = loop_node->blocks().at(0);
146-
graph_->createStore(getVarname(), false_val_)
147-
->insertAfter(body_block->param_node());
148-
graph_->createStore(getVarname(), false_val_)
149-
->insertBefore(body_block->return_node());
150-
}
151155

152-
void transformLoop(Node* n) {
153-
Block* body_block = n->blocks().at(0);
154-
auto ret_status = handleTransforms(body_block);
155-
156-
// When we're transforming breaks:
157-
// the body condition has not yet been inlined. If we we are not breaking
158-
// we need to inline the condition block into the end of the loop.
159-
// if we might break, we create an if statement and only execute the loop
160-
// header if we did not break.
161-
// Since we run the continue pass before the break pass,
162-
// we do not need to do any additional work in continues; guardBlock nodes
163-
// ensures that we do not execute any ops present in the block after a
164-
// continue, and loop condition is inlined after.
165-
166-
if (transform_ == CONTINUES) {
167-
return;
168-
}
169-
170-
if (ret_status == WONT) {
171-
inlineLoopConditionIntoLoopBody(n);
172-
return;
173-
}
156+
// restore previous identifier
157+
setCurString(prev_string);
158+
}
174159

175-
WithInsertPoint insert(body_block);
160+
// Create a check for the current transform variable.
161+
// if transform is true, loop continue condition is false, otherwise
162+
// run original condition
163+
void guardConditionBlock(Block* condition_block) {
164+
WithInsertPoint insert(*condition_block->nodes().begin());
176165
auto did_break =
177166
graph_->insertNode(graph_->createLoad(getVarname(), BoolType::get()))
178167
->output();
179-
180168
auto new_loop_condition = graph_->insertNode(graph_->create(prim::If));
181169
new_loop_condition->addInput(did_break);
182170
new_loop_condition->output()->setType(BoolType::get());
183-
184-
// if we did break, we do not continue
185171
new_loop_condition->addBlock()->registerOutput(false_val_);
186172
auto original_condition = new_loop_condition->addBlock();
187-
auto pre_header = n->blocks().at(1);
188-
moveBlockBeforeNode(original_condition->return_node(), pre_header);
189-
original_condition->insertOutput(0, pre_header->outputs().at(0));
190-
n->eraseBlock(1);
191-
body_block->registerOutput(new_loop_condition->output());
173+
174+
Node* n = new_loop_condition;
175+
for (n = n->next(); n != condition_block->return_node();) {
176+
auto cur = n;
177+
n = n->next();
178+
cur->moveBefore(original_condition->return_node());
179+
}
180+
original_condition->insertOutput(0, condition_block->outputs().at(0));
181+
condition_block->eraseOutput(0);
182+
condition_block->registerOutput(new_loop_condition->output());
183+
}
184+
185+
void transformLoop(Node* n) {
186+
Block* body_block = n->blocks().at(0);
187+
auto ret_status = handleTransforms(body_block);
188+
189+
// loop header should run even if we have continued
190+
if (transform_ == CONTINUES || ret_status == WONT) {
191+
return;
192+
}
193+
194+
// because the condition block will get inlined as the start loop condition,
195+
// we need to make sure that it is defined before the loop executes
196+
// (and false so original condition is run). Also insert it into the block
197+
// so it is not an unneccessary loop carried var.
198+
graph_->createStore(getVarname(), false_val_)->insertBefore(n);
199+
graph_->createStore(getVarname(), false_val_)
200+
->insertAfter(body_block->param_node());
201+
guardConditionBlock(n->blocks().at(1));
192202
};
193203

194204
LoopStatus handleTransforms(Block* block) {
@@ -205,7 +215,6 @@ struct LoopTransformer {
205215
if (node->kind() != transformKind()) {
206216
continue;
207217
}
208-
WithInsertPoint b(block);
209218
node->destroy();
210219
loop_status = WILL;
211220
} break;
@@ -244,81 +253,26 @@ struct LoopTransformer {
244253
handleTransforms(graph_->block());
245254
}
246255

256+
size_t loop_count = 0;
247257
Transform transform_;
248258
Value* true_val_ = nullptr;
249259
Value* false_val_ = nullptr;
260+
std::string cur_string = "";
250261

251262
std::shared_ptr<Graph> graph_;
252263
};
253264

254-
void moveBlockBeforeNode(Node* before_node, Block* block) {
255-
for (auto it = block->nodes().begin(); it != block->nodes().end();) {
256-
auto block_node = *it++;
257-
block_node->moveBefore(before_node);
258-
}
259-
}
260-
261-
// The loop node is initially emitted as:
262-
// Loop(max_trip_count)
263-
// block0(loop_counter) {
264-
// <body>
265-
// }
266-
// block1 {
267-
// <loop condition>
268-
// -> (condition)
269-
// }
270-
// Here, we inline the loop condition into:
271-
// Loop(max_trip_count, start_condition)
272-
// block0(loop_counter) {
273-
// <body>
274-
// }
275-
// block1 {
276-
// <loop condition>
277-
// -> (condition)
278-
// }
279-
280-
void inlineLoopStartCondition(Node* n) {
281-
auto pre_header = n->blocks().at(1);
282-
auto header_block = n->addBlock();
283-
header_block->cloneFrom(pre_header, [](Value* v) { return v; });
284-
moveBlockBeforeNode(n, header_block);
285-
n->addInput(header_block->outputs().at(0));
286-
n->eraseBlock(2);
287-
}
265+
// These passes are run before SSA, so they need to handle before the
266+
// Loop body and loop condition as a separate block.
288267

289-
void inlineLoopStartCondition(Block* block) {
290-
for (Node* n : block->nodes()) {
291-
switch (n->kind()) {
292-
case prim::If:
293-
case prim::Function: {
294-
for (auto b : n->blocks()) {
295-
inlineLoopStartCondition(b);
296-
}
297-
} break;
298-
case prim::Loop: {
299-
inlineLoopStartCondition(n->blocks().at(0));
300-
inlineLoopStartCondition(n);
301-
} break;
302-
}
303-
}
268+
void TransformBreaks(std::shared_ptr<Graph>& graph) {
269+
LoopTransformer breaks(graph, BREAKS);
270+
breaks.run();
304271
}
305272

306-
// First we inline the loop input condition.
307-
// Then, we transform the continues. Because the loop body condition
308-
// has not yet been inlined, we can safely ignore it in the continue pass.
309-
// Then, we transform breaks, inlining the loop body condition as part of the
310-
// pass. Because they have not been inlined yet, we can generated nice graph_s
311-
// of the form
312-
// if did_break
313-
// ... loop_continue = False
314-
// else:
315-
// ... loop_continue = original_condition
316-
void TransformBreaks(std::shared_ptr<Graph>& graph) {
317-
inlineLoopStartCondition(graph->block());
273+
void TransformContinues(std::shared_ptr<Graph>& graph) {
318274
LoopTransformer continues(graph, CONTINUES);
319275
continues.run();
320-
LoopTransformer breaks(graph, BREAKS);
321-
breaks.run();
322276
}
323277

324278
} // namespace script

torch/csrc/jit/passes/break_continue_transform.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ namespace jit {
88
namespace script {
99

1010
TORCH_API void TransformBreaks(std::shared_ptr<Graph>& graph);
11+
TORCH_API void TransformContinues(std::shared_ptr<Graph>& graph);
1112

1213
} // namespace script
1314
} // namespace jit

torch/csrc/jit/script/canonicalize_modified_loop.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,6 @@ void canonicalizeModifiedLoops(Node* n) {
2020
return;
2121
}
2222

23-
TORCH_INTERNAL_ASSERT(
24-
constant_as<bool>(n->inputs().at(1)).value_or(false) == true,
25-
"non-constant loop input condition NYI");
26-
2723
auto g = n->owningGraph();
2824
WithInsertPoint node_insert(n);
2925
auto zero = g->insertConstant(0);
@@ -32,6 +28,11 @@ void canonicalizeModifiedLoops(Node* n) {
3228
auto condition = g->insert(aten::gt, {max_trip_count, zero});
3329
loop.replaceMaxTripCount(
3430
g->insertConstant(std::numeric_limits<int64_t>::max()));
31+
32+
auto inp_condition = toIValue(loop.inputCond());
33+
if (inp_condition == c10::nullopt || inp_condition->toInt() == false) {
34+
condition = g->insert(aten::__and__, {condition, loop.inputCond()});
35+
}
3536
loop.replaceInputCondition(condition);
3637
n->addOutput()->setType(IntType::get());
3738
WithInsertPoint loop_insert(loop.bodyBlock());

0 commit comments

Comments
 (0)