22
33#include " torch/csrc/jit/interned_strings.h"
44#include " torch/csrc/jit/symbolic_variable.h"
5+ #include " torch/csrc/jit/tensor_conversions.h"
56#include " torch/csrc/jit/passes/dead_code_elimination.h"
67
78namespace torch { namespace jit {
@@ -13,11 +14,8 @@ static constexpr int64_t kMaxBodySize = 16;
1314static constexpr int64_t kMaxBodyRepeats = 64 ;
1415
1516bool isTrueConstant (Value *val) {
16- Node *producer = val->node ();
17- if (producer->kind () != prim::Constant)
18- return false ;
19- auto value = producer->t (attr::value);
20- return value.type () == at::CPU (at::kByte ) && value.dim () == 0 && value.toCLong () == 1 ;
17+ at::optional<bool > maybe_value = constant_as<bool >(val);
18+ return maybe_value && *maybe_value;
2119}
2220
2321bool isForLoop (Node* node) {
@@ -28,12 +26,13 @@ bool isForLoop(Node* node) {
2826 return isTrueConstant (start_cond) && isTrueConstant (continue_cond);
2927}
3028
29+ // Counts the size of this block, stopping and returning once reaches limit instructions.
3130int64_t limitedBlockSize (Block *body, int64_t limit) {
3231 auto it = body->nodes ().begin ();
3332 auto end = body->nodes ().end ();
3433 for (int64_t i = 0 ; i < limit; ++i, ++it) {
3534 for (Block *subblock : it->blocks ()) {
36- i += limitedBlockSize (subblock, limit);
35+ i += limitedBlockSize (subblock, limit - i );
3736 }
3837 if (it == end) {
3938 return i;
@@ -46,13 +45,6 @@ bool isSmallBlock(Block *body) {
4645 return limitedBlockSize (body, kMaxBodySize + 1 ) <= kMaxBodySize ;
4746}
4847
49- at::optional<int64_t > getConstantLength (Node *loop) {
50- Value *trip_count = loop->inputs ().at (0 );
51- if (trip_count->node ()->kind () != prim::Constant)
52- return at::nullopt ;
53- return {trip_count->node ()->t (attr::value).toCLong ()};
54- }
55-
5648// XXX: This function can only be called with a loop that is guaranteed to execute EXACTLY ONCE.
5749void inlineBody (Node *loop) {
5850 auto graph = loop->owningGraph ();
@@ -89,11 +81,13 @@ void inlineBody(Node *loop) {
8981
9082void repeatBody (Block *body, int64_t times) {
9183 // We will be adding nodes to the body, so cache the initial start and end.
92- // XXX: they are both inclusive
84+ // XXX: they are both inclusive, because the exclusive body_end would point to
85+ // return_node, which would move further away if we were to add nodes, and we
86+ // would enter an infinite loop.
9387 auto body_start = body->nodes ().begin ();
9488 auto body_end = std::prev (body->nodes ().end ());
9589 auto graph = body->owningGraph ();
96- WithInsertPoint insert_point_guard { body-> return_node () };
90+ WithInsertPoint insert_point_guard { body };
9791
9892 std::unordered_map<Value*, Value*> value_map;
9993 auto get_value = [&](Value *v) {
@@ -123,12 +117,12 @@ void repeatBody(Block *body, int64_t times) {
123117 }
124118
125119 // Update outputs of the body
126- const std::vector<Value*> orig_outputs = body->outputs ();
127- for (int64_t i = orig_outputs .size () - 1 ; i >= 0 ; --i) {
120+ const std::vector<Value*> new_outputs = fmap ( body->outputs (), get_value );
121+ for (int64_t i = new_outputs .size () - 1 ; i >= 0 ; --i) {
128122 body->eraseOutput (i);
129123 }
130- for (Value *output : orig_outputs ) {
131- body->registerOutput (get_value ( output) );
124+ for (Value *output : new_outputs ) {
125+ body->registerOutput (output);
132126 }
133127
134128 // It's likely that we have some dead nodes now - for example the "true" constant
@@ -168,7 +162,8 @@ void unroll(Node *loop) {
168162
169163 // Some optimization for constant-length loops. If we know they won't run too many
170164 // times, then we can unroll them entirely.
171- int64_t const_len = getConstantLength (loop).value_or (-1 );
165+ Value *trip_count = loop->inputs ().at (0 );
166+ int64_t const_len = constant_as<int64_t >(trip_count).value_or (-1 );
172167 if (const_len != -1 && const_len < kMaxBodyRepeats ) {
173168 repeatBody (body, const_len);
174169 inlineBody (loop);
0 commit comments