@@ -85,57 +85,57 @@ void inlineBody(Node* loop) {
8585 loop->destroy ();
8686}
8787
88- void repeatBody (Block* body, int64_t times) {
89- // We will be adding nodes to the body, so cache the initial start and end.
90- // XXX: they are both inclusive, because the exclusive body_end would point to
91- // return_node, which would move further away if we were to add nodes,
92- // and we would enter an infinite loop.
93- auto body_start = body->nodes ().begin ();
94- auto body_end = std::prev (body->nodes ().end ());
95- auto graph = body->owningGraph ();
96- WithInsertPoint insert_point_guard{body};
97-
88+ // inserts a copy of body, passing inputs to the inputs of the block
89+ // it returns the a list of the Values for the output of the block
90+ std::vector<Value*> insertBlockCopy (
91+ Graph& graph,
92+ Block* body,
93+ at::ArrayRef<Value*> inputs) {
94+ TORCH_INTERNAL_ASSERT (inputs.size () == body->inputs ().size ());
9895 std::unordered_map<Value*, Value*> value_map;
9996 auto get_value = [&](Value* v) {
10097 auto it = value_map.find (v);
10198 if (it != value_map.end ())
10299 return it->second ;
103100 return v;
104101 };
105-
106- for (int64_t i = 1 ; i < times; ++i) {
107- // Update loop-carried values
108- // NB: note that we don't need to worry about the loop counter, because
109- // we've replaced it with a loop-carried variable
110- AT_ASSERT (body->inputs ().size () == body->outputs ().size ());
111- for (size_t i = 1 ; i < body->inputs ().size (); ++i) {
112- value_map[body->inputs ()[i]] = get_value (body->outputs ()[i]);
102+ auto inputs_it = inputs.begin ();
103+ for (Value* input : body->inputs ()) {
104+ value_map[input] = *inputs_it++;
105+ }
106+ for (Node* node : body->nodes ()) {
107+ Node* new_node = graph.insertNode (graph.createClone (node, get_value));
108+ auto outputs_it = new_node->outputs ().begin ();
109+ for (Value* output : node->outputs ()) {
110+ value_map[output] = *outputs_it++;
113111 }
112+ }
113+ return fmap (body->outputs (), get_value);
114+ }
114115
115- // Clone the nodes
116- for (auto it = body_start; it != std::next (body_end); ++it) {
117- Node* orig = *it;
118- Node* clone = graph->insertNode (graph->createClone (orig, get_value));
119- for (size_t i = 0 ; i < orig->outputs ().size (); ++i) {
120- value_map[orig->outputs ()[i]] = clone->outputs ()[i];
121- }
122- }
116+ void repeatBody (Block* body, size_t times, Block* dest) {
117+ auto graph = body->owningGraph ();
118+ WithInsertPoint insert_point_guard (dest);
119+ for (Value* input : body->inputs ()) {
120+ dest->addInput ()->copyMetadata (input);
123121 }
124122
125- // Update outputs of the body
126- const std::vector<Value*> new_outputs = fmap (body->outputs (), get_value);
127- for (int64_t i = new_outputs.size () - 1 ; i >= 0 ; --i) {
128- body->eraseOutput (i);
123+ std::vector<Value*> io = dest->inputs ().vec ();
124+ TORCH_INTERNAL_ASSERT (
125+ !body->inputs ().at (0 )->hasUses (), " loop counter should be unused" );
126+ for (size_t i = 0 ; i < times; ++i) {
127+ io[0 ] = body->inputs ().at (0 );
128+ io = insertBlockCopy (*graph, body, io);
129129 }
130- for (Value* output : new_outputs ) {
131- body ->registerOutput (output);
130+ for (Value* output : io ) {
131+ dest ->registerOutput (output);
132132 }
133133
134134 // It's likely that we have some dead nodes now - for example the "true"
135135 // constant that prevents the loop from breaking. We shouldn't wait too long
136136 // before removing them because they might artificially increase the loop size
137137 // and prevent outer loop unrolling.
138- EliminateDeadCode (body , false );
138+ EliminateDeadCode (dest , false );
139139}
140140
141141// Replaces the builtin loop counter with a "mutable" variable outside of the
@@ -173,9 +173,11 @@ void unroll(Node* loop) {
173173 // Some optimization for constant-length loops. If we know they won't run too
174174 // many times, then we can unroll them entirely.
175175 Value* trip_count = loop->inputs ().at (0 );
176- int64_t const_len = constant_as<int64_t >(trip_count).value_or (-1 );
177- if (const_len != -1 && const_len < kMaxBodyRepeats ) {
178- repeatBody (body, const_len);
176+ c10::optional<int64_t > const_len = constant_as<int64_t >(trip_count);
177+ if (const_len && *const_len < kMaxBodyRepeats ) {
178+ Block* dest = loop->addBlock ();
179+ repeatBody (body, *const_len, dest);
180+ loop->eraseBlock (0 );
179181 inlineBody (loop);
180182 return ;
181183 }
@@ -190,7 +192,10 @@ void unroll(Node* loop) {
190192 loop_epilogue->replaceInput (i + 2 , loop->outputs ()[i]);
191193 }
192194
193- repeatBody (body, kUnrollFactor );
195+ Block* dest = loop->addBlock ();
196+ repeatBody (body, kUnrollFactor , dest);
197+ loop->eraseBlock (0 );
198+ body = dest;
194199
195200 // Change the iteration counts of both loops
196201 Value* iter_count = loop->inputs ().at (0 );
0 commit comments