xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/remove_exceptions.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/constant_pooling.h>
2 #include <torch/csrc/jit/passes/constant_propagation.h>
3 #include <torch/csrc/jit/passes/remove_exceptions.h>
4 
5 #include <torch/csrc/jit/jit_log.h>
6 
7 namespace torch::jit {
8 
certainlyThrows(Block * block)9 static bool certainlyThrows(Block* block) {
10   for (Node* n : block->nodes()) {
11     if (n->kind() == prim::RaiseException) {
12       return true;
13     }
14   }
15   return false;
16 }
17 
EliminateExceptions(Block * block)18 static void EliminateExceptions(Block* block) {
19   auto graph = block->owningGraph();
20   Value* false_const = graph->insertConstant(IValue(false));
21   Value* true_const = graph->insertConstant(IValue(true));
22   for (Node* n : block->nodes()) {
23     if (n->kind() == prim::If) {
24       Block* true_block = n->blocks()[0];
25       Block* false_block = n->blocks()[1];
26       if (certainlyThrows(true_block)) {
27         n->input(0)->replaceAllUsesWith(false_const);
28       } else if (certainlyThrows(false_block)) {
29         n->input(0)->replaceAllUsesWith(true_const);
30       }
31     }
32 
33     for (Block* subblock : n->blocks()) {
34       EliminateExceptions(subblock);
35     }
36   }
37 }
38 
EliminateExceptions(std::shared_ptr<Graph> & graph)39 void EliminateExceptions(std::shared_ptr<Graph>& graph) {
40   GRAPH_DUMP("Before EliminateExceptions: ", graph);
41   EliminateExceptions(graph->block());
42   ConstantPropagation(graph);
43   ConstantPooling(graph);
44   GRAPH_DUMP("After EliminateExceptions: ", graph);
45 }
46 
47 } // namespace torch::jit
48