1 #include <torch/csrc/jit/frontend/exit_transforms.h>
2
3 #include <ATen/core/jit_type.h>
4 #include <c10/util/irange.h>
5 #include <torch/csrc/jit/ir/ir.h>
6 #include <torch/csrc/jit/ir/ir_views.h>
7 #include <torch/csrc/jit/runtime/graph_iterator.h>
8
9 namespace torch::jit {
10
11 // WILL states that a node/block must hit the exit, MIGHT that it may happen,
12 // WONT that it will not happen. THROWS states that a node/block always throws,
13 // and allows us to create better graphs by not conditionalizing execution
14 // when it is not necessary. It is an optimization; replacing it with WONT
15 // would preserve graph semantics.
16
17 enum class ExitStatus { WILL, MIGHT, WONT, THROWS };
18
19 enum class Transform { Returns, LoopContinuations };
20
21 // hasExited() indicates whether or not an exit has been hit.
22 // The ExitTransform pass maintains a false boolean false_val_ && a true boolean
23 // true_val_, and an uninitialized boolean throws_val_.
24 // if hasExited() == true_val_ then we have exited, if hasExited() == false_val_
25 // we have not, hasExited() == throws_val_ we have hit a block that throws.
26 // Otherwise, we might have exited.
27 // exitValues() are the values that we are propagating to a destination block.
28 // this is used for block outputs of loops and outputs of functions & closures
29 struct ExitPair : public std::pair<Value*, std::vector<Value*>> {
30 using pair::pair;
31
ExitPairtorch::jit::ExitPair32 ExitPair(Value* exit_v, at::ArrayRef<Value*> exit_val_ref) {
33 std::vector<Value*> exit_vals;
34 for (Value* v : exit_val_ref) {
35 exit_vals.push_back(v);
36 }
37 AT_ASSERT(exit_v->type() == BoolType::get());
38 this->first = exit_v;
39 this->second = std::move(exit_vals);
40 }
41
hasExitedtorch::jit::ExitPair42 Value* hasExited() const {
43 return this->first;
44 }
45
exitValuestorch::jit::ExitPair46 std::vector<Value*> exitValues() const {
47 return this->second;
48 }
49 };
50
51 /**
52 * This pass currently transforms the Graph so that all exit nodes targeting
53 * a block location are removed from the graph and unified.
54 * The exit node for breaks/continues is LoopContinuation, and the exit for
55 * Graphs & Closures is ReturnStmt.
56 *
57 * Once we hit an Exit Node, we do not execute any further instructions
58 * until the exit target has been reached.
59 *
60 * For blocks and control flow nodes that have an exit statement that may
61 * have been hit, we conditionalize all execution on a boolean value that
62 * indicates whether we have hit the exit, hasExited().
63 *
64 * The pass keeps tracks of blocks that always throw, so that we can construct
65 * simpler graphs. For example, if in one block of an if statement we return
66 * and in the other we throw, we can treat the node as always returning instead
67 * of conditionalizing execution in the remainder of the block.
68 */
69 struct ExitTransformer {
ExitTransformertorch::jit::ExitTransformer70 ExitTransformer(std::shared_ptr<Graph> graph) : graph_(std::move(graph)) {
71 WithInsertPoint guard(graph_->block()->nodes().front());
72 true_val_ = graph_->insertConstant(true);
73 false_val_ = graph_->insertConstant(false);
74 // this value will never be used, since we will always throw before it is
75 // accessed
76 throws_val_ = getUnitValue(BoolType::get());
77 };
78
transformReturnStmtstorch::jit::ExitTransformer79 void transformReturnStmts() {
80 current_exit_kind_ = prim::ReturnStmt;
81 transformExits(graph_->block());
82 }
83
transformLoopContinuationstorch::jit::ExitTransformer84 void transformLoopContinuations() {
85 current_exit_kind_ = prim::LoopContinuation;
86 transformExits(graph_->block());
87 }
88
89 private:
constructThrowsExitPairtorch::jit::ExitTransformer90 ExitPair constructThrowsExitPair() {
91 return ExitPair(throws_val_, std::vector<Value*>({}));
92 }
constructWontExitPairtorch::jit::ExitTransformer93 ExitPair constructWontExitPair() {
94 return ExitPair(false_val_, std::vector<Value*>({}));
95 }
constructWillExitPairtorch::jit::ExitTransformer96 ExitPair constructWillExitPair(at::ArrayRef<Value*> exit_val_ref) {
97 return ExitPair(true_val_, exit_val_ref);
98 }
99
getExitStatustorch::jit::ExitTransformer100 ExitStatus getExitStatus(ExitPair& exit_pair) {
101 Value* exit_v = exit_pair.hasExited();
102 if (exit_v == true_val_) {
103 return ExitStatus::WILL;
104 } else if (exit_v == false_val_) {
105 return ExitStatus::WONT;
106 } else if (exit_v == throws_val_) {
107 return ExitStatus::THROWS;
108 } else {
109 return ExitStatus::MIGHT;
110 }
111 }
112
owningNodeKindtorch::jit::ExitTransformer113 static Symbol owningNodeKind(Block* block) {
114 if (block->owningNode()) {
115 return block->owningNode()->kind();
116 }
117 return Symbol();
118 }
119
isGraphOrClosureBlocktorch::jit::ExitTransformer120 static bool isGraphOrClosureBlock(Block* block) {
121 return block->owningNode() == nullptr ||
122 owningNodeKind(block) == prim::Closure;
123 }
124
removeOutputstorch::jit::ExitTransformer125 static void removeOutputs(Block* b) {
126 while (!b->outputs().empty()) {
127 b->eraseOutput(0);
128 }
129 }
130
registerBlockOutputstorch::jit::ExitTransformer131 static void registerBlockOutputs(Block* b, at::ArrayRef<Value*> outs) {
132 for (Value* out : outs) {
133 b->registerOutput(out);
134 }
135 }
136
replaceBlockOutputstorch::jit::ExitTransformer137 static void replaceBlockOutputs(Block* b, at::ArrayRef<Value*> outs) {
138 removeOutputs(b);
139 registerBlockOutputs(b, outs);
140 }
141
addIfOutputstorch::jit::ExitTransformer142 static void addIfOutputs(
143 Node* n,
144 at::ArrayRef<Value*> true_outs,
145 at::ArrayRef<Value*> false_outs) {
146 IfView if_view(n);
147 registerBlockOutputs(if_view.thenBlock(), true_outs);
148 registerBlockOutputs(if_view.elseBlock(), false_outs);
149 for (const auto i : c10::irange(true_outs.size())) {
150 auto out_type = unifyTypes(
151 true_outs.at(i)->type(),
152 false_outs.at(i)->type(),
153 /*default_to_union=*/true);
154 n->addOutput()->setType(*out_type);
155 }
156 }
157
158 // creates a vector of uninitialized values of the same type as the
159 // values_to_match
matchValuesWithUnitializedtorch::jit::ExitTransformer160 std::vector<Value*> matchValuesWithUnitialized(
161 at::ArrayRef<Value*> values_to_match) {
162 std::vector<Value*> match_values;
163 for (Value* val : values_to_match) {
164 match_values.push_back(getUnitValue(val->type()));
165 }
166 return match_values;
167 }
168
transformLooptorch::jit::ExitTransformer169 ExitPair transformLoop(Node* node) {
170 LoopView loop(node);
171 Block* body = loop.bodyBlock();
172 auto exit_pair = transformExits(body);
173 // if we're not exiting to outside the loop we don't need to do any work.
174 // since we may not enter the loop return WONT for the THROWS case.
175
176 if (getExitStatus(exit_pair) == ExitStatus::WONT ||
177 getExitStatus(exit_pair) == ExitStatus::THROWS) {
178 return constructWontExitPair();
179 }
180
181 // if we are, we need to update the loop continue condition so that
182 // we exit the loop if we've hit an exit
183 // and we need to propagate hasExited() and exitValues() outside the loop
184
185 // example:
186 // while i < 5:
187 // i += 1
188 // if j == 4:
189 // return 5
190 // -> becomes
191 //
192 // loop_continue = i < 5
193 // has_exited = false
194 // ret_val = uninitialized(int)
195 // while loop_continue:
196 // i += 1
197 // if j == 4:
198 // ret_val = 5
199 // has_exited = True
200 // else:
201 // ret_val = uninitialized(int)
202 // has_exited = False
203 // if has_exited:
204 // loop_continue = False
205 // else:
206 // loop_continue = i < 5
207
208 // update loop continuation condition so that we exit if we hit an exit
209 WithInsertPoint insert(body);
210 auto new_if = graph_->insertNode(graph_->create(prim::If, 0));
211 new_if->addInput(exit_pair.hasExited());
212 new_if->addBlock()->registerOutput(false_val_);
213 new_if->addBlock()->registerOutput(loop.nextCond());
214 auto new_condition = new_if->addOutput()->setType(BoolType::get());
215 loop.bodyBlock()->eraseOutput(0);
216 loop.bodyBlock()->insertOutput(0, new_condition);
217
218 // add hasExited() to loop outputs, we didn't exit if we didn't enter the
219 // loop
220 node->addInput(false_val_);
221 body->addInput()->setType(BoolType::get());
222 body->registerOutput(exit_pair.hasExited());
223 Value* new_has_exited = node->addOutput()->setType(BoolType::get());
224
225 // add exit values
226 for (Value* exit_value : exit_pair.exitValues()) {
227 auto typ = exit_value->type();
228 node->addInput(getUnitValue(typ));
229 node->addOutput()->setType(typ);
230 body->addInput()->setType(typ);
231 body->registerOutput(exit_value);
232 }
233
234 auto exit_vals = node->outputs().slice(
235 node->outputs().size() - exit_pair.exitValues().size());
236
237 return ExitPair(new_has_exited, exit_vals);
238 }
239
calcIfExitStatustorch::jit::ExitTransformer240 ExitStatus calcIfExitStatus(ExitStatus then_status, ExitStatus else_status) {
241 // if one branch throws, we can take the status of the other
242 if (then_status == ExitStatus::THROWS) {
243 return else_status;
244 } else if (else_status == ExitStatus::THROWS) {
245 return then_status;
246 }
247
248 if (then_status == ExitStatus::WONT && else_status == ExitStatus::WONT) {
249 return ExitStatus::WONT;
250 }
251
252 if (then_status == ExitStatus::WILL && else_status == ExitStatus::WILL) {
253 return ExitStatus::WILL;
254 }
255
256 return ExitStatus::MIGHT;
257 }
258
259 // Recursively transforms the if node
transformIftorch::jit::ExitTransformer260 ExitPair transformIf(Node* node) {
261 auto then_block = node->blocks().at(0);
262 auto else_block = node->blocks().at(1);
263
264 auto then_pair = transformExits(then_block);
265 auto else_pair = transformExits(else_block);
266 auto then_status = getExitStatus(then_pair);
267 auto else_status = getExitStatus(else_pair);
268
269 auto if_status = calcIfExitStatus(then_status, else_status);
270
271 if (if_status == ExitStatus::THROWS) {
272 return constructThrowsExitPair();
273 }
274 if (if_status == ExitStatus::WONT) {
275 return constructWontExitPair();
276 }
277
278 // The exit values of the block that is not exiting will not get
279 // used, so we create uninitialized values of the same type as the other
280 // block.
281 if (then_status == ExitStatus::WONT || then_status == ExitStatus::THROWS) {
282 std::vector<Value*> exit_vals =
283 matchValuesWithUnitialized(else_pair.exitValues());
284 then_pair = ExitPair(then_pair.hasExited(), exit_vals);
285 } else if (
286 else_status == ExitStatus::WONT || else_status == ExitStatus::THROWS) {
287 std::vector<Value*> exit_vals =
288 matchValuesWithUnitialized(then_pair.exitValues());
289 else_pair = ExitPair(else_pair.hasExited(), exit_vals);
290 }
291
292 Value* has_exited = nullptr;
293 if (if_status == ExitStatus::WILL) {
294 // Need to maintain the invariant that if hasExited() == true_val_
295 // then we have exited.
296 has_exited = true_val_;
297 } else {
298 addIfOutputs(node, {then_pair.hasExited()}, {else_pair.hasExited()});
299 has_exited = node->outputs().at(node->outputs().size() - 1);
300 }
301 addIfOutputs(node, then_pair.exitValues(), else_pair.exitValues());
302 size_t num_exit_vals = then_pair.exitValues().size();
303 auto exit_vals =
304 node->outputs().slice(node->outputs().size() - num_exit_vals);
305 return ExitPair(has_exited, exit_vals);
306 }
307
308 // Recursively transforms the With node.
transformWithtorch::jit::ExitTransformer309 ExitPair transformWith(Node* node) {
310 auto body_block = node->blocks().at(0);
311 auto body_pair = transformExits(body_block);
312 return body_pair;
313 }
314
315 // Guards the remaining nodes in the block with an if node that takes
316 // the has exited value as its conditional
guardBlockNodestorch::jit::ExitTransformer317 ExitPair guardBlockNodes(
318 Block* block,
319 const ExitPair& exit_pair,
320 graph_node_list_iterator& iter) {
321 auto new_if = graph_->create(prim::If, 0)->insertBefore(*iter);
322 new_if->addInput(exit_pair.hasExited());
323
324 auto exit_block = new_if->addBlock();
325 auto guard_block = new_if->addBlock();
326
327 // Move all remaining nodes into the guard block
328 while (iter != block->nodes().end()) {
329 auto node = *iter++;
330 node->moveBefore(guard_block->return_node());
331 }
332
333 std::vector<Value*> exit_block_vals;
334 // after an exit, the only values that will get used
335 // are the hasExited() and exitValues(), so we match the existing
336 // block outputs with unitialized
337 exit_block_vals = matchValuesWithUnitialized(block->outputs());
338
339 // Set the new if to have the same outputs of the original block,
340 // then replace the original block outputs with new if's outputs
341 for (size_t i = 0; i < block->outputs().size(); ++i) {
342 exit_block->registerOutput(exit_block_vals.at(i));
343 guard_block->registerOutput(block->outputs().at(i));
344 new_if->addOutput()->setType(block->outputs().at(i)->type());
345 }
346
347 while (!block->outputs().empty()) {
348 block->eraseOutput(0);
349 }
350 for (auto out : new_if->outputs()) {
351 block->registerOutput(out);
352 }
353
354 graph_->create(current_exit_kind_, {exit_pair.exitValues()}, 0)
355 ->insertBefore(exit_block->return_node());
356 return transformIf(new_if);
357 }
358
359 // these nodes my have uses,
360 // such as in the case:
361 // if i == 1:
362 // break
363 // j = j + 1
364 // where the j + 1 value will be a block output, but since they will
365 // never be used, it is safe to replace them with unitialized value
destroyNodeAfterExittorch::jit::ExitTransformer366 void destroyNodeAfterExit(Node* n) {
367 for (auto output : n->outputs()) {
368 if (!output->uses().empty()) {
369 output->replaceAllUsesWith(getUnitValue(output->type()));
370 }
371 }
372 n->destroy();
373 }
374
deleteAfterExitNodestorch::jit::ExitTransformer375 void deleteAfterExitNodes(Block* block, graph_node_list_iterator& iter) {
376 if (iter == block->nodes().end()) {
377 return;
378 }
379 WithInsertPoint insert(*block->nodes().begin());
380 // need to destroy in reverse order so nodes have no uses when destroyed
381 for (auto it = block->nodes().reverse().begin(); it != iter;) {
382 Node* n = *it++;
383 if (*it != block->return_node()) {
384 destroyNodeAfterExit(n);
385 }
386 }
387 destroyNodeAfterExit(*iter);
388 }
389
390 // if we're entering a Loop block & transforming LoopContinuations, or if
391 // we're entering a Closure/Graph block and we're transforming ReturnStmts,
392 // then we update target_block_ to be the new block.
393 // otherwise, target_block_ remains the same.
updateTargetBlocktorch::jit::ExitTransformer394 void updateTargetBlock(Block* block) {
395 if (owningNodeKind(block) == prim::Loop &&
396 // NOLINTNEXTLINE(bugprone-branch-clone)
397 current_exit_kind_ == prim::LoopContinuation) {
398 target_block_ = block;
399 } else if (
400 isGraphOrClosureBlock(block) &&
401 current_exit_kind_ == prim::ReturnStmt) {
402 target_block_ = block;
403 }
404 }
405
transformExitstorch::jit::ExitTransformer406 ExitPair transformExits(Block* block) {
407 Block* prev_target_block = target_block_;
408 updateTargetBlock(block);
409 ExitPair exit_pair = constructWontExitPair();
410
411 for (auto it = block->nodes().begin(); it != block->nodes().end();) {
412 Node* node = *it;
413 it++;
414 switch (node->kind()) {
415 case prim::RaiseException: {
416 exit_pair = constructThrowsExitPair();
417 } break;
418 case prim::ReturnStmt:
419 case prim::LoopContinuation: {
420 if (node->kind() == current_exit_kind_) {
421 exit_pair = constructWillExitPair(node->inputs());
422 node->destroy();
423 }
424 } break;
425 case prim::If: {
426 exit_pair = transformIf(node);
427 } break;
428 case prim::With: {
429 exit_pair = transformWith(node);
430 } break;
431 case prim::Closure: {
432 // exits of closure declaration stay local to the closure
433 transformExits(node->blocks().at(0));
434 } break;
435 case prim::Loop: {
436 exit_pair = transformLoop(node);
437 } break;
438 }
439
440 // if we have hit a node that might exit, we need to conditionally execute
441 // all subsequent nodes in the block. if we've hit a node that will exit
442 // we can remove all subsequent nodes.
443 ExitStatus status = getExitStatus(exit_pair);
444 if (status == ExitStatus::WILL || status == ExitStatus::THROWS) {
445 deleteAfterExitNodes(block, it);
446 break;
447 }
448 if (status == ExitStatus::MIGHT) {
449 if (it != block->nodes().end()) {
450 exit_pair = guardBlockNodes(block, exit_pair, it);
451 }
452 break;
453 }
454 }
455
456 // if we are targeting this block, update the output values to the
457 // exit values. since the exit does not extend outside this block,
458 // update returned exit to false. then, reset the target_block to whatever
459 // it was previously
460 if (target_block_ == block) {
461 // if we might have exited, use the new exit values if we did exit,
462 // otherwise use the existing block outputs.
463 if (getExitStatus(exit_pair) == ExitStatus::MIGHT) {
464 auto new_if =
465 graph_->create(prim::If, 0)->insertBefore(block->return_node());
466 new_if->addBlock();
467 new_if->addBlock();
468 new_if->addInput(exit_pair.hasExited());
469 addIfOutputs(new_if, exit_pair.exitValues(), block->outputs());
470 replaceBlockOutputs(block, new_if->outputs());
471 } else if (getExitStatus(exit_pair) == ExitStatus::WILL) {
472 replaceBlockOutputs(block, exit_pair.exitValues());
473 }
474
475 // reset the exiting status. an exit should only reach its target block.
476 // e.g. a continue only affects most recent loop, return in closure
477 // does not affect enclosing graph.
478 // Exceptions do not propagate from Loops bc we might not enter the loop,
479 // and not from closures bc the Function node is a declaration and not
480 // an invocation.
481 exit_pair = constructWontExitPair();
482 }
483 target_block_ = prev_target_block;
484 return exit_pair;
485 }
486
getUnitValuetorch::jit::ExitTransformer487 Value* getUnitValue(const TypePtr& type) {
488 auto maybe_val = unit_values_.find(type);
489 if (maybe_val != unit_values_.end()) {
490 return maybe_val->second;
491 }
492 auto unit = graph_->createUninitialized(type)
493 ->insertAfter(graph_->param_node())
494 ->output();
495 unit_values_[type] = unit;
496 return unit;
497 }
498
499 // we create one uninitialized value per type, cache it here and reuse it
500 std::unordered_map<TypePtr, Value*> unit_values_;
501
502 // can either be LoopContinuation/ReturnStmt
503 Symbol current_exit_kind_;
504 Value* true_val_;
505 Value* false_val_;
506 Value* throws_val_;
507
508 // when we see current_exit_kind_, this is the block that the values are
509 // exiting to. For example when we are transforming LoopContinuations
510 // for i in range(5):
511 // while i < 3:
512 // continue
513 // break
514 // when we transform the for loop block, target_block_ will be set the for
515 // block. then, when we enter the while loop, target_block_ will be the while
516 // loop block. when we are done transforming the while it will be set back to
517 // the for block.
518 Block* target_block_ = nullptr;
519 std::shared_ptr<Graph> graph_;
520 };
521
inlineConsecutiveIfs(Node * node)522 static bool inlineConsecutiveIfs(Node* node) {
523 if (node->kind() != prim::If || node->next()->kind() != prim::If) {
524 return false;
525 }
526
527 IfView first_if(node);
528 IfView second_if(node->next());
529
530 // the second if must depend on a value outputted in the first if for us to
531 // inline the second if
532 if (second_if.cond()->node() != node) {
533 return false;
534 }
535
536 // both blocks must output a constant value for us to inline, and those values
537 // must be different. if the values are the same, then the subsequent if node
538 // will get constant prop'd away, and inlining it into the first node would
539 // double code size
540
541 auto input_offset = second_if.cond()->offset();
542 auto maybe_then_value = toIValue(first_if.thenOutputs().at(input_offset));
543 auto maybe_else_value = toIValue(first_if.elseOutputs().at(input_offset));
544 if (!maybe_then_value || !maybe_else_value ||
545 maybe_then_value->toBool() == maybe_else_value->toBool()) {
546 return false;
547 }
548
549 bool then_value = maybe_then_value->toBool();
550 bool else_value = maybe_else_value->toBool();
551
552 for (const auto i : c10::irange(2)) {
553 Block* first_if_block = nullptr;
554 Block* second_if_block = nullptr;
555
556 if (i == 0) {
557 first_if_block = first_if.thenBlock();
558 second_if_block =
559 then_value ? second_if.thenBlock() : second_if.elseBlock();
560 } else {
561 first_if_block = first_if.elseBlock();
562 second_if_block =
563 else_value ? second_if.thenBlock() : second_if.elseBlock();
564 ;
565 }
566
567 // we need to replace values that were used in the second if that were
568 // outputs of the first if with the equivalent value in the scope of the
569 // block we're copying into
570 auto value_map = [&](Value* v) {
571 if (v->node() != first_if.node()) {
572 return v;
573 }
574 auto offset = v->offset();
575 return first_if_block->outputs().at(offset);
576 };
577
578 // clone from also copies block outputs from second_if_block onto
579 // first_if_block
580 first_if_block->cloneFrom(second_if_block, value_map);
581 }
582
583 for (Value* output : second_if.outputs()) {
584 auto new_out = first_if.node()->addOutput()->copyMetadata(output);
585 output->replaceAllUsesWith(new_out);
586 }
587 second_if.node()->destroy();
588 return true;
589 }
590
591 // After an early return, we conditionalize all further execution
592 // This means code like the following:
593 // if x:
594 // return 1
595 // return 2
596 // Gets generated as one if statement checking `if x`, and then a second if
597 // statement that conditionalizes execution. We can rewrite cases like these
598 // into one if statement, so that the above examples gets rewritten to look
599 // like: if x:
600 // return 1
601 // else:
602 // return 2
inlineConsecutiveIfs(Block * block)603 static void inlineConsecutiveIfs(Block* block) {
604 for (auto it = block->nodes().begin(), end = block->nodes().end();
605 it != end;) {
606 for (Block* b : it->blocks()) {
607 inlineConsecutiveIfs(b);
608 }
609
610 // if we fused two ifs, we need to check current node and new next node
611 if (!inlineConsecutiveIfs(*it)) {
612 it++;
613 }
614 }
615 }
616
617 // Adds prim::With nodes to a graph to help handle early exits between
618 // prim::Enter and prim::Exit nodes. More specifically, it transforms
619 // IR that looks like this:
620 //
621 // %a = prim::Enter(%b)
622 // <code>
623 // %c = prim::Exit(%b)
624 //
625 // to this:
626 //
627 // %a = prim::Enter(%b)
628 // = prim::With()
629 // block0():
630 // <code>
631 // -> ()
632 // block1():
633 // %c = prim::Exit(%b)
634 // -> ()
635 //
convertEnterExitNodesToWithBlocks(std::shared_ptr<Graph> & graph)636 static void convertEnterExitNodesToWithBlocks(std::shared_ptr<Graph>& graph) {
637 // First, find all Enter-Exit pairs up front to avoid iterator invalidation
638 // issues later when moving nodes around. Do this by iterating through the
639 // nodes of the graph while keeping a stack of encountered Enter nodes. Each
640 // time an Exit node is seen, its corresponding Enter node must be at the
641 // top of the stack. Pop it and record the pair.
642 std::vector<std::pair<Node*, Node*>> enter_exit_pairs;
643 std::vector<Node*> enter_node_stack;
644
645 DepthFirstGraphNodeIterator it(graph);
646 Node* node = it.next();
647
648 while (node) {
649 if (node->kind() == prim::Enter) {
650 enter_node_stack.emplace_back(node);
651 } else if (node->kind() == prim::Exit) {
652 // enter_node_stack should not be empty.
653 TORCH_INTERNAL_ASSERT(!enter_node_stack.empty());
654 // The input to this Exit node should be the same as that of the Enter
655 // node on the top of the enter_node_stack.
656 TORCH_INTERNAL_ASSERT(
657 enter_node_stack.back()->input(0) == node->input(0));
658 // Record the pair.
659 enter_exit_pairs.emplace_back(enter_node_stack.back(), node);
660 enter_node_stack.pop_back();
661 }
662
663 node = it.next();
664 }
665
666 // The stack should be empty; an Exit should have been found for every Enter.
667 TORCH_INTERNAL_ASSERT(enter_node_stack.empty());
668
669 // Now, add a With block for each Enter-Exit pair. The innermost pairs were
670 // found first, so they will be converted first.
671 for (auto& pair : enter_exit_pairs) {
672 Node* enter = pair.first;
673 Node* exit = pair.second;
674
675 auto* with = graph->create(prim::With, /*num_outputs=*/0);
676 auto* body_block = with->addBlock();
677 auto* exit_block = with->addBlock();
678
679 // Insert the With after the Enter.
680 Node* cur = enter->next();
681 Node* insert_point = body_block->param_node();
682
683 // Move all of the nodes between the Enter and Exit into the body block.
684 while (cur != exit) {
685 auto* next = cur->next();
686 cur->moveAfter(insert_point);
687 insert_point = insert_point->next();
688 cur = next;
689 }
690
691 // Move the Exit node into the exit block.
692 exit->moveAfter(exit_block->param_node());
693 with->insertAfter(enter);
694 }
695 }
696
697 // Removes prim::With nodes from a graph. More specifically, it transforms
698 // IR that looks like this:
699 //
700 // %a = prim::Enter(%b)
701 // = prim::With()
702 // block0():
703 // <code>
704 // -> ()
705 // block1():
706 // %c = prim::Exit(%b)
707 // ->()
708 //
709 // to this:
710 // %a = prim::Enter(%b)
711 // <code>
712 // %c = prim::Exit(%b)
713 //
convertWithBlocksToEnterExitNodes(std::shared_ptr<Graph> & graph)714 static void convertWithBlocksToEnterExitNodes(std::shared_ptr<Graph>& graph) {
715 // First, find all With blocks to avoid iterator invalidation issues when
716 // moving nodes around later.
717 std::vector<Node*> with_nodes;
718
719 DepthFirstGraphNodeIterator it(graph);
720 Node* node = it.next();
721
722 while (node) {
723 if (node->kind() == prim::With) {
724 with_nodes.emplace_back(node);
725 }
726 node = it.next();
727 }
728
729 // For each With node:
730 for (auto& node : with_nodes) {
731 auto* body_block = node->blocks().at(0);
732 auto* exit_block = node->blocks().at(1);
733
734 std::vector<Node*> to_append;
735
736 // Record all nodes that need to be appended after the Enter that precedes
737 // the With block to avoid iterator invalidation issues later when moving
738 // nodes around.
739 for (auto body_node : body_block->nodes()) {
740 to_append.emplace_back(body_node);
741 }
742
743 for (auto exit_node : exit_block->nodes()) {
744 to_append.emplace_back(exit_node);
745 }
746
747 Node* cur = node->prev();
748
749 // Move all nodes inside the with block outside of it.
750 for (auto& node : to_append) {
751 node->moveAfter(cur);
752 cur = node;
753 }
754 node->destroy();
755 }
756 }
757
758 // This pass takes in a graph where LoopContinuation & ReturnStmts exist in the
759 // graph and erases them in the graph, correctly setting block outputs.
760 // prim::LoopContinuation(*vals) means that the values are targeting the most
761 // recent loop block. prim::ReturnStmt(*vals) means that the values are
762 // targeting the most recent Closure or Graph Block. Once we hit an exit node,
763 // we do not execute any further instructions until the block exit reaches its
764 // destination. If we encounter a node that contains nested blocks that may
765 // have hit an exit node, such as an if statement that exits in one block
766 // and does not exit in the other, we use a boolean value to indicate if the
767 // exit has been hit or not. Then, we conditionalize further execution.
768 //
769 // Python example:
770 // while i < 5:
771 // if i == 3:
772 // i += 1
773 // continue
774 // i += 2
775 //
776 // -> transforms to:
777 //
778 // continue_loop = i < 5
779 // while continue_loop:
780 // if i == 3:
781 // i = i + 1
782 // continue_loop = i < 5
783 // did_exit = True
784 // if did_exit:
785 // pass
786 // else:
787 // i = i + 2
788 // continue_loop = i < 5
789 // IR as it enters pass:
790 // %36 : bool = aten::lt(%i.1, %3)
791 // %i : int = prim::Loop(%1, %36, %i.1)
792 // block0(%5 : int, %i.17 : int):
793 // %8 : bool = aten::eq(%i.17, %7)
794 // %i.16 : int = prim::If(%8)
795 // block0():
796 // %i.6 : int = aten::add(%i.17, %11)
797 // %33 : bool = aten::lt(%i.6, %3)
798 // = prim::LoopContinuation(%33, %i.6)
799 // -> (%i.6)
800 // block1():
801 // -> (%i.17)
802 // %i.13 : int = aten::add(%i.16, %19)
803 // %4 : bool = aten::lt(%i.13, %3)
804 // -> (%4, %i.13)
805 // return (%i)
806 //
807 // -> transforms to
808 //
809 // %false_val : bool = prim::Constant[value=0]()
810 // %true_val : bool = prim::Constant[value=1]()
811 // %40 : int = prim::Uninitialized()
812 // %39 : bool = prim::Uninitialized()
813 // %36 : bool = aten::lt(%i.1, %3)
814 // %i : int = prim::Loop(%1, %36, %i.1)
815 // block0(%5 : int, %i.17 : int):
816 // %8 : bool = aten::eq(%i.17, %7)
817 // %did_exit : bool, %continue_loop : bool, %43 : int, %i.16 : int =
818 // prim::If(%8)
819 // block0():
820 // %i.6 : int = aten::add(%i.17, %11)
821 // %33 : bool = aten::lt(%i.6, %3)
822 // -> (%true_val, %33, %i.6, %i.6)
823 // block1():
824 // -> (%false_val, %39, %40, %i.17)
825 // %44 : bool, %i : int = prim::If(%did_exit)
826 // block0():
827 // -> (%continue_loop, %43)
828 // block1():
829 // %i.13 : int = aten::add(%i.16, %19)
830 // %4 : bool = aten::lt(%i.13, %3)
831 // -> (%4, %i.13)
832 // -> (%44, %i)
833
TransformExits(std::shared_ptr<Graph> & graph)834 void TransformExits(std::shared_ptr<Graph>& graph) {
835 convertEnterExitNodesToWithBlocks(graph);
836 ExitTransformer e_loop(graph);
837 e_loop.transformLoopContinuations();
838 ExitTransformer e_ret(graph);
839 e_ret.transformReturnStmts();
840 inlineConsecutiveIfs(graph->block());
841 convertWithBlocksToEnterExitNodes(graph);
842 }
843
844 } // namespace torch::jit
845