xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/frontend/exit_transforms.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/frontend/exit_transforms.h>
2 
3 #include <ATen/core/jit_type.h>
4 #include <c10/util/irange.h>
5 #include <torch/csrc/jit/ir/ir.h>
6 #include <torch/csrc/jit/ir/ir_views.h>
7 #include <torch/csrc/jit/runtime/graph_iterator.h>
8 
9 namespace torch::jit {
10 
11 // WILL states that a node/block must hit the exit, MIGHT that it may happen,
12 // WONT that it will not happen. THROWS states that a node/block always throws,
13 // and allows us to create better graphs by not conditionalizing execution
14 // when it is not necessary. It is an optimization; replacing it with WONT
15 // would preserve graph semantics.
16 
17 enum class ExitStatus { WILL, MIGHT, WONT, THROWS };
18 
19 enum class Transform { Returns, LoopContinuations };
20 
21 // hasExited() indicates whether or not an exit has been hit.
22 // The ExitTransform pass maintains a false boolean false_val_ && a true boolean
23 // true_val_, and an uninitialized boolean throws_val_.
24 // if hasExited() == true_val_ then we have exited, if hasExited() == false_val_
25 // we have not, hasExited() == throws_val_ we have hit a block that throws.
26 // Otherwise, we might have exited.
27 // exitValues() are the values that we are propagating to a destination block.
28 // this is used for block outputs of loops and outputs of functions & closures
29 struct ExitPair : public std::pair<Value*, std::vector<Value*>> {
30   using pair::pair;
31 
ExitPairtorch::jit::ExitPair32   ExitPair(Value* exit_v, at::ArrayRef<Value*> exit_val_ref) {
33     std::vector<Value*> exit_vals;
34     for (Value* v : exit_val_ref) {
35       exit_vals.push_back(v);
36     }
37     AT_ASSERT(exit_v->type() == BoolType::get());
38     this->first = exit_v;
39     this->second = std::move(exit_vals);
40   }
41 
hasExitedtorch::jit::ExitPair42   Value* hasExited() const {
43     return this->first;
44   }
45 
exitValuestorch::jit::ExitPair46   std::vector<Value*> exitValues() const {
47     return this->second;
48   }
49 };
50 
51 /**
52  * This pass currently transforms the Graph so that all exit nodes targeting
53  * a block location are removed from the graph and unified.
54  * The exit node for breaks/continues is LoopContinuation, and the exit for
55  * Graphs & Closures is ReturnStmt.
56  *
57  * Once we hit an Exit Node, we do not execute any further instructions
58  * until the exit target has been reached.
59  *
60  * For blocks and control flow nodes that have an exit statement that may
61  * have been hit, we conditionalize all execution on a boolean value that
62  * indicates whether we have hit the exit, hasExited().
63  *
64  * The pass keeps tracks of blocks that always throw, so that we can construct
65  * simpler graphs. For example, if in one block of an if statement we return
66  * and in the other we throw, we can treat the node as always returning instead
67  * of conditionalizing execution in the remainder of the block.
68  */
69 struct ExitTransformer {
ExitTransformertorch::jit::ExitTransformer70   ExitTransformer(std::shared_ptr<Graph> graph) : graph_(std::move(graph)) {
71     WithInsertPoint guard(graph_->block()->nodes().front());
72     true_val_ = graph_->insertConstant(true);
73     false_val_ = graph_->insertConstant(false);
74     // this value will never be used, since we will always throw before it is
75     // accessed
76     throws_val_ = getUnitValue(BoolType::get());
77   };
78 
transformReturnStmtstorch::jit::ExitTransformer79   void transformReturnStmts() {
80     current_exit_kind_ = prim::ReturnStmt;
81     transformExits(graph_->block());
82   }
83 
transformLoopContinuationstorch::jit::ExitTransformer84   void transformLoopContinuations() {
85     current_exit_kind_ = prim::LoopContinuation;
86     transformExits(graph_->block());
87   }
88 
89  private:
constructThrowsExitPairtorch::jit::ExitTransformer90   ExitPair constructThrowsExitPair() {
91     return ExitPair(throws_val_, std::vector<Value*>({}));
92   }
constructWontExitPairtorch::jit::ExitTransformer93   ExitPair constructWontExitPair() {
94     return ExitPair(false_val_, std::vector<Value*>({}));
95   }
constructWillExitPairtorch::jit::ExitTransformer96   ExitPair constructWillExitPair(at::ArrayRef<Value*> exit_val_ref) {
97     return ExitPair(true_val_, exit_val_ref);
98   }
99 
getExitStatustorch::jit::ExitTransformer100   ExitStatus getExitStatus(ExitPair& exit_pair) {
101     Value* exit_v = exit_pair.hasExited();
102     if (exit_v == true_val_) {
103       return ExitStatus::WILL;
104     } else if (exit_v == false_val_) {
105       return ExitStatus::WONT;
106     } else if (exit_v == throws_val_) {
107       return ExitStatus::THROWS;
108     } else {
109       return ExitStatus::MIGHT;
110     }
111   }
112 
owningNodeKindtorch::jit::ExitTransformer113   static Symbol owningNodeKind(Block* block) {
114     if (block->owningNode()) {
115       return block->owningNode()->kind();
116     }
117     return Symbol();
118   }
119 
isGraphOrClosureBlocktorch::jit::ExitTransformer120   static bool isGraphOrClosureBlock(Block* block) {
121     return block->owningNode() == nullptr ||
122         owningNodeKind(block) == prim::Closure;
123   }
124 
removeOutputstorch::jit::ExitTransformer125   static void removeOutputs(Block* b) {
126     while (!b->outputs().empty()) {
127       b->eraseOutput(0);
128     }
129   }
130 
registerBlockOutputstorch::jit::ExitTransformer131   static void registerBlockOutputs(Block* b, at::ArrayRef<Value*> outs) {
132     for (Value* out : outs) {
133       b->registerOutput(out);
134     }
135   }
136 
replaceBlockOutputstorch::jit::ExitTransformer137   static void replaceBlockOutputs(Block* b, at::ArrayRef<Value*> outs) {
138     removeOutputs(b);
139     registerBlockOutputs(b, outs);
140   }
141 
addIfOutputstorch::jit::ExitTransformer142   static void addIfOutputs(
143       Node* n,
144       at::ArrayRef<Value*> true_outs,
145       at::ArrayRef<Value*> false_outs) {
146     IfView if_view(n);
147     registerBlockOutputs(if_view.thenBlock(), true_outs);
148     registerBlockOutputs(if_view.elseBlock(), false_outs);
149     for (const auto i : c10::irange(true_outs.size())) {
150       auto out_type = unifyTypes(
151           true_outs.at(i)->type(),
152           false_outs.at(i)->type(),
153           /*default_to_union=*/true);
154       n->addOutput()->setType(*out_type);
155     }
156   }
157 
158   // creates a vector of uninitialized values of the same type as the
159   // values_to_match
matchValuesWithUnitializedtorch::jit::ExitTransformer160   std::vector<Value*> matchValuesWithUnitialized(
161       at::ArrayRef<Value*> values_to_match) {
162     std::vector<Value*> match_values;
163     for (Value* val : values_to_match) {
164       match_values.push_back(getUnitValue(val->type()));
165     }
166     return match_values;
167   }
168 
transformLooptorch::jit::ExitTransformer169   ExitPair transformLoop(Node* node) {
170     LoopView loop(node);
171     Block* body = loop.bodyBlock();
172     auto exit_pair = transformExits(body);
173     // if we're not exiting to outside the loop we don't need to do any work.
174     // since we may not enter the loop return WONT for the THROWS case.
175 
176     if (getExitStatus(exit_pair) == ExitStatus::WONT ||
177         getExitStatus(exit_pair) == ExitStatus::THROWS) {
178       return constructWontExitPair();
179     }
180 
181     // if we are, we need to update the loop continue condition so that
182     // we exit the loop if we've hit an exit
183     // and we need to propagate hasExited() and exitValues() outside the loop
184 
185     // example:
186     // while i < 5:
187     //    i += 1
188     //    if j == 4:
189     //      return 5
190     // -> becomes
191     //
192     // loop_continue = i < 5
193     // has_exited = false
194     // ret_val = uninitialized(int)
195     // while loop_continue:
196     //    i += 1
197     //    if j == 4:
198     //      ret_val = 5
199     //      has_exited = True
200     //    else:
201     //      ret_val = uninitialized(int)
202     //      has_exited = False
203     //    if has_exited:
204     //      loop_continue = False
205     //    else:
206     //      loop_continue = i < 5
207 
208     // update loop continuation condition so that we exit if we hit an exit
209     WithInsertPoint insert(body);
210     auto new_if = graph_->insertNode(graph_->create(prim::If, 0));
211     new_if->addInput(exit_pair.hasExited());
212     new_if->addBlock()->registerOutput(false_val_);
213     new_if->addBlock()->registerOutput(loop.nextCond());
214     auto new_condition = new_if->addOutput()->setType(BoolType::get());
215     loop.bodyBlock()->eraseOutput(0);
216     loop.bodyBlock()->insertOutput(0, new_condition);
217 
218     // add hasExited() to loop outputs, we didn't exit if we didn't enter the
219     // loop
220     node->addInput(false_val_);
221     body->addInput()->setType(BoolType::get());
222     body->registerOutput(exit_pair.hasExited());
223     Value* new_has_exited = node->addOutput()->setType(BoolType::get());
224 
225     // add exit values
226     for (Value* exit_value : exit_pair.exitValues()) {
227       auto typ = exit_value->type();
228       node->addInput(getUnitValue(typ));
229       node->addOutput()->setType(typ);
230       body->addInput()->setType(typ);
231       body->registerOutput(exit_value);
232     }
233 
234     auto exit_vals = node->outputs().slice(
235         node->outputs().size() - exit_pair.exitValues().size());
236 
237     return ExitPair(new_has_exited, exit_vals);
238   }
239 
calcIfExitStatustorch::jit::ExitTransformer240   ExitStatus calcIfExitStatus(ExitStatus then_status, ExitStatus else_status) {
241     // if one branch throws, we can take the status of the other
242     if (then_status == ExitStatus::THROWS) {
243       return else_status;
244     } else if (else_status == ExitStatus::THROWS) {
245       return then_status;
246     }
247 
248     if (then_status == ExitStatus::WONT && else_status == ExitStatus::WONT) {
249       return ExitStatus::WONT;
250     }
251 
252     if (then_status == ExitStatus::WILL && else_status == ExitStatus::WILL) {
253       return ExitStatus::WILL;
254     }
255 
256     return ExitStatus::MIGHT;
257   }
258 
259   // Recursively transforms the if node
transformIftorch::jit::ExitTransformer260   ExitPair transformIf(Node* node) {
261     auto then_block = node->blocks().at(0);
262     auto else_block = node->blocks().at(1);
263 
264     auto then_pair = transformExits(then_block);
265     auto else_pair = transformExits(else_block);
266     auto then_status = getExitStatus(then_pair);
267     auto else_status = getExitStatus(else_pair);
268 
269     auto if_status = calcIfExitStatus(then_status, else_status);
270 
271     if (if_status == ExitStatus::THROWS) {
272       return constructThrowsExitPair();
273     }
274     if (if_status == ExitStatus::WONT) {
275       return constructWontExitPair();
276     }
277 
278     // The exit values of the block that is not exiting will not get
279     // used, so we create uninitialized values of the same type as the other
280     // block.
281     if (then_status == ExitStatus::WONT || then_status == ExitStatus::THROWS) {
282       std::vector<Value*> exit_vals =
283           matchValuesWithUnitialized(else_pair.exitValues());
284       then_pair = ExitPair(then_pair.hasExited(), exit_vals);
285     } else if (
286         else_status == ExitStatus::WONT || else_status == ExitStatus::THROWS) {
287       std::vector<Value*> exit_vals =
288           matchValuesWithUnitialized(then_pair.exitValues());
289       else_pair = ExitPair(else_pair.hasExited(), exit_vals);
290     }
291 
292     Value* has_exited = nullptr;
293     if (if_status == ExitStatus::WILL) {
294       // Need to maintain the invariant that if hasExited() == true_val_
295       // then we have exited.
296       has_exited = true_val_;
297     } else {
298       addIfOutputs(node, {then_pair.hasExited()}, {else_pair.hasExited()});
299       has_exited = node->outputs().at(node->outputs().size() - 1);
300     }
301     addIfOutputs(node, then_pair.exitValues(), else_pair.exitValues());
302     size_t num_exit_vals = then_pair.exitValues().size();
303     auto exit_vals =
304         node->outputs().slice(node->outputs().size() - num_exit_vals);
305     return ExitPair(has_exited, exit_vals);
306   }
307 
308   // Recursively transforms the With node.
transformWithtorch::jit::ExitTransformer309   ExitPair transformWith(Node* node) {
310     auto body_block = node->blocks().at(0);
311     auto body_pair = transformExits(body_block);
312     return body_pair;
313   }
314 
315   // Guards the remaining nodes in the block with an if node that takes
316   // the has exited value as its conditional
guardBlockNodestorch::jit::ExitTransformer317   ExitPair guardBlockNodes(
318       Block* block,
319       const ExitPair& exit_pair,
320       graph_node_list_iterator& iter) {
321     auto new_if = graph_->create(prim::If, 0)->insertBefore(*iter);
322     new_if->addInput(exit_pair.hasExited());
323 
324     auto exit_block = new_if->addBlock();
325     auto guard_block = new_if->addBlock();
326 
327     // Move all remaining nodes into the guard block
328     while (iter != block->nodes().end()) {
329       auto node = *iter++;
330       node->moveBefore(guard_block->return_node());
331     }
332 
333     std::vector<Value*> exit_block_vals;
334     // after an exit, the only values that will get used
335     // are the hasExited() and exitValues(), so we match the existing
336     // block outputs with unitialized
337     exit_block_vals = matchValuesWithUnitialized(block->outputs());
338 
339     // Set the new if to have the same outputs of the original block,
340     // then replace the original block outputs with new if's outputs
341     for (size_t i = 0; i < block->outputs().size(); ++i) {
342       exit_block->registerOutput(exit_block_vals.at(i));
343       guard_block->registerOutput(block->outputs().at(i));
344       new_if->addOutput()->setType(block->outputs().at(i)->type());
345     }
346 
347     while (!block->outputs().empty()) {
348       block->eraseOutput(0);
349     }
350     for (auto out : new_if->outputs()) {
351       block->registerOutput(out);
352     }
353 
354     graph_->create(current_exit_kind_, {exit_pair.exitValues()}, 0)
355         ->insertBefore(exit_block->return_node());
356     return transformIf(new_if);
357   }
358 
359   // these nodes my have uses,
360   // such as in the case:
361   // if i == 1:
362   //    break
363   //    j = j + 1
364   // where the j + 1 value will be a block output, but since they will
365   // never be used, it is safe to replace them with unitialized value
destroyNodeAfterExittorch::jit::ExitTransformer366   void destroyNodeAfterExit(Node* n) {
367     for (auto output : n->outputs()) {
368       if (!output->uses().empty()) {
369         output->replaceAllUsesWith(getUnitValue(output->type()));
370       }
371     }
372     n->destroy();
373   }
374 
deleteAfterExitNodestorch::jit::ExitTransformer375   void deleteAfterExitNodes(Block* block, graph_node_list_iterator& iter) {
376     if (iter == block->nodes().end()) {
377       return;
378     }
379     WithInsertPoint insert(*block->nodes().begin());
380     // need to destroy in reverse order so nodes have no uses when destroyed
381     for (auto it = block->nodes().reverse().begin(); it != iter;) {
382       Node* n = *it++;
383       if (*it != block->return_node()) {
384         destroyNodeAfterExit(n);
385       }
386     }
387     destroyNodeAfterExit(*iter);
388   }
389 
390   // if we're entering a Loop block & transforming LoopContinuations, or if
391   // we're entering a Closure/Graph block and we're transforming ReturnStmts,
392   // then we update target_block_ to be the new block.
393   // otherwise, target_block_ remains the same.
updateTargetBlocktorch::jit::ExitTransformer394   void updateTargetBlock(Block* block) {
395     if (owningNodeKind(block) == prim::Loop &&
396         // NOLINTNEXTLINE(bugprone-branch-clone)
397         current_exit_kind_ == prim::LoopContinuation) {
398       target_block_ = block;
399     } else if (
400         isGraphOrClosureBlock(block) &&
401         current_exit_kind_ == prim::ReturnStmt) {
402       target_block_ = block;
403     }
404   }
405 
transformExitstorch::jit::ExitTransformer406   ExitPair transformExits(Block* block) {
407     Block* prev_target_block = target_block_;
408     updateTargetBlock(block);
409     ExitPair exit_pair = constructWontExitPair();
410 
411     for (auto it = block->nodes().begin(); it != block->nodes().end();) {
412       Node* node = *it;
413       it++;
414       switch (node->kind()) {
415         case prim::RaiseException: {
416           exit_pair = constructThrowsExitPair();
417         } break;
418         case prim::ReturnStmt:
419         case prim::LoopContinuation: {
420           if (node->kind() == current_exit_kind_) {
421             exit_pair = constructWillExitPair(node->inputs());
422             node->destroy();
423           }
424         } break;
425         case prim::If: {
426           exit_pair = transformIf(node);
427         } break;
428         case prim::With: {
429           exit_pair = transformWith(node);
430         } break;
431         case prim::Closure: {
432           // exits of closure declaration stay local to the closure
433           transformExits(node->blocks().at(0));
434         } break;
435         case prim::Loop: {
436           exit_pair = transformLoop(node);
437         } break;
438       }
439 
440       // if we have hit a node that might exit, we need to conditionally execute
441       // all subsequent nodes in the block. if we've hit a node that will exit
442       // we can remove all subsequent nodes.
443       ExitStatus status = getExitStatus(exit_pair);
444       if (status == ExitStatus::WILL || status == ExitStatus::THROWS) {
445         deleteAfterExitNodes(block, it);
446         break;
447       }
448       if (status == ExitStatus::MIGHT) {
449         if (it != block->nodes().end()) {
450           exit_pair = guardBlockNodes(block, exit_pair, it);
451         }
452         break;
453       }
454     }
455 
456     // if we are targeting this block, update the output values to the
457     // exit values. since the exit does not extend outside this block,
458     // update returned exit to false. then, reset the target_block to whatever
459     // it was previously
460     if (target_block_ == block) {
461       // if we might have exited, use the new exit values if we did exit,
462       // otherwise use the existing block outputs.
463       if (getExitStatus(exit_pair) == ExitStatus::MIGHT) {
464         auto new_if =
465             graph_->create(prim::If, 0)->insertBefore(block->return_node());
466         new_if->addBlock();
467         new_if->addBlock();
468         new_if->addInput(exit_pair.hasExited());
469         addIfOutputs(new_if, exit_pair.exitValues(), block->outputs());
470         replaceBlockOutputs(block, new_if->outputs());
471       } else if (getExitStatus(exit_pair) == ExitStatus::WILL) {
472         replaceBlockOutputs(block, exit_pair.exitValues());
473       }
474 
475       // reset the exiting status. an exit should only reach its target block.
476       // e.g. a continue only affects most recent loop, return in closure
477       // does not affect enclosing graph.
478       // Exceptions do not propagate from Loops bc we might not enter the loop,
479       // and not from closures bc the Function node is a declaration and not
480       // an invocation.
481       exit_pair = constructWontExitPair();
482     }
483     target_block_ = prev_target_block;
484     return exit_pair;
485   }
486 
getUnitValuetorch::jit::ExitTransformer487   Value* getUnitValue(const TypePtr& type) {
488     auto maybe_val = unit_values_.find(type);
489     if (maybe_val != unit_values_.end()) {
490       return maybe_val->second;
491     }
492     auto unit = graph_->createUninitialized(type)
493                     ->insertAfter(graph_->param_node())
494                     ->output();
495     unit_values_[type] = unit;
496     return unit;
497   }
498 
499   // we create one uninitialized value per type, cache it here and reuse it
500   std::unordered_map<TypePtr, Value*> unit_values_;
501 
502   // can either be LoopContinuation/ReturnStmt
503   Symbol current_exit_kind_;
504   Value* true_val_;
505   Value* false_val_;
506   Value* throws_val_;
507 
508   // when we see current_exit_kind_, this is the block that the values are
509   // exiting to. For example when we are transforming LoopContinuations
510   // for i in range(5):
511   //   while i < 3:
512   //     continue
513   //   break
514   // when we transform the for loop block, target_block_ will be set the for
515   // block. then, when we enter the while loop, target_block_ will be the while
516   // loop block. when we are done transforming the while it will be set back to
517   // the for block.
518   Block* target_block_ = nullptr;
519   std::shared_ptr<Graph> graph_;
520 };
521 
inlineConsecutiveIfs(Node * node)522 static bool inlineConsecutiveIfs(Node* node) {
523   if (node->kind() != prim::If || node->next()->kind() != prim::If) {
524     return false;
525   }
526 
527   IfView first_if(node);
528   IfView second_if(node->next());
529 
530   // the second if must depend on a value outputted in the first if for us to
531   // inline the second if
532   if (second_if.cond()->node() != node) {
533     return false;
534   }
535 
536   // both blocks must output a constant value for us to inline, and those values
537   // must be different. if the values are the same, then the subsequent if node
538   // will get constant prop'd away, and inlining it into the first node would
539   // double code size
540 
541   auto input_offset = second_if.cond()->offset();
542   auto maybe_then_value = toIValue(first_if.thenOutputs().at(input_offset));
543   auto maybe_else_value = toIValue(first_if.elseOutputs().at(input_offset));
544   if (!maybe_then_value || !maybe_else_value ||
545       maybe_then_value->toBool() == maybe_else_value->toBool()) {
546     return false;
547   }
548 
549   bool then_value = maybe_then_value->toBool();
550   bool else_value = maybe_else_value->toBool();
551 
552   for (const auto i : c10::irange(2)) {
553     Block* first_if_block = nullptr;
554     Block* second_if_block = nullptr;
555 
556     if (i == 0) {
557       first_if_block = first_if.thenBlock();
558       second_if_block =
559           then_value ? second_if.thenBlock() : second_if.elseBlock();
560     } else {
561       first_if_block = first_if.elseBlock();
562       second_if_block =
563           else_value ? second_if.thenBlock() : second_if.elseBlock();
564       ;
565     }
566 
567     // we need to replace values that were used in the second if that were
568     // outputs of the first if with the equivalent value in the scope of the
569     // block we're copying into
570     auto value_map = [&](Value* v) {
571       if (v->node() != first_if.node()) {
572         return v;
573       }
574       auto offset = v->offset();
575       return first_if_block->outputs().at(offset);
576     };
577 
578     // clone from also copies block outputs from second_if_block onto
579     // first_if_block
580     first_if_block->cloneFrom(second_if_block, value_map);
581   }
582 
583   for (Value* output : second_if.outputs()) {
584     auto new_out = first_if.node()->addOutput()->copyMetadata(output);
585     output->replaceAllUsesWith(new_out);
586   }
587   second_if.node()->destroy();
588   return true;
589 }
590 
591 // After an early return, we conditionalize all further execution
592 // This means code like the following:
593 // if x:
594 //     return 1
595 // return 2
596 // Gets generated as one if statement checking `if x`, and then a second if
597 // statement that conditionalizes execution. We can rewrite cases like these
598 // into one if statement, so that the above examples gets rewritten to look
599 // like: if x:
600 //   return 1
601 // else:
602 //   return 2
inlineConsecutiveIfs(Block * block)603 static void inlineConsecutiveIfs(Block* block) {
604   for (auto it = block->nodes().begin(), end = block->nodes().end();
605        it != end;) {
606     for (Block* b : it->blocks()) {
607       inlineConsecutiveIfs(b);
608     }
609 
610     // if we fused two ifs, we need to check current node and new next node
611     if (!inlineConsecutiveIfs(*it)) {
612       it++;
613     }
614   }
615 }
616 
617 // Adds prim::With nodes to a graph to help handle early exits between
618 // prim::Enter and prim::Exit nodes. More specifically, it transforms
619 // IR that looks like this:
620 //
621 //   %a = prim::Enter(%b)
622 //   <code>
623 //   %c = prim::Exit(%b)
624 //
625 // to this:
626 //
627 //   %a = prim::Enter(%b)
628 //   = prim::With()
629 //     block0():
630 //       <code>
631 //     -> ()
632 //     block1():
633 //       %c = prim::Exit(%b)
634 //     -> ()
635 //
convertEnterExitNodesToWithBlocks(std::shared_ptr<Graph> & graph)636 static void convertEnterExitNodesToWithBlocks(std::shared_ptr<Graph>& graph) {
637   // First, find all Enter-Exit pairs up front to avoid iterator invalidation
638   // issues later when moving nodes around. Do this by iterating through the
639   // nodes of the graph while keeping a stack of encountered Enter nodes. Each
640   // time an Exit node is seen, its corresponding Enter node must be at the
641   // top of the stack. Pop it and record the pair.
642   std::vector<std::pair<Node*, Node*>> enter_exit_pairs;
643   std::vector<Node*> enter_node_stack;
644 
645   DepthFirstGraphNodeIterator it(graph);
646   Node* node = it.next();
647 
648   while (node) {
649     if (node->kind() == prim::Enter) {
650       enter_node_stack.emplace_back(node);
651     } else if (node->kind() == prim::Exit) {
652       // enter_node_stack should not be empty.
653       TORCH_INTERNAL_ASSERT(!enter_node_stack.empty());
654       // The input to this Exit node should be the same as that of the Enter
655       // node on the top of the enter_node_stack.
656       TORCH_INTERNAL_ASSERT(
657           enter_node_stack.back()->input(0) == node->input(0));
658       // Record the pair.
659       enter_exit_pairs.emplace_back(enter_node_stack.back(), node);
660       enter_node_stack.pop_back();
661     }
662 
663     node = it.next();
664   }
665 
666   // The stack should be empty; an Exit should have been found for every Enter.
667   TORCH_INTERNAL_ASSERT(enter_node_stack.empty());
668 
669   // Now, add a With block for each Enter-Exit pair. The innermost pairs were
670   // found first, so they will be converted first.
671   for (auto& pair : enter_exit_pairs) {
672     Node* enter = pair.first;
673     Node* exit = pair.second;
674 
675     auto* with = graph->create(prim::With, /*num_outputs=*/0);
676     auto* body_block = with->addBlock();
677     auto* exit_block = with->addBlock();
678 
679     // Insert the With after the Enter.
680     Node* cur = enter->next();
681     Node* insert_point = body_block->param_node();
682 
683     // Move all of the nodes between the Enter and Exit into the body block.
684     while (cur != exit) {
685       auto* next = cur->next();
686       cur->moveAfter(insert_point);
687       insert_point = insert_point->next();
688       cur = next;
689     }
690 
691     // Move the Exit node into the exit block.
692     exit->moveAfter(exit_block->param_node());
693     with->insertAfter(enter);
694   }
695 }
696 
697 // Removes prim::With nodes from a graph. More specifically, it transforms
698 // IR that looks like this:
699 //
700 //   %a = prim::Enter(%b)
701 //   = prim::With()
702 //     block0():
703 //       <code>
704 //     -> ()
705 //     block1():
706 //       %c = prim::Exit(%b)
707 //      ->()
708 //
709 // to this:
710 //   %a = prim::Enter(%b)
711 //   <code>
712 //   %c = prim::Exit(%b)
713 //
convertWithBlocksToEnterExitNodes(std::shared_ptr<Graph> & graph)714 static void convertWithBlocksToEnterExitNodes(std::shared_ptr<Graph>& graph) {
715   // First, find all With blocks to avoid iterator invalidation issues when
716   // moving nodes around later.
717   std::vector<Node*> with_nodes;
718 
719   DepthFirstGraphNodeIterator it(graph);
720   Node* node = it.next();
721 
722   while (node) {
723     if (node->kind() == prim::With) {
724       with_nodes.emplace_back(node);
725     }
726     node = it.next();
727   }
728 
729   // For each With node:
730   for (auto& node : with_nodes) {
731     auto* body_block = node->blocks().at(0);
732     auto* exit_block = node->blocks().at(1);
733 
734     std::vector<Node*> to_append;
735 
736     // Record all nodes that need to be appended after the Enter that precedes
737     // the With block to avoid iterator invalidation issues later when moving
738     // nodes around.
739     for (auto body_node : body_block->nodes()) {
740       to_append.emplace_back(body_node);
741     }
742 
743     for (auto exit_node : exit_block->nodes()) {
744       to_append.emplace_back(exit_node);
745     }
746 
747     Node* cur = node->prev();
748 
749     // Move all nodes inside the with block outside of it.
750     for (auto& node : to_append) {
751       node->moveAfter(cur);
752       cur = node;
753     }
754     node->destroy();
755   }
756 }
757 
758 // This pass takes in a graph where LoopContinuation & ReturnStmts exist in the
759 // graph and erases them in the graph, correctly setting block outputs.
760 // prim::LoopContinuation(*vals) means that the values are targeting the most
761 // recent loop block. prim::ReturnStmt(*vals) means that the values are
762 // targeting the most recent Closure or Graph Block. Once we hit an exit node,
763 // we do not execute any further instructions until the block exit reaches its
764 // destination. If we encounter a node that contains nested blocks that may
765 // have hit an exit node, such as an if statement that exits in one block
766 // and does not exit in the other, we use a boolean value to indicate if the
767 // exit has been hit or not. Then, we conditionalize further execution.
768 //
769 // Python example:
770 // while i < 5:
771 //   if i == 3:
772 //     i += 1
773 //     continue
774 //   i += 2
775 //
776 // -> transforms to:
777 //
778 // continue_loop = i < 5
779 // while continue_loop:
780 //   if i == 3:
781 //     i = i + 1
782 //     continue_loop = i < 5
783 //     did_exit = True
784 //   if did_exit:
785 //     pass
786 //   else:
787 //     i = i + 2
788 //     continue_loop = i < 5
789 // IR as it enters pass:
790 // %36 : bool = aten::lt(%i.1, %3)
791 // %i : int = prim::Loop(%1, %36, %i.1)
792 //   block0(%5 : int, %i.17 : int):
793 //     %8 : bool = aten::eq(%i.17, %7)
794 //     %i.16 : int = prim::If(%8)
795 //       block0():
796 //         %i.6 : int = aten::add(%i.17, %11)
797 //         %33 : bool = aten::lt(%i.6, %3)
798 //          = prim::LoopContinuation(%33, %i.6)
799 //         -> (%i.6)
800 //       block1():
801 //         -> (%i.17)
802 //     %i.13 : int = aten::add(%i.16, %19)
803 //     %4 : bool = aten::lt(%i.13, %3)
804 //     -> (%4, %i.13)
805 // return (%i)
806 //
807 //   -> transforms to
808 //
809 // %false_val : bool = prim::Constant[value=0]()
810 // %true_val : bool = prim::Constant[value=1]()
811 // %40 : int = prim::Uninitialized()
812 // %39 : bool = prim::Uninitialized()
813 // %36 : bool = aten::lt(%i.1, %3)
814 // %i : int = prim::Loop(%1, %36, %i.1)
815 //   block0(%5 : int, %i.17 : int):
816 //     %8 : bool = aten::eq(%i.17, %7)
817 //     %did_exit : bool, %continue_loop : bool, %43 : int, %i.16 : int =
818 //     prim::If(%8)
819 //       block0():
820 //         %i.6 : int = aten::add(%i.17, %11)
821 //         %33 : bool = aten::lt(%i.6, %3)
822 //         -> (%true_val, %33, %i.6, %i.6)
823 //       block1():
824 //         -> (%false_val, %39, %40, %i.17)
825 //     %44 : bool, %i : int = prim::If(%did_exit)
826 //       block0():
827 //         -> (%continue_loop, %43)
828 //       block1():
829 //         %i.13 : int = aten::add(%i.16, %19)
830 //         %4 : bool = aten::lt(%i.13, %3)
831 //         -> (%4, %i.13)
832 //     -> (%44, %i)
833 
TransformExits(std::shared_ptr<Graph> & graph)834 void TransformExits(std::shared_ptr<Graph>& graph) {
835   convertEnterExitNodesToWithBlocks(graph);
836   ExitTransformer e_loop(graph);
837   e_loop.transformLoopContinuations();
838   ExitTransformer e_ret(graph);
839   e_ret.transformReturnStmts();
840   inlineConsecutiveIfs(graph->block());
841   convertWithBlocksToEnterExitNodes(graph);
842 }
843 
844 } // namespace torch::jit
845