1 #include <torch/csrc/jit/codegen/onednn/graph_fuser.h> 2 #include <torch/csrc/jit/codegen/onednn/graph_helper.h> 3 #include <torch/csrc/jit/ir/alias_analysis.h> 4 #include <torch/csrc/jit/passes/common_subexpression_elimination.h> 5 #include <torch/csrc/jit/passes/dead_code_elimination.h> 6 #include <torch/csrc/jit/passes/utils/subgraph_utils.h> 7 8 namespace torch { 9 namespace jit { 10 namespace fuser { 11 namespace onednn { 12 CreateLlgaSubgraphs(std::shared_ptr<Graph> & graph)13void CreateLlgaSubgraphs(std::shared_ptr<Graph>& graph) { 14 AliasDb db(graph); 15 GraphRewriter graphRewriter(graph->block(), graph, db); 16 // We maintain alias db correctness in-place while building up the LLGA 17 // subgraphs, however it is difficult to preserve correctness when 18 // un-inlining autodiff subgraphs. We first recursively construct all 19 // subgraphs and then recursively cleanup & unmerge the small subgraphs 20 graphRewriter.buildupSubgraphs(); 21 graphRewriter.cleanupSubgraphs(); 22 // Run CSE globally onceto eliminate duplicates that may have occurred 23 // while inlining subgraphs. 24 EliminateCommonSubexpression(graph); 25 EliminateDeadCode(graph); 26 } 27 28 } // namespace onednn 29 } // namespace fuser 30 } // namespace jit 31 } // namespace torch 32