Skip to content

Commit cf356a3

Browse files
zdevitofacebook-github-bot
authored andcommitted
Fix a bug in loop unrolling (#21239)
Summary: Pull Request resolved: #21239 ghimport-source-id: 68256b7 Reviewed By: suo Differential Revision: D15590901 Pulled By: zdevito fbshipit-source-id: 8700aab723d4486fd20d3414df8160b36a3cc5da
1 parent 6e657c5 commit cf356a3

File tree

2 files changed

+55
-37
lines changed

2 files changed

+55
-37
lines changed

test/test_jit.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5608,6 +5608,19 @@ def func(lim):
56085608
inputs = self._make_scalar_vars([10], torch.int64)
56095609
self.checkScript(func, inputs, optimize=True)
56105610

5611+
def test_fibb_totally_better(self):
5612+
def fib(x):
5613+
# type: (int) -> int
5614+
prev = 1
5615+
v = 1
5616+
for i in range(0, x):
5617+
save = v
5618+
v = v + prev
5619+
prev = save
5620+
return v
5621+
5622+
self.checkScript(fib, (10,))
5623+
56115624
def test_if(self):
56125625
def func(a, b):
56135626
# type: (int, int) -> int

torch/csrc/jit/passes/loop_unrolling.cpp

Lines changed: 42 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -85,57 +85,57 @@ void inlineBody(Node* loop) {
8585
loop->destroy();
8686
}
8787

88-
void repeatBody(Block* body, int64_t times) {
89-
// We will be adding nodes to the body, so cache the initial start and end.
90-
// XXX: they are both inclusive, because the exclusive body_end would point to
91-
// return_node, which would move further away if we were to add nodes,
92-
// and we would enter an infinite loop.
93-
auto body_start = body->nodes().begin();
94-
auto body_end = std::prev(body->nodes().end());
95-
auto graph = body->owningGraph();
96-
WithInsertPoint insert_point_guard{body};
97-
88+
// inserts a copy of body, passing inputs to the inputs of the block
89+
// it returns the a list of the Values for the output of the block
90+
std::vector<Value*> insertBlockCopy(
91+
Graph& graph,
92+
Block* body,
93+
at::ArrayRef<Value*> inputs) {
94+
TORCH_INTERNAL_ASSERT(inputs.size() == body->inputs().size());
9895
std::unordered_map<Value*, Value*> value_map;
9996
auto get_value = [&](Value* v) {
10097
auto it = value_map.find(v);
10198
if (it != value_map.end())
10299
return it->second;
103100
return v;
104101
};
105-
106-
for (int64_t i = 1; i < times; ++i) {
107-
// Update loop-carried values
108-
// NB: note that we don't need to worry about the loop counter, because
109-
// we've replaced it with a loop-carried variable
110-
AT_ASSERT(body->inputs().size() == body->outputs().size());
111-
for (size_t i = 1; i < body->inputs().size(); ++i) {
112-
value_map[body->inputs()[i]] = get_value(body->outputs()[i]);
102+
auto inputs_it = inputs.begin();
103+
for (Value* input : body->inputs()) {
104+
value_map[input] = *inputs_it++;
105+
}
106+
for (Node* node : body->nodes()) {
107+
Node* new_node = graph.insertNode(graph.createClone(node, get_value));
108+
auto outputs_it = new_node->outputs().begin();
109+
for (Value* output : node->outputs()) {
110+
value_map[output] = *outputs_it++;
113111
}
112+
}
113+
return fmap(body->outputs(), get_value);
114+
}
114115

115-
// Clone the nodes
116-
for (auto it = body_start; it != std::next(body_end); ++it) {
117-
Node* orig = *it;
118-
Node* clone = graph->insertNode(graph->createClone(orig, get_value));
119-
for (size_t i = 0; i < orig->outputs().size(); ++i) {
120-
value_map[orig->outputs()[i]] = clone->outputs()[i];
121-
}
122-
}
116+
void repeatBody(Block* body, size_t times, Block* dest) {
117+
auto graph = body->owningGraph();
118+
WithInsertPoint insert_point_guard(dest);
119+
for (Value* input : body->inputs()) {
120+
dest->addInput()->copyMetadata(input);
123121
}
124122

125-
// Update outputs of the body
126-
const std::vector<Value*> new_outputs = fmap(body->outputs(), get_value);
127-
for (int64_t i = new_outputs.size() - 1; i >= 0; --i) {
128-
body->eraseOutput(i);
123+
std::vector<Value*> io = dest->inputs().vec();
124+
TORCH_INTERNAL_ASSERT(
125+
!body->inputs().at(0)->hasUses(), "loop counter should be unused");
126+
for (size_t i = 0; i < times; ++i) {
127+
io[0] = body->inputs().at(0);
128+
io = insertBlockCopy(*graph, body, io);
129129
}
130-
for (Value* output : new_outputs) {
131-
body->registerOutput(output);
130+
for (Value* output : io) {
131+
dest->registerOutput(output);
132132
}
133133

134134
// It's likely that we have some dead nodes now - for example the "true"
135135
// constant that prevents the loop from breaking. We shouldn't wait too long
136136
// before removing them because they might artificially increase the loop size
137137
// and prevent outer loop unrolling.
138-
EliminateDeadCode(body, false);
138+
EliminateDeadCode(dest, false);
139139
}
140140

141141
// Replaces the builtin loop counter with a "mutable" variable outside of the
@@ -173,9 +173,11 @@ void unroll(Node* loop) {
173173
// Some optimization for constant-length loops. If we know they won't run too
174174
// many times, then we can unroll them entirely.
175175
Value* trip_count = loop->inputs().at(0);
176-
int64_t const_len = constant_as<int64_t>(trip_count).value_or(-1);
177-
if (const_len != -1 && const_len < kMaxBodyRepeats) {
178-
repeatBody(body, const_len);
176+
c10::optional<int64_t> const_len = constant_as<int64_t>(trip_count);
177+
if (const_len && *const_len < kMaxBodyRepeats) {
178+
Block* dest = loop->addBlock();
179+
repeatBody(body, *const_len, dest);
180+
loop->eraseBlock(0);
179181
inlineBody(loop);
180182
return;
181183
}
@@ -190,7 +192,10 @@ void unroll(Node* loop) {
190192
loop_epilogue->replaceInput(i + 2, loop->outputs()[i]);
191193
}
192194

193-
repeatBody(body, kUnrollFactor);
195+
Block* dest = loop->addBlock();
196+
repeatBody(body, kUnrollFactor, dest);
197+
loop->eraseBlock(0);
198+
body = dest;
194199

195200
// Change the iteration counts of both loops
196201
Value* iter_count = loop->inputs().at(0);

0 commit comments

Comments
 (0)