Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5598,6 +5598,19 @@ def func(lim):
inputs = self._make_scalar_vars([10], torch.int64)
self.checkScript(func, inputs, optimize=True)

def test_fibb_totally_better(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol

def fib(x):
# type: (int) -> int
prev = 1
v = 1
for i in range(0, x):
save = v
v = v + prev
prev = save
return v

self.checkScript(fib, (10,))

def test_if(self):
def func(a, b):
# type: (int, int) -> int
Expand Down
79 changes: 42 additions & 37 deletions torch/csrc/jit/passes/loop_unrolling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,57 +85,57 @@ void inlineBody(Node* loop) {
loop->destroy();
}

void repeatBody(Block* body, int64_t times) {
// We will be adding nodes to the body, so cache the initial start and end.
// XXX: they are both inclusive, because the exclusive body_end would point to
// return_node, which would move further away if we were to add nodes,
// and we would enter an infinite loop.
auto body_start = body->nodes().begin();
auto body_end = std::prev(body->nodes().end());
auto graph = body->owningGraph();
WithInsertPoint insert_point_guard{body};

// inserts a copy of body, passing inputs to the inputs of the block
// it returns the a list of the Values for the output of the block
std::vector<Value*> insertBlockCopy(
Graph& graph,
Block* body,
at::ArrayRef<Value*> inputs) {
TORCH_INTERNAL_ASSERT(inputs.size() == body->inputs().size());
std::unordered_map<Value*, Value*> value_map;
auto get_value = [&](Value* v) {
auto it = value_map.find(v);
if (it != value_map.end())
return it->second;
return v;
};

for (int64_t i = 1; i < times; ++i) {
// Update loop-carried values
// NB: note that we don't need to worry about the loop counter, because
// we've replaced it with a loop-carried variable
AT_ASSERT(body->inputs().size() == body->outputs().size());
for (size_t i = 1; i < body->inputs().size(); ++i) {
value_map[body->inputs()[i]] = get_value(body->outputs()[i]);
auto inputs_it = inputs.begin();
for (Value* input : body->inputs()) {
value_map[input] = *inputs_it++;
}
for (Node* node : body->nodes()) {
Node* new_node = graph.insertNode(graph.createClone(node, get_value));
auto outputs_it = new_node->outputs().begin();
for (Value* output : node->outputs()) {
value_map[output] = *outputs_it++;
}
}
return fmap(body->outputs(), get_value);
}

// Clone the nodes
for (auto it = body_start; it != std::next(body_end); ++it) {
Node* orig = *it;
Node* clone = graph->insertNode(graph->createClone(orig, get_value));
for (size_t i = 0; i < orig->outputs().size(); ++i) {
value_map[orig->outputs()[i]] = clone->outputs()[i];
}
}
void repeatBody(Block* body, size_t times, Block* dest) {
auto graph = body->owningGraph();
WithInsertPoint insert_point_guard(dest);
for (Value* input : body->inputs()) {
dest->addInput()->copyMetadata(input);
}

// Update outputs of the body
const std::vector<Value*> new_outputs = fmap(body->outputs(), get_value);
for (int64_t i = new_outputs.size() - 1; i >= 0; --i) {
body->eraseOutput(i);
std::vector<Value*> io = dest->inputs().vec();
TORCH_INTERNAL_ASSERT(
!body->inputs().at(0)->hasUses(), "loop counter should be unused");
for (size_t i = 0; i < times; ++i) {
io[0] = body->inputs().at(0);
io = insertBlockCopy(*graph, body, io);
}
for (Value* output : new_outputs) {
body->registerOutput(output);
for (Value* output : io) {
dest->registerOutput(output);
}

// It's likely that we have some dead nodes now - for example the "true"
// constant that prevents the loop from breaking. We shouldn't wait too long
// before removing them because they might artificially increase the loop size
// and prevent outer loop unrolling.
EliminateDeadCode(body, false);
EliminateDeadCode(dest, false);
}

// Replaces the builtin loop counter with a "mutable" variable outside of the
Expand Down Expand Up @@ -173,9 +173,11 @@ void unroll(Node* loop) {
// Some optimization for constant-length loops. If we know they won't run too
// many times, then we can unroll them entirely.
Value* trip_count = loop->inputs().at(0);
int64_t const_len = constant_as<int64_t>(trip_count).value_or(-1);
if (const_len != -1 && const_len < kMaxBodyRepeats) {
repeatBody(body, const_len);
c10::optional<int64_t> const_len = constant_as<int64_t>(trip_count);
if (const_len && *const_len < kMaxBodyRepeats) {
Block* dest = loop->addBlock();
repeatBody(body, *const_len, dest);
loop->eraseBlock(0);
inlineBody(loop);
return;
}
Expand All @@ -190,7 +192,10 @@ void unroll(Node* loop) {
loop_epilogue->replaceInput(i + 2, loop->outputs()[i]);
}

repeatBody(body, kUnrollFactor);
Block* dest = loop->addBlock();
repeatBody(body, kUnrollFactor, dest);
loop->eraseBlock(0);
body = dest;

// Change the iteration counts of both loops
Value* iter_count = loop->inputs().at(0);
Expand Down