xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/onednn/graph_fuser.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/codegen/onednn/graph_fuser.h>
2 #include <torch/csrc/jit/codegen/onednn/graph_helper.h>
3 #include <torch/csrc/jit/ir/alias_analysis.h>
4 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
5 #include <torch/csrc/jit/passes/dead_code_elimination.h>
6 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
7 
8 namespace torch {
9 namespace jit {
10 namespace fuser {
11 namespace onednn {
12 
CreateLlgaSubgraphs(std::shared_ptr<Graph> & graph)13 void CreateLlgaSubgraphs(std::shared_ptr<Graph>& graph) {
14   AliasDb db(graph);
15   GraphRewriter graphRewriter(graph->block(), graph, db);
16   // We maintain alias db correctness in-place while building up the LLGA
17   // subgraphs, however it is difficult to preserve correctness when
18   // un-inlining autodiff subgraphs. We first recursively construct all
19   // subgraphs and then recursively cleanup & unmerge the small subgraphs
20   graphRewriter.buildupSubgraphs();
21   graphRewriter.cleanupSubgraphs();
22   // Run CSE globally onceto eliminate duplicates that may have occurred
23   // while inlining subgraphs.
24   EliminateCommonSubexpression(graph);
25   EliminateDeadCode(graph);
26 }
27 
28 } // namespace onednn
29 } // namespace fuser
30 } // namespace jit
31 } // namespace torch
32