xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/loop_unrolling.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/loop_unrolling.h>
2 
3 #include <ATen/core/symbol.h>
4 #include <c10/util/Exception.h>
5 #include <c10/util/irange.h>
6 
7 #include <torch/csrc/jit/ir/constants.h>
8 #include <torch/csrc/jit/ir/ir_views.h>
9 #include <torch/csrc/jit/jit_log.h>
10 #include <torch/csrc/jit/passes/dead_code_elimination.h>
11 
12 namespace torch::jit {
13 
14 namespace {
15 
16 static constexpr int64_t kUnrollFactor = 8;
17 static constexpr int64_t kMaxBodySize = 32;
18 static constexpr int64_t kMaxBodyRepeats = 64;
19 
isTrueConstant(Value * val)20 bool isTrueConstant(Value* val) {
21   std::optional<bool> maybe_value = constant_as<bool>(val);
22   return maybe_value && *maybe_value;
23 }
24 
isForLoop(Node * node)25 bool isForLoop(Node* node) {
26   if (node->kind() != prim::Loop)
27     return false;
28   Value* start_cond = node->inputs().at(1);
29   Value* continue_cond = node->blocks().at(0)->outputs().at(0);
30   return isTrueConstant(start_cond) && isTrueConstant(continue_cond);
31 }
32 
33 // Counts the size of this block, stopping and returning once reaches limit
34 // instructions.
limitedBlockSize(Block * body,int64_t limit)35 int64_t limitedBlockSize(Block* body, int64_t limit) {
36   auto it = body->nodes().begin();
37   auto end = body->nodes().end();
38   for (int64_t i = 0; i < limit; ++it) {
39     for (Block* subblock : it->blocks()) {
40       i += limitedBlockSize(subblock, limit - i);
41     }
42     if (!it->notExecutedOp()) {
43       ++i;
44     }
45     if (it == end) {
46       return i;
47     }
48   }
49   return limit;
50 }
51 
isSmallBlock(Block * body)52 bool isSmallBlock(Block* body) {
53   return limitedBlockSize(body, kMaxBodySize + 1) <= kMaxBodySize;
54 }
55 
56 // XXX: This function can only be called with a loop that is guaranteed to
57 // execute EXACTLY ONCE.
inlineBody(Node * loop)58 void inlineBody(Node* loop) {
59   auto graph = loop->owningGraph();
60   auto body = loop->blocks().at(0);
61   WithInsertPoint insert_point_guard{loop};
62 
63   std::unordered_map<Value*, Value*> value_map;
64   auto get_value = [&](Value* v) {
65     auto it = value_map.find(v);
66     if (it != value_map.end())
67       return it->second;
68     return v;
69   };
70 
71   // Loop node has extra (max_iters, initial_cond) inputs,
72   // body has an extra (loop_counter) input.
73   for (size_t i = 2; i < loop->inputs().size(); ++i) {
74     value_map[body->inputs()[i - 1]] = loop->inputs()[i];
75   }
76 
77   for (Node* orig : body->nodes()) {
78     Node* clone = graph->insertNode(graph->createClone(orig, get_value));
79     for (size_t i = 0; i < orig->outputs().size(); ++i) {
80       value_map[orig->outputs()[i]] = clone->outputs()[i];
81     }
82   }
83   for (size_t i = 0; i < loop->outputs().size(); ++i) {
84     loop->outputs().at(i)->replaceAllUsesWith(
85         get_value(body->outputs().at(i + 1)));
86   }
87   // XXX: it is extremely important to destroy the loop in here. DCE might not
88   // be able to conclude that it's safe, because the loop might contain side
89   // effects.
90   loop->destroy();
91 }
92 
93 // inserts a copy of body, passing inputs to the inputs of the block
94 // it returns the a list of the Values for the output of the block
insertBlockCopy(Graph & graph,Block * body,at::ArrayRef<Value * > inputs)95 std::vector<Value*> insertBlockCopy(
96     Graph& graph,
97     Block* body,
98     at::ArrayRef<Value*> inputs) {
99   TORCH_INTERNAL_ASSERT(inputs.size() == body->inputs().size());
100   std::unordered_map<Value*, Value*> value_map;
101   auto get_value = [&](Value* v) {
102     auto it = value_map.find(v);
103     if (it != value_map.end())
104       return it->second;
105     return v;
106   };
107   auto inputs_it = inputs.begin();
108   for (Value* input : body->inputs()) {
109     value_map[input] = *inputs_it++;
110   }
111   for (Node* node : body->nodes()) {
112     Node* new_node = graph.insertNode(graph.createClone(node, get_value));
113     auto outputs_it = new_node->outputs().begin();
114     for (Value* output : node->outputs()) {
115       value_map[output] = *outputs_it++;
116     }
117   }
118   return fmap(body->outputs(), get_value);
119 }
120 
repeatBody(Block * body,size_t times,Block * dest)121 void repeatBody(Block* body, size_t times, Block* dest) {
122   auto graph = body->owningGraph();
123   WithInsertPoint insert_point_guard(dest);
124   for (Value* input : body->inputs()) {
125     dest->addInput()->copyMetadata(input);
126   }
127 
128   std::vector<Value*> io = dest->inputs().vec();
129   TORCH_INTERNAL_ASSERT(
130       !body->inputs().at(0)->hasUses(), "loop counter should be unused");
131   for (const auto i : c10::irange(times)) {
132     (void)i; // Suppress unused variable warning
133     io[0] = body->inputs().at(0);
134     io = insertBlockCopy(*graph, body, io);
135   }
136   for (Value* output : io) {
137     dest->registerOutput(output);
138   }
139 
140   // It's likely that we have some dead nodes now - for example the "true"
141   // constant that prevents the loop from breaking. We shouldn't wait too long
142   // before removing them because they might artificially increase the loop size
143   // and prevent outer loop unrolling.
144   EliminateDeadCode(dest, false);
145 }
146 
147 // Replaces the builtin loop counter with a "mutable" variable outside of the
148 // loop.
replaceLoopCounter(Node * loop)149 void replaceLoopCounter(Node* loop) {
150   Graph* graph = loop->owningGraph();
151   Block* body = loop->blocks().at(0);
152   WithInsertPoint guard(loop);
153   Value* init_counter = graph->insertConstant(0);
154 
155   loop->insertInput(2, init_counter);
156   loop->insertOutput(0)->setType(IntType::get());
157 
158   Value* internal_counter = body->insertInput(1)->setType(init_counter->type());
159   body->inputs()[0]->replaceAllUsesWith(internal_counter);
160 
161   WithInsertPoint insertPointGuard{body->return_node()};
162   Value* result = graph->insert(aten::add, {internal_counter, 1});
163   body->insertOutput(1, result);
164 }
165 
unroll(Node * loop)166 void unroll(Node* loop) {
167   Graph* graph = loop->owningGraph();
168   Block* body = loop->blocks().at(0);
169 
170   // We will be using a "mutable" counter outside of the loop instead of the
171   // default one, because this will allow us to share it between the unrolled
172   // loop and its epilogue. This is necessary only if the loop counter is
173   // actually used in the body.
174   if (!body->inputs()[0]->uses().empty())
175     replaceLoopCounter(loop);
176 
177   // Some optimization for constant-length loops. If we know they won't run too
178   // many times, then we can unroll them entirely.
179   Value* trip_count = loop->inputs().at(0);
180   std::optional<int64_t> const_len = constant_as<int64_t>(trip_count);
181   if (const_len && *const_len < kMaxBodyRepeats) {
182     Block* dest = loop->addBlock();
183     repeatBody(body, *const_len, dest);
184     loop->eraseBlock(0);
185     inlineBody(loop);
186     return;
187   }
188 
189   WithInsertPoint insert_point_guard{loop};
190 
191   // Clone the loop before we unroll it. The clone will become the epilogue.
192   Node* loop_epilogue =
193       graph->createClone(loop, [](Value* v) { return v; })->insertAfter(loop);
194   for (size_t i = 0; i < loop->outputs().size(); ++i) {
195     loop->outputs()[i]->replaceAllUsesWith(loop_epilogue->outputs()[i]);
196     loop_epilogue->replaceInput(i + 2, loop->outputs()[i]);
197   }
198 
199   Block* dest = loop->addBlock();
200   repeatBody(body, kUnrollFactor, dest);
201   loop->eraseBlock(0);
202 
203   // Change the iteration counts of both loops
204   Value* iter_count = loop->inputs().at(0);
205   Value* unrolled_iter_count = graph->insert(
206       aten::__round_to_zero_floordiv, {iter_count, kUnrollFactor});
207   loop->replaceInput(0, unrolled_iter_count);
208   loop_epilogue->replaceInput(
209       0,
210       graph->insert(
211           aten::sub,
212           {iter_count,
213            graph->insert(aten::mul, {unrolled_iter_count, kUnrollFactor})}));
214 }
215 
UnrollLoops(Block * block,bool constant_only)216 bool UnrollLoops(Block* block, bool constant_only) {
217   bool changed = false;
218   for (auto it = block->nodes().begin(); it != block->nodes().end();) {
219     // XXX: unroll might destroy the current node, so we need to pre-increment
220     // the iterator
221     Node* node = *it;
222     ++it;
223     for (Block* subblock : node->blocks()) {
224       changed |= UnrollLoops(subblock, constant_only);
225     }
226     if (!isForLoop(node)) {
227       continue;
228     }
229     if (constant_only) {
230       if (node->inputs().at(0)->node()->kind() != prim::Constant) {
231         continue;
232       }
233     } else if (!isSmallBlock(node->blocks().at(0))) {
234       continue;
235     }
236 
237     unroll(node);
238     changed = true;
239   }
240   return changed;
241 }
242 
243 } // anonymous namespace
244 
addCondAsOutput(Node * loop)245 static void addCondAsOutput(Node* loop) {
246   LoopView loop_view(loop);
247   loop->addInput(loop_view.inputCond());
248   auto block_cond_input = loop_view.bodyBlock()->addInput();
249   block_cond_input->copyMetadata(loop_view.inputCond());
250   auto cond_output_index =
251       loop_view.bodyBlock()->registerOutput(loop_view.nextCond());
252   loop_view.bodyBlock()->outputs()[cond_output_index]->copyMetadata(
253       loop_view.nextCond());
254   auto cond_output = loop->addOutput();
255   cond_output->copyMetadata(loop_view.nextCond());
256 }
257 
run(const std::shared_ptr<Graph> & graph)258 bool LoopsPeeler::run(const std::shared_ptr<Graph>& graph) {
259   GRAPH_DUMP("Before LoopsPeeler", graph);
260   collectLoops(graph->block());
261   peelLoops();
262   GRAPH_DUMP("After LoopsPeeler", graph);
263   return true;
264 }
265 
collectLoop(Node * n)266 void LoopsPeeler::collectLoop(Node* n) {
267   if (callback_(n)) {
268     if (in_loop_) {
269       GRAPH_DEBUG("Loop ", getHeader(in_loop_), " will be unrolled");
270       loops_to_peel_.push_back(in_loop_);
271       in_loop_ = nullptr;
272     }
273   }
274 }
275 
collectLoops(Block * block)276 void LoopsPeeler::collectLoops(Block* block) {
277   // we do a pre-order traversal to reduce the number
278   // of peeled loops.
279   for (auto n : block->nodes()) {
280     collectLoop(n);
281   }
282   collectLoop(block->return_node());
283 
284   // process child blocks
285   for (auto n : block->nodes()) {
286     auto old_in_loop_ = in_loop_;
287     if (n->kind() == prim::Loop) {
288       in_loop_ = n;
289     }
290     for (auto b : n->blocks()) {
291       collectLoops(b);
292     }
293     in_loop_ = old_in_loop_;
294   }
295 }
296 
peelLoops()297 void LoopsPeeler::peelLoops() {
298   for (auto loop : loops_to_peel_) {
299     PeelLoop(loop, num_iterations_);
300   }
301 }
302 
PeelProfilingLoops(const std::shared_ptr<Graph> & graph)303 bool PeelProfilingLoops(const std::shared_ptr<Graph>& graph) {
304   auto peel_predicate = [](Node* n) {
305     for (auto i : n->inputs()) {
306       if (i->type()->isSubtypeOf(*TensorType::get())) {
307         return true;
308       }
309     }
310 
311     return false;
312   };
313 
314   LoopsPeeler lp(peel_predicate);
315   return lp.run(graph);
316 }
317 
PeelLoop(Node * n,size_t times)318 Node* PeelLoop(Node* n, size_t times) {
319   GRAPH_DEBUG("Peeling the loop ", getHeader(n), " ", times, " times");
320 
321   auto graph = n->owningGraph();
322   auto orig_loop = LoopView(n);
323 
324   WithInsertPoint wip(n);
325   auto times_const = graph->insertConstant(static_cast<int64_t>(times));
326   // N.B. even though a caller may request to peel `times` iterations
327   // `maxTripCount` of the original loop might be less than that
328   // so we should take the minimum of the two
329   auto min_trip_count =
330       graph->insert(prim::min, {orig_loop.maxTripCount(), times_const});
331 
332   // make the peeled clone
333   auto peeled_copy = graph->createClone(n, [](Value* v) { return v; });
334   addCondAsOutput(peeled_copy);
335 
336   LoopView new_lv(peeled_copy);
337   graph->insertNode(peeled_copy);
338   // only run until the peeled count
339   new_lv.replaceMaxTripCount(min_trip_count);
340 
341   // subtract `maxTripCount` of the original loop by the number iterations
342   // the peeled loop runs
343   auto new_max_trip_count =
344       graph->insert(aten::sub, {orig_loop.maxTripCount(), min_trip_count});
345   orig_loop.replaceMaxTripCount(new_max_trip_count);
346   // update the termination condition
347   auto cond_index = peeled_copy->outputs().size() - 1;
348   orig_loop.replaceInputCondition(peeled_copy->output(cond_index));
349 
350   static const size_t LOOP_DEPS_WITH_COND_OFFSET = 2;
351   for (size_t i = 0; i < peeled_copy->outputs().size() -
352            1 /* leave off the termination condition */;
353        i++) {
354     n->replaceInput(LOOP_DEPS_WITH_COND_OFFSET + i, peeled_copy->output(i));
355   }
356 
357   // the induction variable also needs to be adjusted by the number of
358   // iterations the peeled loop runs
359   {
360     WithInsertPoint peeled_wip(*orig_loop.bodyBlock()->nodes().begin());
361     // we can't create the expression: `new_counter` = `old_counter` + 1 yet
362     // because when we
363     // run `old_counter->replaceAllUsesWith(new_counter)`, we will get
364     // `new_counter = new_counter + 1`
365     auto adjusted_iter_counter =
366         graph->insert(aten::add, {min_trip_count, min_trip_count});
367     orig_loop.currentTripCount()->replaceAllUsesWith(adjusted_iter_counter);
368     adjusted_iter_counter->node()->replaceInput(
369         0, orig_loop.currentTripCount());
370   }
371 
372   return peeled_copy;
373 }
374 
UnrollLoops(std::shared_ptr<Graph> & graph)375 bool UnrollLoops(std::shared_ptr<Graph>& graph) {
376   bool changed = UnrollLoops(graph->block(), false);
377   if (changed) {
378     EliminateDeadCode(graph);
379   }
380   return changed;
381 }
382 
UnrollConstantLoops(std::shared_ptr<Graph> & graph)383 bool UnrollConstantLoops(std::shared_ptr<Graph>& graph) {
384   bool changed = UnrollLoops(graph->block(), true);
385   if (changed) {
386     EliminateDeadCode(graph);
387   }
388   return changed;
389 }
390 
391 } // namespace torch::jit
392