xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/common_subexpression_elimination.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
2 
3 #include <torch/csrc/jit/ir/alias_analysis.h>
4 #include <torch/csrc/jit/ir/ir.h>
5 #include <torch/csrc/jit/ir/node_hashing.h>
6 #include <torch/csrc/jit/jit_log.h>
7 
8 #include <unordered_map>
9 
10 namespace torch::jit {
11 namespace {
12 
13 struct CommonSubexpressionEliminator {
CommonSubexpressionEliminatortorch::jit::__anon330bd7e70111::CommonSubexpressionEliminator14   CommonSubexpressionEliminator(std::shared_ptr<Graph> graph)
15       : graph_(std::move(graph)) {}
16 
runtorch::jit::__anon330bd7e70111::CommonSubexpressionEliminator17   bool run(std::function<Node*(Node*)> parent_lookup_fn) {
18     return run(graph_->block(), std::move(parent_lookup_fn));
19   }
20 
21   // The function implements common subexpression elimination.
22   // Since the nodes are visited in topological order, one pass is enough.
23   // returns true if CSE made changes to a graph
runtorch::jit::__anon330bd7e70111::CommonSubexpressionEliminator24   bool run(Block* block, std::function<Node*(Node*)> parent_lookup_fn) {
25     std::unordered_set<Node*, HashNode, EqualNode> subexprs;
26     bool changed = false;
27     for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
28       auto node = *it;
29 
30       if (node->kind() == prim::profile) {
31         GRAPH_DEBUG(
32             "Profiled nodes shouldn't be CSE'ed there's a separate pass that does dedup and merging:\n",
33             *node);
34         continue;
35       }
36 
37       if (node->hasSideEffects()) {
38         GRAPH_DEBUG("Node was skipped due to side effects:\n", *node);
39         continue;
40       }
41       if (node->isNondeterministic()) {
42         GRAPH_DEBUG("Node was skipped due to its non determinism:\n", *node);
43         continue;
44       }
45 
46       if (!node->blocks().empty()) {
47         // Traverse sub-blocks.
48         for (auto block : node->blocks()) {
49           changed |= run(block, [&](Node* n) {
50             auto existing = subexprs.find(n);
51             if (existing != subexprs.end()) {
52               return *existing;
53             }
54 
55             return parent_lookup_fn(n);
56           });
57         }
58 
59         continue;
60       }
61 
62       if (getOrCreateAliasDb().hasWriters(node)) {
63         GRAPH_DEBUG("Node was skipped due to alias analysis result:\n", *node);
64         // Do NOT have enough information to do CSE on these nodes.
65         continue;
66       }
67 
68       // Check for CSE opportunities in the parent block.
69       auto parent_lookup = parent_lookup_fn(node);
70       auto g_out = node->owningGraph()->outputs();
71       if (parent_lookup != nullptr) {
72         if (!getOrCreateAliasDb().safeToChangeAliasingRelationship(
73                 node->outputs(), parent_lookup->outputs())) {
74           continue;
75         }
76 
77         GRAPH_UPDATE("Replacing\n", *node, "with\n", *parent_lookup);
78         changed = true;
79         node->replaceAllUsesWith(parent_lookup);
80         it.destroyCurrent();
81         continue;
82       }
83 
84       // Check whether the same subexpression already exists.
85       auto subit = subexprs.insert(node);
86       if (!subit.second) {
87         // Subexpression exists, replace the uses of node, and destroy it.
88         auto existing = *subit.first;
89 
90         // don't introduce new aliasing among graph outputs
91         if (getOrCreateAliasDb().mayContainAlias(
92                 node->outputs(), node->owningGraph()->outputs()) &&
93             getOrCreateAliasDb().mayContainAlias(existing->outputs(), g_out)) {
94           continue;
95         }
96 
97         GRAPH_UPDATE("Replacing\n", *node, "with\n", *existing);
98         changed = true;
99         node->replaceAllUsesWith(existing);
100         // Destroy the node.
101         it.destroyCurrent();
102       }
103     }
104 
105     return changed;
106   }
107 
getOrCreateAliasDbtorch::jit::__anon330bd7e70111::CommonSubexpressionEliminator108   AliasDb& getOrCreateAliasDb() {
109     if (!alias_db_) {
110       alias_db_ = std::make_unique<AliasDb>(graph_);
111     }
112 
113     return *alias_db_;
114   }
115 
116  private:
117   std::unique_ptr<AliasDb> alias_db_;
118   std::shared_ptr<Graph> graph_;
119 };
120 
121 } // namespace
122 
EliminateCommonSubexpression(const std::shared_ptr<Graph> & graph)123 bool EliminateCommonSubexpression(const std::shared_ptr<Graph>& graph) {
124   GRAPH_DUMP("Before CSE", graph);
125   CommonSubexpressionEliminator cse(graph);
126   return cse.run([](Node*) { return nullptr; });
127 }
128 } // namespace torch::jit
129