1 #include <torch/csrc/jit/codegen/onednn/graph_fuser.h>
2 #include <torch/csrc/jit/ir/alias_analysis.h>
3 #include <torch/csrc/jit/jit_log.h>
4 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
5 #include <torch/csrc/jit/passes/dead_code_elimination.h>
6 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
7
8 namespace torch {
9 namespace jit {
10 namespace fuser {
11 namespace onednn {
12
cleanupSubgraphs()13 void GraphRewriter::cleanupSubgraphs() {
14 auto curNode = *block_->nodes().rbegin();
15 while (curNode != *block_->nodes().rend()) {
16 // Save the previous node, since we might delete `curNode` in next block
17 auto prevNode = curNode->prev();
18 if (llgaHelper_.isLlgaSubgraph(curNode)) {
19 // Unmerge subgraph if we don't get every nodes of a partition
20 // into the subgraph due to failed alias check
21 llgaHelper_.unmergeIfAnyNodeIsMissing(curNode);
22 }
23 curNode = prevNode;
24 }
25 for (Node* n : block_->nodes()) {
26 for (Block* b : n->blocks()) {
27 GraphRewriter(b, graph_, aliasDb_).cleanupSubgraphs();
28 }
29 }
30 }
31
buildupSubgraphs()32 void GraphRewriter::buildupSubgraphs() {
33 // We need to run the rewriter multiple times in order to get all merge
34 // opportunities. This is because moveBeforeTopologicalValid may reorder
35 // nodes to be AFTER the current iteration point. In order to properly
36 // consider those nodes for merging, we need run the pass until no changes
37 // have been made.
38 //
39 // Example:
40 // c = f(a, b)
41 // d = f(c)
42 // e = f(d) <- iter is here, moving upward
43 // After c.moveBeforeTopologicallyValid(e), we have:
44 // c = f(a, b)
45 // e = f(d) <- iter still here
46 // d = f(c) <- this was node moved on the other side.
47 // see [workblocks]
48 auto workblocks = buildWorkBlocks();
49 for (auto& workblock : workblocks) {
50 bool any_changed = true;
51 while (any_changed) {
52 any_changed = false;
53 auto workblock_end = workblock.end()->reverseIterator();
54 auto workblock_begin = workblock.begin()->reverseIterator();
55 for (auto it = workblock_end; it != workblock_begin;) {
56 bool changed = false;
57 std::tie(it, changed) = scanNode(*it, workblock_begin);
58 any_changed |= changed;
59 }
60 }
61 }
62
63 // Construct Subgraphs Recursively
64 for (Node* n : block_->nodes()) {
65 for (auto subBlock : n->blocks()) {
66 GraphRewriter(subBlock, graph_, aliasDb_).buildupSubgraphs();
67 }
68 }
69 }
70
buildWorkBlocks()71 std::vector<WorkBlock> GraphRewriter::buildWorkBlocks() {
72 // [workblocks]
73 // the IR has many nodes which can never be reordered around, such as a
74 // prim::Bailout. if a node N is surrounded by two nodes which cannot be
75 // reordered, A and B, then a fusion group that is created from N
76 // can only contain nodes from (A, B) The nodes from A to B represent one
77 // work block for the subgraph rewriter to work on. By creating these up
78 // front, we avoid retraversing the whole graph block any time scanNode
79 // returns
80 Node* end_bound_node = block_->return_node();
81 Node* curr = end_bound_node->prev();
82 std::vector<WorkBlock> worklist;
83 while (curr != block_->param_node()) {
84 // cannot reorder around side effectful nodes
85 if (curr->hasSideEffects()) {
86 worklist.emplace_back(curr, end_bound_node);
87 end_bound_node = curr;
88 }
89 curr = curr->prev();
90 }
91 worklist.emplace_back(curr, end_bound_node);
92 return worklist;
93 }
94
scanNode(Node * consumer,graph_node_list::iterator workblock_begin)95 std::pair<graph_node_list::iterator, bool> GraphRewriter::scanNode(
96 Node* consumer,
97 graph_node_list::iterator workblock_begin) {
98 GRAPH_DEBUG("Scanning ", consumer->kind().toQualString());
99 if (llgaHelper_.shouldConsiderForMerge(consumer)) {
100 if (!llgaHelper_.isLlgaSubgraph(consumer)) {
101 consumer = llgaHelper_.createSingletonSubgraph(consumer, aliasDb_);
102 }
103 // Iterate through the workblock to merge nodes of the
104 // same partition determined by LLGA graph helper.
105 // Nodes like B and C do not share a common input but belong to a
106 // same partition, and thus we cannot only scan the input nodes
107 // to find merging opportunities. Instead, we have to scan through
108 // the whole workblock, which might lead to O^2 accesses in worst case
109 // A
110 // + - - / - \ - - +
111 // | B C |
112 // | | | |
113 // | D E |
114 // + - - \ - / - - +
115 // F
116 auto prev = ++consumer->reverseIterator();
117 for (auto it = prev; it != workblock_begin; it++) {
118 if (auto group = tryMerge(consumer, *it)) {
119 // we successfully merged, so the new group's `inputs` may have
120 // changed. So rescan the new group for more merging opportunities.
121 return std::make_pair(group.value()->reverseIterator(), true);
122 }
123 }
124 }
125 return std::make_pair(++consumer->reverseIterator(), false);
126 }
127
128 // Try to merge `producer` into `consumer`. If successful, this destroys
129 // `producer` and returns the `consumer` group.
tryMerge(Node * consumer,Node * producer)130 std::optional<Node*> GraphRewriter::tryMerge(Node* consumer, Node* producer) {
131 AT_ASSERT(llgaHelper_.isLlgaSubgraph(consumer));
132 bool canMerge = llgaHelper_.shouldMerge(producer, consumer) &&
133 aliasDb_.moveBeforeTopologicallyValid(producer, consumer);
134 if (!canMerge) {
135 return std::nullopt;
136 }
137 llgaHelper_.mergeNodeIntoSubgraph(producer, consumer, aliasDb_);
138 return consumer;
139 }
140
141 } // namespace onednn
142 } // namespace fuser
143 } // namespace jit
144 } // namespace torch
145