Skip to content

Commit ec1b669

Browse files
suofacebook-github-bot
authored andcommitted
fix dce over loops
Summary: Pull Request resolved: #22632 Test Plan: Imported from OSS Differential Revision: D16184469 Pulled By: suo fbshipit-source-id: b7cc2d20a7dd8b287e1b6128ddb70d3936032a7e
1 parent 9b8d771 commit ec1b669

File tree

3 files changed

+121
-13
lines changed

3 files changed

+121
-13
lines changed

test/cpp/jit/test.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <test/cpp/jit/test_constant_propagation.h>
2020
#include <test/cpp/jit/test_create_autodiff_subgraphs.h>
2121
#include <test/cpp/jit/test_custom_operators.h>
22+
#include <test/cpp/jit/test_dce.h>
2223
#include <test/cpp/jit/test_dynamic_dag.h>
2324
#include <test/cpp/jit/test_fuser.h>
2425
#include <test/cpp/jit/test_graph_executor.h>
@@ -88,7 +89,8 @@ namespace jit {
8889
_(QualifiedName) \
8990
_(ClassImport) \
9091
_(ScriptObject) \
91-
_(SaveExtraFilesHook)
92+
_(SaveExtraFilesHook) \
93+
_(DCE)
9294

9395
#define TH_FORALL_TESTS_CUDA(_) \
9496
_(ArgumentSpec) \

test/cpp/jit/test_dce.h

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#pragma once
2+
3+
#include <test/cpp/jit/test_base.h>
4+
#include <test/cpp/jit/test_utils.h>
5+
6+
#include <torch/csrc/jit/passes/dead_code_elimination.h>
7+
#include <torch/csrc/jit/testing/file_check.h>
8+
9+
namespace torch {
10+
namespace jit {
11+
void testDCE() {
12+
auto graph = std::make_shared<Graph>();
13+
14+
// Consider the following loop:
15+
// for i in range(3):
16+
// tot += a[0][0]
17+
// b = a[0]
18+
// b[0] += 1
19+
// print(tot)
20+
// We want to check that b[0] and b are properly marked as live and thus not
21+
// DCE'd.
22+
const std::string input =
23+
R"IR(
24+
graph():
25+
%48 : None = prim::Constant()
26+
%50 : bool = prim::Constant[value=1]()
27+
%10 : bool? = prim::Constant()
28+
%8 : Device? = prim::Constant()
29+
%4 : int? = prim::Constant()
30+
%0 : int = prim::Constant[value=2]()
31+
%12 : int = prim::Constant[value=1]()
32+
%24 : int = prim::Constant[value=3]()
33+
%31 : int = prim::Constant[value=0]()
34+
%2 : int[] = prim::ListConstruct(%0, %0)
35+
%a.1 : Tensor = aten::ones(%2, %4, %4, %8, %10)
36+
%14 : int[] = prim::ListConstruct(%12)
37+
%tot.1 : Tensor = aten::zeros(%14, %4, %4, %8, %10)
38+
%tot : Tensor = prim::Loop(%24, %50, %tot.1)
39+
block0(%i : int, %tot.6 : Tensor):
40+
%33 : Tensor = aten::select(%a.1, %31, %31)
41+
%35 : Tensor = aten::select(%33, %31, %31)
42+
# CHECK: add_
43+
%tot.3 : Tensor = aten::add_(%tot.6, %35, %12)
44+
%b.1 : Tensor = aten::select(%a.1, %31, %31)
45+
%44 : Tensor = aten::select(%b.1, %31, %31)
46+
# CHECK: add_
47+
%46 : Tensor = aten::add_(%44, %12, %12)
48+
-> (%50, %tot.3)
49+
return (%tot)
50+
)IR";
51+
script::parseIR(input, graph.get());
52+
EliminateDeadCode(graph);
53+
// Check that dead code elimin
54+
testing::FileCheck().run(input, *graph);
55+
}
56+
} // namespace jit
57+
} // namespace torch

torch/csrc/jit/passes/dead_code_elimination.cpp

