@@ -14,9 +14,9 @@ namespace script {
1414void moveBlockBeforeNode (Node* before_node, Block* block);
1515
1616/* *
17- * This pass transforms the graph_ so that break & continue statements are
18- * removed. We transform the graph_ so that ops following a break or continue
19- * are not run.
17+ * This pass transforms the graph so that break & continue statements are
18+ * removed. We transform the graph so that ops following a break or continue are
19+ * not run.
2020 */
2121
2222// Will a block or node continue or break
@@ -32,21 +32,33 @@ struct LoopTransformer {
3232 true_val_ = graph_->insertConstant (true );
3333 false_val_ = graph_->insertConstant (false );
3434 transform_ = transform;
35+ incrementCurString ();
3536 };
3637
37- const std::string& getVarname () {
38+ const std::string getVarname () {
39+ return cur_string;
40+ }
41+
42+ void incrementCurString () {
3843 static const std::string& break_name = " $did_break" ;
3944 static const std::string& continue_name = " $did_continue" ;
40- return transform_ == BREAKS ? break_name : continue_name;
45+ const auto & name = transform_ == BREAKS ? break_name : continue_name;
46+ loop_count++;
47+ cur_string = name + std::to_string (loop_count);
48+ }
49+
50+ void setCurString (const std::string& new_string) {
51+ cur_string = new_string;
4152 }
4253
4354 Symbol transformKind () {
4455 return transform_ == BREAKS ? prim::BreakStmt : prim::ContinueStmt;
4556 }
4657
47- // Recurses on the if node and returns its return status
48- // If status != WONT, sets the block_return_val and sentinel val
49- // of its parent block before exit
58+ // Recursively transform both blocks of the if node.
59+ // If both blocks have hit the transform variable, then the return status,
60+ // is WILL, if both will not hit the transform variable it is false.
61+ // Otherwise we may have hit it.
5062 LoopStatus handleIf (Node* node) {
5163 auto true_block = node->blocks ().at (0 );
5264 auto false_block = node->blocks ().at (1 );
@@ -63,23 +75,24 @@ struct LoopTransformer {
6375 }
6476 }
6577
78+ // if an if node might hit a break or continue statement,
79+ // we guard all subsequent nodes in the block, and only execute them
80+ // if the transform is false.
81+ // The LoopStatus is the result of recursing on the newly created if.
6682 LoopStatus guardBlockNodes (
6783 Block* block,
68- generic_graph_node_list_iterator<Node>& iter) {
69- // if an if node might hit a break or continue statement,
70- // we guard all subsequent nodes in the block, and only execute them
71- // if we did break / did continue is false.
72-
73- auto new_if = graph_->create (prim::If, 0 )->insertBefore (*iter);
84+ graph_node_list::iterator& remaining_block_nodes) {
85+ auto new_if =
86+ graph_->create (prim::If, 0 )->insertBefore (*remaining_block_nodes);
7487 auto sentinel =
7588 graph_->createLoad (getVarname (), BoolType::get ())->insertBefore (new_if);
7689 new_if->addInput (sentinel->output ());
7790
7891 auto hit_control_flow_block = new_if->addBlock ();
7992 auto guard_block = new_if->addBlock ();
8093
81- while (iter != block->nodes ().end ()) {
82- auto node = *iter ++;
94+ while (remaining_block_nodes != block->nodes ().end ()) {
95+ auto node = *remaining_block_nodes ++;
8396 node->moveBefore (guard_block->return_node ());
8497 }
8598
@@ -89,11 +102,14 @@ struct LoopTransformer {
89102 // In a graph like:
90103 // for i in range(3):
91104 // if cond == 2:
105+ // k : Optional[int] = None
92106 // if cond == 2:
93107 // m = 2
94108 // break
95109 // k = 1
110+ // j = 2
96111 // else:
112+ // j = 1
97113 // k = 2
98114 // m += k
99115 // We transform the inner cond == 2 block to look like:
@@ -130,65 +146,59 @@ struct LoopTransformer {
130146 iter->destroy ();
131147 }
132148
133- void inlineLoopConditionIntoLoopBody (Node* n) {
134- auto body_block = n->blocks ().at (0 );
135- auto pre_header = n->blocks ().at (1 );
136- moveBlockBeforeNode (body_block->return_node (), pre_header);
137- body_block->insertOutput (0 , pre_header->outputs ().at (0 ));
138- n->eraseBlock (1 );
139- }
140-
141149 void handleLoop (Node* loop_node) {
142- // transform the loop, then ensure that that it does not accidentally
143- // pick up or assign the current transform variable outside of the loop.
150+ const std::string prev_string = getVarname ();
151+ // Give current loop unique identifier
152+ incrementCurString ();
153+ // transform the loop
144154 transformLoop (loop_node);
145- Block* body_block = loop_node->blocks ().at (0 );
146- graph_->createStore (getVarname (), false_val_)
147- ->insertAfter (body_block->param_node ());
148- graph_->createStore (getVarname (), false_val_)
149- ->insertBefore (body_block->return_node ());
150- }
151155
152- void transformLoop (Node* n) {
153- Block* body_block = n->blocks ().at (0 );
154- auto ret_status = handleTransforms (body_block);
155-
156- // When we're transforming breaks:
157- // the body condition has not yet been inlined. If we we are not breaking
158- // we need to inline the condition block into the end of the loop.
159- // if we might break, we create an if statement and only execute the loop
160- // header if we did not break.
161- // Since we run the continue pass before the break pass,
162- // we do not need to do any additional work in continues; guardBlock nodes
163- // ensures that we do not execute any ops present in the block after a
164- // continue, and loop condition is inlined after.
165-
166- if (transform_ == CONTINUES) {
167- return ;
168- }
169-
170- if (ret_status == WONT) {
171- inlineLoopConditionIntoLoopBody (n);
172- return ;
173- }
156+ // restore previous identifier
157+ setCurString (prev_string);
158+ }
174159
175- WithInsertPoint insert (body_block);
160+ // Create a check for the current transform variable.
161+ // if transform is true, loop continue condition is false, otherwise
162+ // run original condition
163+ void guardConditionBlock (Block* condition_block) {
164+ WithInsertPoint insert (*condition_block->nodes ().begin ());
176165 auto did_break =
177166 graph_->insertNode (graph_->createLoad (getVarname (), BoolType::get ()))
178167 ->output ();
179-
180168 auto new_loop_condition = graph_->insertNode (graph_->create (prim::If));
181169 new_loop_condition->addInput (did_break);
182170 new_loop_condition->output ()->setType (BoolType::get ());
183-
184- // if we did break, we do not continue
185171 new_loop_condition->addBlock ()->registerOutput (false_val_);
186172 auto original_condition = new_loop_condition->addBlock ();
187- auto pre_header = n->blocks ().at (1 );
188- moveBlockBeforeNode (original_condition->return_node (), pre_header);
189- original_condition->insertOutput (0 , pre_header->outputs ().at (0 ));
190- n->eraseBlock (1 );
191- body_block->registerOutput (new_loop_condition->output ());
173+
174+ Node* n = new_loop_condition;
175+ for (n = n->next (); n != condition_block->return_node ();) {
176+ auto cur = n;
177+ n = n->next ();
178+ cur->moveBefore (original_condition->return_node ());
179+ }
180+ original_condition->insertOutput (0 , condition_block->outputs ().at (0 ));
181+ condition_block->eraseOutput (0 );
182+ condition_block->registerOutput (new_loop_condition->output ());
183+ }
184+
185+ void transformLoop (Node* n) {
186+ Block* body_block = n->blocks ().at (0 );
187+ auto ret_status = handleTransforms (body_block);
188+
189+ // loop header should run even if we have continued
190+ if (transform_ == CONTINUES || ret_status == WONT) {
191+ return ;
192+ }
193+
194+ // because the condition block will get inlined as the start loop condition,
195+ // we need to make sure that it is defined before the loop executes
196+ // (and false so original condition is run). Also insert it into the block
197+ // so it is not an unneccessary loop carried var.
198+ graph_->createStore (getVarname (), false_val_)->insertBefore (n);
199+ graph_->createStore (getVarname (), false_val_)
200+ ->insertAfter (body_block->param_node ());
201+ guardConditionBlock (n->blocks ().at (1 ));
192202 };
193203
194204 LoopStatus handleTransforms (Block* block) {
@@ -205,7 +215,6 @@ struct LoopTransformer {
205215 if (node->kind () != transformKind ()) {
206216 continue ;
207217 }
208- WithInsertPoint b (block);
209218 node->destroy ();
210219 loop_status = WILL;
211220 } break ;
@@ -244,81 +253,26 @@ struct LoopTransformer {
244253 handleTransforms (graph_->block ());
245254 }
246255
256+ size_t loop_count = 0 ;
247257 Transform transform_;
248258 Value* true_val_ = nullptr ;
249259 Value* false_val_ = nullptr ;
260+ std::string cur_string = " " ;
250261
251262 std::shared_ptr<Graph> graph_;
252263};
253264
254- void moveBlockBeforeNode (Node* before_node, Block* block) {
255- for (auto it = block->nodes ().begin (); it != block->nodes ().end ();) {
256- auto block_node = *it++;
257- block_node->moveBefore (before_node);
258- }
259- }
260-
261- // The loop node is initially emitted as:
262- // Loop(max_trip_count)
263- // block0(loop_counter) {
264- // <body>
265- // }
266- // block1 {
267- // <loop condition>
268- // -> (condition)
269- // }
270- // Here, we inline the loop condition into:
271- // Loop(max_trip_count, start_condition)
272- // block0(loop_counter) {
273- // <body>
274- // }
275- // block1 {
276- // <loop condition>
277- // -> (condition)
278- // }
279-
280- void inlineLoopStartCondition (Node* n) {
281- auto pre_header = n->blocks ().at (1 );
282- auto header_block = n->addBlock ();
283- header_block->cloneFrom (pre_header, [](Value* v) { return v; });
284- moveBlockBeforeNode (n, header_block);
285- n->addInput (header_block->outputs ().at (0 ));
286- n->eraseBlock (2 );
287- }
265+ // These passes are run before SSA, so they need to handle before the
266+ // Loop body and loop condition as a separate block.
288267
289- void inlineLoopStartCondition (Block* block) {
290- for (Node* n : block->nodes ()) {
291- switch (n->kind ()) {
292- case prim::If:
293- case prim::Function: {
294- for (auto b : n->blocks ()) {
295- inlineLoopStartCondition (b);
296- }
297- } break ;
298- case prim::Loop: {
299- inlineLoopStartCondition (n->blocks ().at (0 ));
300- inlineLoopStartCondition (n);
301- } break ;
302- }
303- }
268+ void TransformBreaks (std::shared_ptr<Graph>& graph) {
269+ LoopTransformer breaks (graph, BREAKS);
270+ breaks.run ();
304271}
305272
306- // First we inline the loop input condition.
307- // Then, we transform the continues. Because the loop body condition
308- // has not yet been inlined, we can safely ignore it in the continue pass.
309- // Then, we transform breaks, inlining the loop body condition as part of the
310- // pass. Because they have not been inlined yet, we can generated nice graph_s
311- // of the form
312- // if did_break
313- // ... loop_continue = False
314- // else:
315- // ... loop_continue = original_condition
316- void TransformBreaks (std::shared_ptr<Graph>& graph) {
317- inlineLoopStartCondition (graph->block ());
273+ void TransformContinues (std::shared_ptr<Graph>& graph) {
318274 LoopTransformer continues (graph, CONTINUES);
319275 continues.run ();
320- LoopTransformer breaks (graph, BREAKS);
321- breaks.run ();
322276}
323277
324278} // namespace script
0 commit comments