@@ -89,9 +89,11 @@ class DeadCodeEliminator {
8989 // We want to be able to DCE all the %b stuff. So when processing block
9090 // returns, we only mark producers for values that "live" (i.e. used outside
9191 // the block).
92- void markReturnNode (Node* node) {
92+ //
93+ // Returns true iff this marked something we haven't marked before.
94+ bool markReturnNode (Node* node) {
9395 if (marked_.count (node)) {
94- return ;
96+ return false ;
9597 }
9698
9799 AT_ASSERT (node->owningBlock ()->return_node () == node);
@@ -132,30 +134,74 @@ class DeadCodeEliminator {
132134 }
133135
134136 marked_.insert (node);
137+ return true ;
135138 }
136139
137- void mark (Block* block) {
140+ // Loops are special, because we need to run them to convergence.
141+ // Consider the following loop:
142+ // for i in range(3):
143+ // tot += a[0][0]
144+ // b = a[0]
145+ // b[0] += 1
146+ // print(tot)
147+ //
148+ // If we only process the loop block once, we will conclude that `b[0]` and
149+ // `b` are dead, even though `b[0] += 1` mutates a live memory location (since
150+ // `b[0]` is an alias of `a`). i.e. `a` is used to compute `tot` in the next
151+ // iteration
152+ //
153+ // We need to mark the loop again with the information that `a` is live, and
154+ // repeat until we're not marking new stuff anymore.
155+ //
156+ // Returns true iff this marked something we haven't marked before.
157+ bool markLoop (Node* node) {
158+ TORCH_INTERNAL_ASSERT (node->kind () == prim::Loop);
159+ // Did a single iteration over the loop block mark anything new?
160+ // If this is false, we've converged.
161+ bool marked = false ;
162+ // Did we ever mark anything new?
163+ bool anyMarked = false ;
164+ do {
165+ marked = mark (node->blocks ().at (0 ));
166+ anyMarked |= marked;
167+ } while (marked);
168+ return anyMarked;
169+ }
170+
171+ // Returns true iff this marked something we haven't marked before.
172+ bool mark (Block* block) {
173+ bool anyMarked = false ;
138174 // Mark all nodes with side effects.
139175 for (auto node : block->nodes ()) {
140- if (sideEffectPolicy_ == DCESideEffectPolicy::DONT_DELETE_NODES_WITH_SIDE_EFFECTS && hasSideEffects (node)) {
141- mark (node);
176+ if (sideEffectPolicy_ ==
177+ DCESideEffectPolicy::DONT_DELETE_NODES_WITH_SIDE_EFFECTS &&
178+ hasSideEffects (node)) {
179+ anyMarked |= mark (node);
142180 }
143181 }
144182
145183 // Initialize by marking the return node
146- markReturnNode (block->return_node ());
184+ anyMarked |= markReturnNode (block->return_node ());
147185
148186 for (auto it = block->nodes ().rbegin (); it != block->nodes ().rend (); ++it) {
149187 auto node = *it;
150- for (auto subBlock : node->blocks ()) {
151- mark (subBlock);
188+ if (node->kind () == prim::Loop) {
189+ // Special casing for loops, see comment in markLoop.
190+ anyMarked |= markLoop (node);
191+ } else {
192+ // Other nodes with sub-blocks get marked normally.
193+ for (auto subBlock : node->blocks ()) {
194+ anyMarked |= mark (subBlock);
195+ }
152196 }
153- markIfLive (node);
197+ anyMarked |= markIfLive (node);
154198 }
199+ return anyMarked;
155200 }
156201
157202 // If we output or write to a live memory location, mark this node
158- void markIfLive (Node* node) {
203+ // Returns true iff this marked something we haven't marked before.
204+ bool markIfLive (Node* node) {
159205 for (const auto output : node->outputs ()) {
160206 if (liveValues_.count (output)) {
161207 return mark (node);
@@ -167,13 +213,15 @@ class DeadCodeEliminator {
167213 return mark (node);
168214 }
169215 }
216+ return false ;
170217 }
171218
172219 // Mark this node as live and add this node's inputs and aliases to the live
173220 // value sets.
174- void mark (Node* node) {
221+ // Returns true iff this marked something we haven't marked before.
222+ bool mark (Node* node) {
175223 if (marked_.count (node)) {
176- return ;
224+ return false ;
177225 }
178226
179227 marked_.insert (node);
@@ -196,6 +244,7 @@ class DeadCodeEliminator {
196244 }
197245 liveValues_.insert (input);
198246 }
247+ return true ;
199248 }
200249
201250 // Delete all unmarked nodes.
0 commit comments