1 #include <torch/csrc/jit/passes/loop_unrolling.h>
2
3 #include <ATen/core/symbol.h>
4 #include <c10/util/Exception.h>
5 #include <c10/util/irange.h>
6
7 #include <torch/csrc/jit/ir/constants.h>
8 #include <torch/csrc/jit/ir/ir_views.h>
9 #include <torch/csrc/jit/jit_log.h>
10 #include <torch/csrc/jit/passes/dead_code_elimination.h>
11
12 namespace torch::jit {
13
14 namespace {
15
16 static constexpr int64_t kUnrollFactor = 8;
17 static constexpr int64_t kMaxBodySize = 32;
18 static constexpr int64_t kMaxBodyRepeats = 64;
19
isTrueConstant(Value * val)20 bool isTrueConstant(Value* val) {
21 std::optional<bool> maybe_value = constant_as<bool>(val);
22 return maybe_value && *maybe_value;
23 }
24
isForLoop(Node * node)25 bool isForLoop(Node* node) {
26 if (node->kind() != prim::Loop)
27 return false;
28 Value* start_cond = node->inputs().at(1);
29 Value* continue_cond = node->blocks().at(0)->outputs().at(0);
30 return isTrueConstant(start_cond) && isTrueConstant(continue_cond);
31 }
32
33 // Counts the size of this block, stopping and returning once reaches limit
34 // instructions.
limitedBlockSize(Block * body,int64_t limit)35 int64_t limitedBlockSize(Block* body, int64_t limit) {
36 auto it = body->nodes().begin();
37 auto end = body->nodes().end();
38 for (int64_t i = 0; i < limit; ++it) {
39 for (Block* subblock : it->blocks()) {
40 i += limitedBlockSize(subblock, limit - i);
41 }
42 if (!it->notExecutedOp()) {
43 ++i;
44 }
45 if (it == end) {
46 return i;
47 }
48 }
49 return limit;
50 }
51
isSmallBlock(Block * body)52 bool isSmallBlock(Block* body) {
53 return limitedBlockSize(body, kMaxBodySize + 1) <= kMaxBodySize;
54 }
55
56 // XXX: This function can only be called with a loop that is guaranteed to
57 // execute EXACTLY ONCE.
inlineBody(Node * loop)58 void inlineBody(Node* loop) {
59 auto graph = loop->owningGraph();
60 auto body = loop->blocks().at(0);
61 WithInsertPoint insert_point_guard{loop};
62
63 std::unordered_map<Value*, Value*> value_map;
64 auto get_value = [&](Value* v) {
65 auto it = value_map.find(v);
66 if (it != value_map.end())
67 return it->second;
68 return v;
69 };
70
71 // Loop node has extra (max_iters, initial_cond) inputs,
72 // body has an extra (loop_counter) input.
73 for (size_t i = 2; i < loop->inputs().size(); ++i) {
74 value_map[body->inputs()[i - 1]] = loop->inputs()[i];
75 }
76
77 for (Node* orig : body->nodes()) {
78 Node* clone = graph->insertNode(graph->createClone(orig, get_value));
79 for (size_t i = 0; i < orig->outputs().size(); ++i) {
80 value_map[orig->outputs()[i]] = clone->outputs()[i];
81 }
82 }
83 for (size_t i = 0; i < loop->outputs().size(); ++i) {
84 loop->outputs().at(i)->replaceAllUsesWith(
85 get_value(body->outputs().at(i + 1)));
86 }
87 // XXX: it is extremely important to destroy the loop in here. DCE might not
88 // be able to conclude that it's safe, because the loop might contain side
89 // effects.
90 loop->destroy();
91 }
92
93 // inserts a copy of body, passing inputs to the inputs of the block
94 // it returns the a list of the Values for the output of the block
insertBlockCopy(Graph & graph,Block * body,at::ArrayRef<Value * > inputs)95 std::vector<Value*> insertBlockCopy(
96 Graph& graph,
97 Block* body,
98 at::ArrayRef<Value*> inputs) {
99 TORCH_INTERNAL_ASSERT(inputs.size() == body->inputs().size());
100 std::unordered_map<Value*, Value*> value_map;
101 auto get_value = [&](Value* v) {
102 auto it = value_map.find(v);
103 if (it != value_map.end())
104 return it->second;
105 return v;
106 };
107 auto inputs_it = inputs.begin();
108 for (Value* input : body->inputs()) {
109 value_map[input] = *inputs_it++;
110 }
111 for (Node* node : body->nodes()) {
112 Node* new_node = graph.insertNode(graph.createClone(node, get_value));
113 auto outputs_it = new_node->outputs().begin();
114 for (Value* output : node->outputs()) {
115 value_map[output] = *outputs_it++;
116 }
117 }
118 return fmap(body->outputs(), get_value);
119 }
120
repeatBody(Block * body,size_t times,Block * dest)121 void repeatBody(Block* body, size_t times, Block* dest) {
122 auto graph = body->owningGraph();
123 WithInsertPoint insert_point_guard(dest);
124 for (Value* input : body->inputs()) {
125 dest->addInput()->copyMetadata(input);
126 }
127
128 std::vector<Value*> io = dest->inputs().vec();
129 TORCH_INTERNAL_ASSERT(
130 !body->inputs().at(0)->hasUses(), "loop counter should be unused");
131 for (const auto i : c10::irange(times)) {
132 (void)i; // Suppress unused variable warning
133 io[0] = body->inputs().at(0);
134 io = insertBlockCopy(*graph, body, io);
135 }
136 for (Value* output : io) {
137 dest->registerOutput(output);
138 }
139
140 // It's likely that we have some dead nodes now - for example the "true"
141 // constant that prevents the loop from breaking. We shouldn't wait too long
142 // before removing them because they might artificially increase the loop size
143 // and prevent outer loop unrolling.
144 EliminateDeadCode(dest, false);
145 }
146
147 // Replaces the builtin loop counter with a "mutable" variable outside of the
148 // loop.
replaceLoopCounter(Node * loop)149 void replaceLoopCounter(Node* loop) {
150 Graph* graph = loop->owningGraph();
151 Block* body = loop->blocks().at(0);
152 WithInsertPoint guard(loop);
153 Value* init_counter = graph->insertConstant(0);
154
155 loop->insertInput(2, init_counter);
156 loop->insertOutput(0)->setType(IntType::get());
157
158 Value* internal_counter = body->insertInput(1)->setType(init_counter->type());
159 body->inputs()[0]->replaceAllUsesWith(internal_counter);
160
161 WithInsertPoint insertPointGuard{body->return_node()};
162 Value* result = graph->insert(aten::add, {internal_counter, 1});
163 body->insertOutput(1, result);
164 }
165
unroll(Node * loop)166 void unroll(Node* loop) {
167 Graph* graph = loop->owningGraph();
168 Block* body = loop->blocks().at(0);
169
170 // We will be using a "mutable" counter outside of the loop instead of the
171 // default one, because this will allow us to share it between the unrolled
172 // loop and its epilogue. This is necessary only if the loop counter is
173 // actually used in the body.
174 if (!body->inputs()[0]->uses().empty())
175 replaceLoopCounter(loop);
176
177 // Some optimization for constant-length loops. If we know they won't run too
178 // many times, then we can unroll them entirely.
179 Value* trip_count = loop->inputs().at(0);
180 std::optional<int64_t> const_len = constant_as<int64_t>(trip_count);
181 if (const_len && *const_len < kMaxBodyRepeats) {
182 Block* dest = loop->addBlock();
183 repeatBody(body, *const_len, dest);
184 loop->eraseBlock(0);
185 inlineBody(loop);
186 return;
187 }
188
189 WithInsertPoint insert_point_guard{loop};
190
191 // Clone the loop before we unroll it. The clone will become the epilogue.
192 Node* loop_epilogue =
193 graph->createClone(loop, [](Value* v) { return v; })->insertAfter(loop);
194 for (size_t i = 0; i < loop->outputs().size(); ++i) {
195 loop->outputs()[i]->replaceAllUsesWith(loop_epilogue->outputs()[i]);
196 loop_epilogue->replaceInput(i + 2, loop->outputs()[i]);
197 }
198
199 Block* dest = loop->addBlock();
200 repeatBody(body, kUnrollFactor, dest);
201 loop->eraseBlock(0);
202
203 // Change the iteration counts of both loops
204 Value* iter_count = loop->inputs().at(0);
205 Value* unrolled_iter_count = graph->insert(
206 aten::__round_to_zero_floordiv, {iter_count, kUnrollFactor});
207 loop->replaceInput(0, unrolled_iter_count);
208 loop_epilogue->replaceInput(
209 0,
210 graph->insert(
211 aten::sub,
212 {iter_count,
213 graph->insert(aten::mul, {unrolled_iter_count, kUnrollFactor})}));
214 }
215
UnrollLoops(Block * block,bool constant_only)216 bool UnrollLoops(Block* block, bool constant_only) {
217 bool changed = false;
218 for (auto it = block->nodes().begin(); it != block->nodes().end();) {
219 // XXX: unroll might destroy the current node, so we need to pre-increment
220 // the iterator
221 Node* node = *it;
222 ++it;
223 for (Block* subblock : node->blocks()) {
224 changed |= UnrollLoops(subblock, constant_only);
225 }
226 if (!isForLoop(node)) {
227 continue;
228 }
229 if (constant_only) {
230 if (node->inputs().at(0)->node()->kind() != prim::Constant) {
231 continue;
232 }
233 } else if (!isSmallBlock(node->blocks().at(0))) {
234 continue;
235 }
236
237 unroll(node);
238 changed = true;
239 }
240 return changed;
241 }
242
243 } // anonymous namespace
244
addCondAsOutput(Node * loop)245 static void addCondAsOutput(Node* loop) {
246 LoopView loop_view(loop);
247 loop->addInput(loop_view.inputCond());
248 auto block_cond_input = loop_view.bodyBlock()->addInput();
249 block_cond_input->copyMetadata(loop_view.inputCond());
250 auto cond_output_index =
251 loop_view.bodyBlock()->registerOutput(loop_view.nextCond());
252 loop_view.bodyBlock()->outputs()[cond_output_index]->copyMetadata(
253 loop_view.nextCond());
254 auto cond_output = loop->addOutput();
255 cond_output->copyMetadata(loop_view.nextCond());
256 }
257
run(const std::shared_ptr<Graph> & graph)258 bool LoopsPeeler::run(const std::shared_ptr<Graph>& graph) {
259 GRAPH_DUMP("Before LoopsPeeler", graph);
260 collectLoops(graph->block());
261 peelLoops();
262 GRAPH_DUMP("After LoopsPeeler", graph);
263 return true;
264 }
265
collectLoop(Node * n)266 void LoopsPeeler::collectLoop(Node* n) {
267 if (callback_(n)) {
268 if (in_loop_) {
269 GRAPH_DEBUG("Loop ", getHeader(in_loop_), " will be unrolled");
270 loops_to_peel_.push_back(in_loop_);
271 in_loop_ = nullptr;
272 }
273 }
274 }
275
collectLoops(Block * block)276 void LoopsPeeler::collectLoops(Block* block) {
277 // we do a pre-order traversal to reduce the number
278 // of peeled loops.
279 for (auto n : block->nodes()) {
280 collectLoop(n);
281 }
282 collectLoop(block->return_node());
283
284 // process child blocks
285 for (auto n : block->nodes()) {
286 auto old_in_loop_ = in_loop_;
287 if (n->kind() == prim::Loop) {
288 in_loop_ = n;
289 }
290 for (auto b : n->blocks()) {
291 collectLoops(b);
292 }
293 in_loop_ = old_in_loop_;
294 }
295 }
296
peelLoops()297 void LoopsPeeler::peelLoops() {
298 for (auto loop : loops_to_peel_) {
299 PeelLoop(loop, num_iterations_);
300 }
301 }
302
PeelProfilingLoops(const std::shared_ptr<Graph> & graph)303 bool PeelProfilingLoops(const std::shared_ptr<Graph>& graph) {
304 auto peel_predicate = [](Node* n) {
305 for (auto i : n->inputs()) {
306 if (i->type()->isSubtypeOf(*TensorType::get())) {
307 return true;
308 }
309 }
310
311 return false;
312 };
313
314 LoopsPeeler lp(peel_predicate);
315 return lp.run(graph);
316 }
317
PeelLoop(Node * n,size_t times)318 Node* PeelLoop(Node* n, size_t times) {
319 GRAPH_DEBUG("Peeling the loop ", getHeader(n), " ", times, " times");
320
321 auto graph = n->owningGraph();
322 auto orig_loop = LoopView(n);
323
324 WithInsertPoint wip(n);
325 auto times_const = graph->insertConstant(static_cast<int64_t>(times));
326 // N.B. even though a caller may request to peel `times` iterations
327 // `maxTripCount` of the original loop might be less than that
328 // so we should take the minimum of the two
329 auto min_trip_count =
330 graph->insert(prim::min, {orig_loop.maxTripCount(), times_const});
331
332 // make the peeled clone
333 auto peeled_copy = graph->createClone(n, [](Value* v) { return v; });
334 addCondAsOutput(peeled_copy);
335
336 LoopView new_lv(peeled_copy);
337 graph->insertNode(peeled_copy);
338 // only run until the peeled count
339 new_lv.replaceMaxTripCount(min_trip_count);
340
341 // subtract `maxTripCount` of the original loop by the number iterations
342 // the peeled loop runs
343 auto new_max_trip_count =
344 graph->insert(aten::sub, {orig_loop.maxTripCount(), min_trip_count});
345 orig_loop.replaceMaxTripCount(new_max_trip_count);
346 // update the termination condition
347 auto cond_index = peeled_copy->outputs().size() - 1;
348 orig_loop.replaceInputCondition(peeled_copy->output(cond_index));
349
350 static const size_t LOOP_DEPS_WITH_COND_OFFSET = 2;
351 for (size_t i = 0; i < peeled_copy->outputs().size() -
352 1 /* leave off the termination condition */;
353 i++) {
354 n->replaceInput(LOOP_DEPS_WITH_COND_OFFSET + i, peeled_copy->output(i));
355 }
356
357 // the induction variable also needs to be adjusted by the number of
358 // iterations the peeled loop runs
359 {
360 WithInsertPoint peeled_wip(*orig_loop.bodyBlock()->nodes().begin());
361 // we can't create the expression: `new_counter` = `old_counter` + 1 yet
362 // because when we
363 // run `old_counter->replaceAllUsesWith(new_counter)`, we will get
364 // `new_counter = new_counter + 1`
365 auto adjusted_iter_counter =
366 graph->insert(aten::add, {min_trip_count, min_trip_count});
367 orig_loop.currentTripCount()->replaceAllUsesWith(adjusted_iter_counter);
368 adjusted_iter_counter->node()->replaceInput(
369 0, orig_loop.currentTripCount());
370 }
371
372 return peeled_copy;
373 }
374
UnrollLoops(std::shared_ptr<Graph> & graph)375 bool UnrollLoops(std::shared_ptr<Graph>& graph) {
376 bool changed = UnrollLoops(graph->block(), false);
377 if (changed) {
378 EliminateDeadCode(graph);
379 }
380 return changed;
381 }
382
UnrollConstantLoops(std::shared_ptr<Graph> & graph)383 bool UnrollConstantLoops(std::shared_ptr<Graph>& graph) {
384 bool changed = UnrollLoops(graph->block(), true);
385 if (changed) {
386 EliminateDeadCode(graph);
387 }
388 return changed;
389 }
390
391 } // namespace torch::jit
392