-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[jit] fix dce over loops #22632
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[jit] fix dce over loops #22632
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| #pragma once | ||
|
|
||
| #include <test/cpp/jit/test_base.h> | ||
| #include <test/cpp/jit/test_utils.h> | ||
|
|
||
| #include <torch/csrc/jit/passes/dead_code_elimination.h> | ||
| #include <torch/csrc/jit/testing/file_check.h> | ||
|
|
||
| namespace torch { | ||
| namespace jit { | ||
| void testDCE() { | ||
| auto graph = std::make_shared<Graph>(); | ||
|
|
||
| // Consider the following loop: | ||
| // for i in range(3): | ||
| // tot += a[0][0] | ||
| // b = a[0] | ||
| // b[0] += 1 | ||
| // print(tot) | ||
| // We want to check that b[0] and b are properly marked as live and thus not | ||
| // DCE'd. | ||
| const std::string input = | ||
| R"IR( | ||
| graph(): | ||
| %48 : None = prim::Constant() | ||
| %50 : bool = prim::Constant[value=1]() | ||
| %10 : bool? = prim::Constant() | ||
| %8 : Device? = prim::Constant() | ||
| %4 : int? = prim::Constant() | ||
| %0 : int = prim::Constant[value=2]() | ||
| %12 : int = prim::Constant[value=1]() | ||
| %24 : int = prim::Constant[value=3]() | ||
| %31 : int = prim::Constant[value=0]() | ||
| %2 : int[] = prim::ListConstruct(%0, %0) | ||
| %a.1 : Tensor = aten::ones(%2, %4, %4, %8, %10) | ||
| %14 : int[] = prim::ListConstruct(%12) | ||
| %tot.1 : Tensor = aten::zeros(%14, %4, %4, %8, %10) | ||
| %tot : Tensor = prim::Loop(%24, %50, %tot.1) | ||
| block0(%i : int, %tot.6 : Tensor): | ||
| %33 : Tensor = aten::select(%a.1, %31, %31) | ||
| %35 : Tensor = aten::select(%33, %31, %31) | ||
| # CHECK: add_ | ||
| %tot.3 : Tensor = aten::add_(%tot.6, %35, %12) | ||
| %b.1 : Tensor = aten::select(%a.1, %31, %31) | ||
| %44 : Tensor = aten::select(%b.1, %31, %31) | ||
| # CHECK: add_ | ||
| %46 : Tensor = aten::add_(%44, %12, %12) | ||
| -> (%50, %tot.3) | ||
| return (%tot) | ||
| )IR"; | ||
| script::parseIR(input, graph.get()); | ||
| EliminateDeadCode(graph); | ||
| // Check that dead code elimin | ||
| testing::FileCheck().run(input, *graph); | ||
| } | ||
| } // namespace jit | ||
| } // namespace torch |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -89,9 +89,11 @@ class DeadCodeEliminator { | |
| // We want to be able to DCE all the %b stuff. So when processing block | ||
| // returns, we only mark producers for values that "live" (i.e. used outside | ||
| // the block). | ||
| void markReturnNode(Node* node) { | ||
| // | ||
| // Returns true iff this marked something we haven't marked before. | ||
| bool markReturnNode(Node* node) { | ||
| if (marked_.count(node)) { | ||
| return; | ||
| return false; | ||
| } | ||
|
|
||
| AT_ASSERT(node->owningBlock()->return_node() == node); | ||
|
|
@@ -132,30 +134,74 @@ class DeadCodeEliminator { | |
| } | ||
|
|
||
| marked_.insert(node); | ||
| return true; | ||
| } | ||
|
|
||
| void mark(Block* block) { | ||
| // Loops are special, because we need to run them to convergence. | ||
| // Consider the following loop: | ||
| // for i in range(3): | ||
| // tot += a[0][0] | ||
| // b = a[0] | ||
| // b[0] += 1 | ||
| // print(tot) | ||
| // | ||
| // If we only process the loop block once, we will conclude that `b[0]` and | ||
| // `b` are dead, even though `b[0] += 1` mutates a live memory location (since | ||
| // `b[0]` is an alias of `a`). i.e. `a` is used to compute `tot` in the next | ||
| // iteration | ||
| // | ||
| // We need to mark the loop again with the information that `a` is live, and | ||
| // repeat until we're not marking new stuff anymore. | ||
| // | ||
| // Returns true iff this marked something we haven't marked before. | ||
| bool markLoop(Node* node) { | ||
| TORCH_INTERNAL_ASSERT(node->kind() == prim::Loop); | ||
| // Did a single iteration over the loop block mark anything new? | ||
| // If this is false, we've converged. | ||
| bool marked = false; | ||
| // Did we ever mark anything new? | ||
| bool anyMarked = false; | ||
| do { | ||
| marked = mark(node->blocks().at(0)); | ||
| anyMarked |= marked; | ||
| } while (marked); | ||
| return anyMarked; | ||
| } | ||
|
|
||
| // Returns true iff this marked something we haven't marked before. | ||
| bool mark(Block* block) { | ||
| bool anyMarked = false; | ||
| // Mark all nodes with side effects. | ||
| for (auto node : block->nodes()) { | ||
| if (sideEffectPolicy_ == DCESideEffectPolicy::DONT_DELETE_NODES_WITH_SIDE_EFFECTS && hasSideEffects(node)) { | ||
| mark(node); | ||
| if (sideEffectPolicy_ == | ||
| DCESideEffectPolicy::DONT_DELETE_NODES_WITH_SIDE_EFFECTS && | ||
| hasSideEffects(node)) { | ||
| anyMarked |= mark(node); | ||
| } | ||
| } | ||
|
|
||
| // Initialize by marking the return node | ||
| markReturnNode(block->return_node()); | ||
| anyMarked |= markReturnNode(block->return_node()); | ||
|
|
||
| for (auto it = block->nodes().rbegin(); it != block->nodes().rend(); ++it) { | ||
| auto node = *it; | ||
| for (auto subBlock : node->blocks()) { | ||
| mark(subBlock); | ||
| if (node->kind() == prim::Loop) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to check for
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. onnx is purely functional and has no aliasing, so we wouldn't need the special behavior |
||
| // Special casing for loops, see comment in markLoop. | ||
| anyMarked |= markLoop(node); | ||
| } else { | ||
| // Other nodes with sub-blocks get marked normally. | ||
| for (auto subBlock : node->blocks()) { | ||
| anyMarked |= mark(subBlock); | ||
| } | ||
| } | ||
| markIfLive(node); | ||
| anyMarked |= markIfLive(node); | ||
| } | ||
| return anyMarked; | ||
| } | ||
|
|
||
| // If we output or write to a live memory location, mark this node | ||
| void markIfLive(Node* node) { | ||
| // Returns true iff this marked something we haven't marked before. | ||
| bool markIfLive(Node* node) { | ||
| for (const auto output : node->outputs()) { | ||
| if (liveValues_.count(output)) { | ||
| return mark(node); | ||
|
|
@@ -167,13 +213,15 @@ class DeadCodeEliminator { | |
| return mark(node); | ||
| } | ||
| } | ||
| return false; | ||
| } | ||
|
|
||
| // Mark this node as live and add this node's inputs and aliases to the live | ||
| // value sets. | ||
| void mark(Node* node) { | ||
| // Returns true iff this marked something we haven't marked before. | ||
| bool mark(Node* node) { | ||
| if (marked_.count(node)) { | ||
| return; | ||
| return false; | ||
| } | ||
|
|
||
| marked_.insert(node); | ||
|
|
@@ -196,6 +244,7 @@ class DeadCodeEliminator { | |
| } | ||
| liveValues_.insert(input); | ||
| } | ||
| return true; | ||
| } | ||
|
|
||
| // Delete all unmarked nodes. | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could also add