1 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
2
3 #include <torch/csrc/jit/ir/alias_analysis.h>
4 #include <torch/csrc/jit/ir/ir.h>
5 #include <torch/csrc/jit/ir/node_hashing.h>
6 #include <torch/csrc/jit/jit_log.h>
7
8 #include <unordered_map>
9
10 namespace torch::jit {
11 namespace {
12
13 struct CommonSubexpressionEliminator {
CommonSubexpressionEliminatortorch::jit::__anon330bd7e70111::CommonSubexpressionEliminator14 CommonSubexpressionEliminator(std::shared_ptr<Graph> graph)
15 : graph_(std::move(graph)) {}
16
runtorch::jit::__anon330bd7e70111::CommonSubexpressionEliminator17 bool run(std::function<Node*(Node*)> parent_lookup_fn) {
18 return run(graph_->block(), std::move(parent_lookup_fn));
19 }
20
21 // The function implements common subexpression elimination.
22 // Since the nodes are visited in topological order, one pass is enough.
23 // returns true if CSE made changes to a graph
runtorch::jit::__anon330bd7e70111::CommonSubexpressionEliminator24 bool run(Block* block, std::function<Node*(Node*)> parent_lookup_fn) {
25 std::unordered_set<Node*, HashNode, EqualNode> subexprs;
26 bool changed = false;
27 for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
28 auto node = *it;
29
30 if (node->kind() == prim::profile) {
31 GRAPH_DEBUG(
32 "Profiled nodes shouldn't be CSE'ed there's a separate pass that does dedup and merging:\n",
33 *node);
34 continue;
35 }
36
37 if (node->hasSideEffects()) {
38 GRAPH_DEBUG("Node was skipped due to side effects:\n", *node);
39 continue;
40 }
41 if (node->isNondeterministic()) {
42 GRAPH_DEBUG("Node was skipped due to its non determinism:\n", *node);
43 continue;
44 }
45
46 if (!node->blocks().empty()) {
47 // Traverse sub-blocks.
48 for (auto block : node->blocks()) {
49 changed |= run(block, [&](Node* n) {
50 auto existing = subexprs.find(n);
51 if (existing != subexprs.end()) {
52 return *existing;
53 }
54
55 return parent_lookup_fn(n);
56 });
57 }
58
59 continue;
60 }
61
62 if (getOrCreateAliasDb().hasWriters(node)) {
63 GRAPH_DEBUG("Node was skipped due to alias analysis result:\n", *node);
64 // Do NOT have enough information to do CSE on these nodes.
65 continue;
66 }
67
68 // Check for CSE opportunities in the parent block.
69 auto parent_lookup = parent_lookup_fn(node);
70 auto g_out = node->owningGraph()->outputs();
71 if (parent_lookup != nullptr) {
72 if (!getOrCreateAliasDb().safeToChangeAliasingRelationship(
73 node->outputs(), parent_lookup->outputs())) {
74 continue;
75 }
76
77 GRAPH_UPDATE("Replacing\n", *node, "with\n", *parent_lookup);
78 changed = true;
79 node->replaceAllUsesWith(parent_lookup);
80 it.destroyCurrent();
81 continue;
82 }
83
84 // Check whether the same subexpression already exists.
85 auto subit = subexprs.insert(node);
86 if (!subit.second) {
87 // Subexpression exists, replace the uses of node, and destroy it.
88 auto existing = *subit.first;
89
90 // don't introduce new aliasing among graph outputs
91 if (getOrCreateAliasDb().mayContainAlias(
92 node->outputs(), node->owningGraph()->outputs()) &&
93 getOrCreateAliasDb().mayContainAlias(existing->outputs(), g_out)) {
94 continue;
95 }
96
97 GRAPH_UPDATE("Replacing\n", *node, "with\n", *existing);
98 changed = true;
99 node->replaceAllUsesWith(existing);
100 // Destroy the node.
101 it.destroyCurrent();
102 }
103 }
104
105 return changed;
106 }
107
getOrCreateAliasDbtorch::jit::__anon330bd7e70111::CommonSubexpressionEliminator108 AliasDb& getOrCreateAliasDb() {
109 if (!alias_db_) {
110 alias_db_ = std::make_unique<AliasDb>(graph_);
111 }
112
113 return *alias_db_;
114 }
115
116 private:
117 std::unique_ptr<AliasDb> alias_db_;
118 std::shared_ptr<Graph> graph_;
119 };
120
121 } // namespace
122
EliminateCommonSubexpression(const std::shared_ptr<Graph> & graph)123 bool EliminateCommonSubexpression(const std::shared_ptr<Graph>& graph) {
124 GRAPH_DUMP("Before CSE", graph);
125 CommonSubexpressionEliminator cse(graph);
126 return cse.run([](Node*) { return nullptr; });
127 }
128 } // namespace torch::jit
129