Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion test/cpp/jit/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <test/cpp/jit/test_constant_propagation.h>
#include <test/cpp/jit/test_create_autodiff_subgraphs.h>
#include <test/cpp/jit/test_custom_operators.h>
#include <test/cpp/jit/test_dce.h>
#include <test/cpp/jit/test_dynamic_dag.h>
#include <test/cpp/jit/test_fuser.h>
#include <test/cpp/jit/test_graph_executor.h>
Expand Down Expand Up @@ -88,7 +89,8 @@ namespace jit {
_(QualifiedName) \
_(ClassImport) \
_(ScriptObject) \
_(SaveExtraFilesHook)
_(SaveExtraFilesHook) \
_(DCE)

#define TH_FORALL_TESTS_CUDA(_) \
_(ArgumentSpec) \
Expand Down
57 changes: 57 additions & 0 deletions test/cpp/jit/test_dce.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#pragma once

#include <test/cpp/jit/test_base.h>
#include <test/cpp/jit/test_utils.h>

#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/testing/file_check.h>

namespace torch {
namespace jit {
void testDCE() {
auto graph = std::make_shared<Graph>();

// Consider the following loop:
// for i in range(3):
// tot += a[0][0]
// b = a[0]
// b[0] += 1
// print(tot)
// We want to check that b[0] and b are properly marked as live and thus not
// DCE'd.
const std::string input =
R"IR(
graph():
%48 : None = prim::Constant()
%50 : bool = prim::Constant[value=1]()
%10 : bool? = prim::Constant()
%8 : Device? = prim::Constant()
%4 : int? = prim::Constant()
%0 : int = prim::Constant[value=2]()
%12 : int = prim::Constant[value=1]()
%24 : int = prim::Constant[value=3]()
%31 : int = prim::Constant[value=0]()
%2 : int[] = prim::ListConstruct(%0, %0)
%a.1 : Tensor = aten::ones(%2, %4, %4, %8, %10)
%14 : int[] = prim::ListConstruct(%12)
%tot.1 : Tensor = aten::zeros(%14, %4, %4, %8, %10)
%tot : Tensor = prim::Loop(%24, %50, %tot.1)
block0(%i : int, %tot.6 : Tensor):
%33 : Tensor = aten::select(%a.1, %31, %31)
%35 : Tensor = aten::select(%33, %31, %31)
# CHECK: add_
%tot.3 : Tensor = aten::add_(%tot.6, %35, %12)
%b.1 : Tensor = aten::select(%a.1, %31, %31)
%44 : Tensor = aten::select(%b.1, %31, %31)
# CHECK: add_
%46 : Tensor = aten::add_(%44, %12, %12)
-> (%50, %tot.3)
return (%tot)
)IR";
script::parseIR(input, graph.get());
EliminateDeadCode(graph);
// Check that dead code elimin
testing::FileCheck().run(input, *graph);
}
} // namespace jit
} // namespace torch
73 changes: 61 additions & 12 deletions torch/csrc/jit/passes/dead_code_elimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,11 @@ class DeadCodeEliminator {
// We want to be able to DCE all the %b stuff. So when processing block
// returns, we only mark producers for values that "live" (i.e. used outside
// the block).
void markReturnNode(Node* node) {
//
// Returns true iff this marked something we haven't marked before.
bool markReturnNode(Node* node) {
if (marked_.count(node)) {
return;
return false;
}

AT_ASSERT(node->owningBlock()->return_node() == node);
Expand Down Expand Up @@ -132,30 +134,74 @@ class DeadCodeEliminator {
}

marked_.insert(node);
return true;
}

void mark(Block* block) {
// Loops are special, because we need to run them to convergence.
// Consider the following loop:
// for i in range(3):
// tot += a[0][0]
// b = a[0]
// b[0] += 1
// print(tot)
//
// If we only process the loop block once, we will conclude that `b[0]` and
// `b` are dead, even though `b[0] += 1` mutates a live memory location (since
// `b[0]` is an alias of `a`). i.e. `a` is used to compute `tot` in the next
// iteration
//
// We need to mark the loop again with the information that `a` is live, and
Copy link
Contributor

@Krovatkin Krovatkin Jul 9, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// We need to mark the loop again with the information that a is live, and

we could also add

i.e. `a` is used to compute `tot` in the next iteration

// repeat until we're not marking new stuff anymore.
//
// Returns true iff this marked something we haven't marked before.
bool markLoop(Node* node) {
TORCH_INTERNAL_ASSERT(node->kind() == prim::Loop);
// Did a single iteration over the loop block mark anything new?
// If this is false, we've converged.
bool marked = false;
// Did we ever mark anything new?
bool anyMarked = false;
do {
marked = mark(node->blocks().at(0));
anyMarked |= marked;
} while (marked);
return anyMarked;
}

// Returns true iff this marked something we haven't marked before.
bool mark(Block* block) {
bool anyMarked = false;
// Mark all nodes with side effects.
for (auto node : block->nodes()) {
if (sideEffectPolicy_ == DCESideEffectPolicy::DONT_DELETE_NODES_WITH_SIDE_EFFECTS && hasSideEffects(node)) {
mark(node);
if (sideEffectPolicy_ ==
DCESideEffectPolicy::DONT_DELETE_NODES_WITH_SIDE_EFFECTS &&
hasSideEffects(node)) {
anyMarked |= mark(node);
}
}

// Initialize by marking the return node
markReturnNode(block->return_node());
anyMarked |= markReturnNode(block->return_node());

for (auto it = block->nodes().rbegin(); it != block->nodes().rend(); ++it) {
auto node = *it;
for (auto subBlock : node->blocks()) {
mark(subBlock);
if (node->kind() == prim::Loop) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to check for c10::onnx::loop here like we do in the other place?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

onnx is purely functional and has no aliasing, so we wouldn't need the special behavior

// Special casing for loops, see comment in markLoop.
anyMarked |= markLoop(node);
} else {
// Other nodes with sub-blocks get marked normally.
for (auto subBlock : node->blocks()) {
anyMarked |= mark(subBlock);
}
}
markIfLive(node);
anyMarked |= markIfLive(node);
}
return anyMarked;
}

// If we output or write to a live memory location, mark this node
void markIfLive(Node* node) {
// Returns true iff this marked something we haven't marked before.
bool markIfLive(Node* node) {
for (const auto output : node->outputs()) {
if (liveValues_.count(output)) {
return mark(node);
Expand All @@ -167,13 +213,15 @@ class DeadCodeEliminator {
return mark(node);
}
}
return false;
}

// Mark this node as live and add this node's inputs and aliases to the live
// value sets.
void mark(Node* node) {
// Returns true iff this marked something we haven't marked before.
bool mark(Node* node) {
if (marked_.count(node)) {
return;
return false;
}

marked_.insert(node);
Expand All @@ -196,6 +244,7 @@ class DeadCodeEliminator {
}
liveValues_.insert(input);
}
return true;
}

// Delete all unmarked nodes.
Expand Down