Lines changed: 61 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,11 @@ class DeadCodeEliminator {
8989
// We want to be able to DCE all the %b stuff. So when processing block
9090
// returns, we only mark producers for values that "live" (i.e. used outside
9191
// the block).
92-
void markReturnNode(Node* node) {
92+
//
93+
// Returns true iff this marked something we haven't marked before.
94+
bool markReturnNode(Node* node) {
9395
if (marked_.count(node)) {
94-
return;
96+
return false;
9597
}
9698

9799
AT_ASSERT(node->owningBlock()->return_node() == node);
@@ -132,30 +134,74 @@ class DeadCodeEliminator {
132134
}
133135

134136
marked_.insert(node);
137+
return true;
135138
}
136139

137-
void mark(Block* block) {
140+
// Loops are special, because we need to run them to convergence.
141+
// Consider the following loop:
142+
// for i in range(3):
143+
// tot += a[0][0]
144+
// b = a[0]
145+
// b[0] += 1
146+
// print(tot)
147+
//
148+
// If we only process the loop block once, we will conclude that `b[0]` and
149+
// `b` are dead, even though `b[0] += 1` mutates a live memory location (since
150+
// `b[0]` is an alias of `a`). i.e. `a` is used to compute `tot` in the next
151+
// iteration
152+
//
153+
// We need to mark the loop again with the information that `a` is live, and
154+
// repeat until we're not marking new stuff anymore.
155+
//
156+
// Returns true iff this marked something we haven't marked before.
157+
bool markLoop(Node* node) {
158+
TORCH_INTERNAL_ASSERT(node->kind() == prim::Loop);
159+
// Did a single iteration over the loop block mark anything new?
160+
// If this is false, we've converged.
161+
bool marked = false;
162+
// Did we ever mark anything new?
163+
bool anyMarked = false;
164+
do {
165+
marked = mark(node->blocks().at(0));
166+
anyMarked |= marked;
167+
} while (marked);
168+
return anyMarked;
169+
}
170+
171+
// Returns true iff this marked something we haven't marked before.
172+
bool mark(Block* block) {
173+
bool anyMarked = false;
138174
// Mark all nodes with side effects.
139175
for (auto node : block->nodes()) {
140-
if (sideEffectPolicy_ == DCESideEffectPolicy::DONT_DELETE_NODES_WITH_SIDE_EFFECTS && hasSideEffects(node)) {
141-
mark(node);
176+
if (sideEffectPolicy_ ==
177+
DCESideEffectPolicy::DONT_DELETE_NODES_WITH_SIDE_EFFECTS &&
178+
hasSideEffects(node)) {
179+
anyMarked |= mark(node);
142180
}
143181
}
144182

145183
// Initialize by marking the return node
146-
markReturnNode(block->return_node());
184+
anyMarked |= markReturnNode(block->return_node());
147185

148186
for (auto it = block->nodes().rbegin(); it != block->nodes().rend(); ++it) {
149187
auto node = *it;
150-
for (auto subBlock : node->blocks()) {
151-
mark(subBlock);
188+
if (node->kind() == prim::Loop) {
189+
// Special casing for loops, see comment in markLoop.
190+
anyMarked |= markLoop(node);
191+
} else {
192+
// Other nodes with sub-blocks get marked normally.
193+
for (auto subBlock : node->blocks()) {
194+
anyMarked |= mark(subBlock);
195+
}
152196
}
153-
markIfLive(node);
197+
anyMarked |= markIfLive(node);
154198
}
199+
return anyMarked;
155200
}
156201

157202
// If we output or write to a live memory location, mark this node
158-
void markIfLive(Node* node) {
203+
// Returns true iff this marked something we haven't marked before.
204+
bool markIfLive(Node* node) {
159205
for (const auto output : node->outputs()) {
160206
if (liveValues_.count(output)) {
161207
return mark(node);
@@ -167,13 +213,15 @@ class DeadCodeEliminator {
167213
return mark(node);
168214
}
169215
}
216+
return false;
170217
}
171218

172219
// Mark this node as live and add this node's inputs and aliases to the live
173220
// value sets.
174-
void mark(Node* node) {
221+
// Returns true iff this marked something we haven't marked before.
222+
bool mark(Node* node) {
175223
if (marked_.count(node)) {
176-
return;
224+
return false;
177225
}
178226

179227
marked_.insert(node);
@@ -196,6 +244,7 @@ class DeadCodeEliminator {
196244
}
197245
liveValues_.insert(input);
198246
}
247+
return true;
199248
}
200249

201250
// Delete all unmarked nodes.

0 commit comments

Comments
 (0)