Skip to content

Commit 14b8ffa

Browse files
committed
Review comments
1 parent d8741f1 commit 14b8ffa

File tree

6 files changed

+69
-54
lines changed

6 files changed

+69
-54
lines changed

torch/csrc/jit/graph_executor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ struct GraphExecutorImpl {
350350
// do not work on variables
351351

352352
// They also may assume that concrete sizes/strides are availiable
353+
UnrollLoops(graph);
353354

354355
//TODO: create peephole optimizations that are safe to run
355356
// when we are using variables, and when we do not know sizes.
@@ -358,7 +359,6 @@ struct GraphExecutorImpl {
358359
// it works fine on variables.
359360
BatchMM(graph);
360361
FuseGraph(graph);
361-
UnrollLoops(graph);
362362
}
363363
}
364364
// we need to run some passes to ensure the graph will run correctly

torch/csrc/jit/ir.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#include <sstream>
1111
#include <algorithm>
1212
#include <string>
13-
#include <regex>
1413

1514
namespace torch { namespace jit {
1615

@@ -538,7 +537,6 @@ std::shared_ptr<Graph> Graph::copy() {
538537
}
539538

540539
inline Value* Value::setUniqueName(const std::string & name) {
541-
static std::regex numbered_name_regex("(.*)\\.([0-9]+)");
542540
if (name.size() > 0 && name.find_first_not_of("0123456789") == std::string::npos) {
543541
throw std::runtime_error("names may not be integers: " + name);
544542
}
@@ -560,10 +558,12 @@ inline Value* Value::setUniqueName(const std::string & name) {
560558
if(old_owner_of_name != names.end()) {
561559
size_t suffix = 1;
562560
std::string name_base = name;
563-
std::smatch match;
564-
if (std::regex_match(name, match, numbered_name_regex)) {
565-
name_base = match[1];
566-
suffix = std::stoll(match[2]);
561+
auto last_dot_pos = name.find_last_of('.');
562+
if (last_dot_pos != std::string::npos && last_dot_pos + 1 != name.size()) {
563+
if (name.find_first_not_of("0123456789", last_dot_pos + 1) == std::string::npos) {
564+
suffix = std::stoll(name.substr(last_dot_pos + 1));
565+
name_base = name.substr(0, last_dot_pos);
566+
}
567567
}
568568
std::string replacement_name;
569569
do {
Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,49 @@
11
#include "torch/csrc/jit/passes/dead_code_elimination.h"
22

3+
#include <unordered_map>
4+
35
namespace torch { namespace jit {
46

5-
bool hasSideEffects(Node * node) {
7+
using bool_memo_type = std::unordered_map<Node*, bool>;
8+
9+
bool hasSideEffects(Node * node, bool_memo_type& memo) {
610
// FIXME: PythonOp and CppOp should be treated as having side effects as well!
711
// Unfortunately ONNX depends on them getting removed in this pass, so it's not
812
// a simple change.
9-
return node->kind() == prim::Print ||
10-
std::any_of(node->blocks().begin(), node->blocks().end(),
11-
[](Block *b) {
12-
return std::any_of(b->nodes().begin(), b->nodes().end(), hasSideEffects);
13-
});
14-
}
15-
16-
void EliminateDeadCode(std::shared_ptr<Graph>& graph) {
17-
EliminateDeadCode(graph->block());
13+
auto it = memo.find(node);
14+
if (it != memo.end())
15+
return it->second;
16+
bool has_side_effects = node->kind() == prim::Print ||
17+
std::any_of(node->blocks().begin(), node->blocks().end(),
18+
[&](Block *b) {
19+
return std::any_of(b->nodes().begin(), b->nodes().end(),
20+
[&](Node *n) { return hasSideEffects(n, memo); });
21+
});
22+
memo.emplace(node, has_side_effects);
23+
return has_side_effects;
1824
}
1925

20-
void EliminateDeadCode(Block *block, bool recurse) {
26+
void EliminateDeadCode(Block *block, bool recurse, bool_memo_type& memo) {
2127
auto nodes = block->nodes().reverse();
2228
for (auto it = nodes.begin(); it != nodes.end(); it++) {
2329
auto node = *it;
2430
if (recurse) {
2531
for (Block * block : node->blocks())
26-
EliminateDeadCode(block);
32+
EliminateDeadCode(block, true, memo);
2733
}
28-
if (!node->hasUses() && !hasSideEffects(node))
34+
if (!node->hasUses() && !hasSideEffects(node, memo))
2935
it.destroyCurrent();
3036
}
3137
}
3238

39+
void EliminateDeadCode(std::shared_ptr<Graph>& graph) {
40+
bool_memo_type side_effect_memo;
41+
EliminateDeadCode(graph->block(), true, side_effect_memo);
42+
}
43+
44+
void EliminateDeadCode(Block *block, bool recurse) {
45+
bool_memo_type side_effect_memo;
46+
EliminateDeadCode(block, recurse, side_effect_memo);
47+
}
48+
3349
}} // namespace torch::jit

torch/csrc/jit/passes/loop_unrolling.cpp

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "torch/csrc/jit/interned_strings.h"
44
#include "torch/csrc/jit/symbolic_variable.h"
5+
#include "torch/csrc/jit/tensor_conversions.h"
56
#include "torch/csrc/jit/passes/dead_code_elimination.h"
67

78
namespace torch { namespace jit {
@@ -13,11 +14,8 @@ static constexpr int64_t kMaxBodySize = 16;
1314
static constexpr int64_t kMaxBodyRepeats = 64;
1415

1516
bool isTrueConstant(Value *val) {
16-
Node *producer = val->node();
17-
if (producer->kind() != prim::Constant)
18-
return false;
19-
auto value = producer->t(attr::value);
20-
return value.type() == at::CPU(at::kByte) && value.dim() == 0 && value.toCLong() == 1;
17+
at::optional<bool> maybe_value = constant_as<bool>(val);
18+
return maybe_value && *maybe_value;
2119
}
2220

2321
bool isForLoop(Node* node) {
@@ -28,12 +26,13 @@ bool isForLoop(Node* node) {
2826
return isTrueConstant(start_cond) && isTrueConstant(continue_cond);
2927
}
3028

29+
// Counts the size of this block, stopping and returning once reaches limit instructions.
3130
int64_t limitedBlockSize(Block *body, int64_t limit) {
3231
auto it = body->nodes().begin();
3332
auto end = body->nodes().end();
3433
for (int64_t i = 0; i < limit; ++i, ++it) {
3534
for (Block *subblock : it->blocks()) {
36-
i += limitedBlockSize(subblock, limit);
35+
i += limitedBlockSize(subblock, limit - i);
3736
}
3837
if (it == end) {
3938
return i;
@@ -46,13 +45,6 @@ bool isSmallBlock(Block *body) {
4645
return limitedBlockSize(body, kMaxBodySize + 1) <= kMaxBodySize;
4746
}
4847

49-
at::optional<int64_t> getConstantLength(Node *loop) {
50-
Value *trip_count = loop->inputs().at(0);
51-
if (trip_count->node()->kind() != prim::Constant)
52-
return at::nullopt;
53-
return {trip_count->node()->t(attr::value).toCLong()};
54-
}
55-
5648
// XXX: This function can only be called with a loop that is guaranteed to execute EXACTLY ONCE.
5749
void inlineBody(Node *loop) {
5850
auto graph = loop->owningGraph();
@@ -89,11 +81,13 @@ void inlineBody(Node *loop) {
8981

9082
void repeatBody(Block *body, int64_t times) {
9183
// We will be adding nodes to the body, so cache the initial start and end.
92-
// XXX: they are both inclusive
84+
// XXX: they are both inclusive, because the exclusive body_end would point to
85+
// return_node, which would move further away if we were to add nodes, and we
86+
// would enter an infinite loop.
9387
auto body_start = body->nodes().begin();
9488
auto body_end = std::prev(body->nodes().end());
9589
auto graph = body->owningGraph();
96-
WithInsertPoint insert_point_guard { body->return_node() };
90+
WithInsertPoint insert_point_guard { body };
9791

9892
std::unordered_map<Value*, Value*> value_map;
9993
auto get_value = [&](Value *v) {
@@ -123,12 +117,12 @@ void repeatBody(Block *body, int64_t times) {
123117
}
124118

125119
// Update outputs of the body
126-
const std::vector<Value*> orig_outputs = body->outputs();
127-
for (int64_t i = orig_outputs.size() - 1; i >= 0; --i) {
120+
const std::vector<Value*> new_outputs = fmap(body->outputs(), get_value);
121+
for (int64_t i = new_outputs.size() - 1; i >= 0; --i) {
128122
body->eraseOutput(i);
129123
}
130-
for (Value *output : orig_outputs) {
131-
body->registerOutput(get_value(output));
124+
for (Value *output : new_outputs) {
125+
body->registerOutput(output);
132126
}
133127

134128
// It's likely that we have some dead nodes now - for example the "true" constant
@@ -168,7 +162,8 @@ void unroll(Node *loop) {
168162

169163
// Some optimization for constant-length loops. If we know they won't run too many
170164
// times, then we can unroll them entirely.
171-
int64_t const_len = getConstantLength(loop).value_or(-1);
165+
Value *trip_count = loop->inputs().at(0);
166+
int64_t const_len = constant_as<int64_t>(trip_count).value_or(-1);
172167
if (const_len != -1 && const_len < kMaxBodyRepeats) {
173168
repeatBody(body, const_len);
174169
inlineBody(loop);

torch/csrc/jit/script/compiler.cpp

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -290,20 +290,6 @@ static bool isTensorSubtype(Value* v) {
290290
return v->type()->isSubtypeOf(*DynamicType::get());
291291
}
292292

293-
// if a value is a constant then try to turn into type T using the
294-
// same rules as the interpreter
295-
template<typename T>
296-
at::optional<T> constant_as(Value* v) {
297-
if(v->node()->kind() != prim::Constant)
298-
return at::nullopt;
299-
auto tensor = v->node()->t(attr::value);
300-
try {
301-
return tensor_as<T>(std::move(tensor));
302-
} catch (tensor_conversion_error& err) {
303-
return at::nullopt;
304-
}
305-
}
306-
307293
at::optional<std::vector<int64_t>> getIntListAttribute(at::optional<int32_t> N, Value* input) {
308294
auto list = constant_as<at::IntList>(input);
309295
if(list)

torch/csrc/jit/tensor_conversions.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,4 +123,22 @@ inline at::Tensor as_variable(const T& t) {
123123
return autograd::make_variable(as_tensor(t));
124124
}
125125

126+
//////////////////////////////////////////////////////////////////////////////////
127+
// Helper for retrieving constants
128+
//////////////////////////////////////////////////////////////////////////////////
129+
130+
// if a value is a constant then try to turn into type T using the
131+
// same rules as the interpreter
132+
template<typename T>
133+
at::optional<T> constant_as(Value* v) {
134+
if(v->node()->kind() != prim::Constant)
135+
return at::nullopt;
136+
auto tensor = v->node()->t(attr::value);
137+
try {
138+
return tensor_as<T>(std::move(tensor));
139+
} catch (tensor_conversion_error& err) {
140+
return at::nullopt;
141+
}
142+
}
143+
126144
}} // namespace torch::jit

0 commit comments

Comments
 (0)