1 #include <torch/csrc/jit/passes/constant_pooling.h> 2 #include <torch/csrc/jit/passes/constant_propagation.h> 3 #include <torch/csrc/jit/passes/remove_exceptions.h> 4 5 #include <torch/csrc/jit/jit_log.h> 6 7 namespace torch::jit { 8 certainlyThrows(Block * block)9static bool certainlyThrows(Block* block) { 10 for (Node* n : block->nodes()) { 11 if (n->kind() == prim::RaiseException) { 12 return true; 13 } 14 } 15 return false; 16 } 17 EliminateExceptions(Block * block)18static void EliminateExceptions(Block* block) { 19 auto graph = block->owningGraph(); 20 Value* false_const = graph->insertConstant(IValue(false)); 21 Value* true_const = graph->insertConstant(IValue(true)); 22 for (Node* n : block->nodes()) { 23 if (n->kind() == prim::If) { 24 Block* true_block = n->blocks()[0]; 25 Block* false_block = n->blocks()[1]; 26 if (certainlyThrows(true_block)) { 27 n->input(0)->replaceAllUsesWith(false_const); 28 } else if (certainlyThrows(false_block)) { 29 n->input(0)->replaceAllUsesWith(true_const); 30 } 31 } 32 33 for (Block* subblock : n->blocks()) { 34 EliminateExceptions(subblock); 35 } 36 } 37 } 38 EliminateExceptions(std::shared_ptr<Graph> & graph)39void EliminateExceptions(std::shared_ptr<Graph>& graph) { 40 GRAPH_DUMP("Before EliminateExceptions: ", graph); 41 EliminateExceptions(graph->block()); 42 ConstantPropagation(graph); 43 ConstantPooling(graph); 44 GRAPH_DUMP("After EliminateExceptions: ", graph); 45 } 46 47 } // namespace torch::jit 48