xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/onednn/graph_rewriter.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/codegen/onednn/graph_fuser.h>
2 #include <torch/csrc/jit/ir/alias_analysis.h>
3 #include <torch/csrc/jit/jit_log.h>
4 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
5 #include <torch/csrc/jit/passes/dead_code_elimination.h>
6 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
7 
8 namespace torch {
9 namespace jit {
10 namespace fuser {
11 namespace onednn {
12 
cleanupSubgraphs()13 void GraphRewriter::cleanupSubgraphs() {
14   auto curNode = *block_->nodes().rbegin();
15   while (curNode != *block_->nodes().rend()) {
16     // Save the previous node, since we might delete `curNode` in next block
17     auto prevNode = curNode->prev();
18     if (llgaHelper_.isLlgaSubgraph(curNode)) {
19       // Unmerge subgraph if we don't get every nodes of a partition
20       // into the subgraph due to failed alias check
21       llgaHelper_.unmergeIfAnyNodeIsMissing(curNode);
22     }
23     curNode = prevNode;
24   }
25   for (Node* n : block_->nodes()) {
26     for (Block* b : n->blocks()) {
27       GraphRewriter(b, graph_, aliasDb_).cleanupSubgraphs();
28     }
29   }
30 }
31 
buildupSubgraphs()32 void GraphRewriter::buildupSubgraphs() {
33   // We need to run the rewriter multiple times in order to get all merge
34   // opportunities. This is because moveBeforeTopologicalValid may reorder
35   // nodes to be AFTER the current iteration point. In order to properly
36   // consider those nodes for merging, we need run the pass until no changes
37   // have been made.
38   //
39   // Example:
40   //   c = f(a, b)
41   //   d = f(c)
42   //   e = f(d)  <- iter is here, moving upward
43   // After c.moveBeforeTopologicallyValid(e), we have:
44   //   c = f(a, b)
45   //   e = f(d)  <- iter still here
46   //   d = f(c)  <- this was node moved on the other side.
47   // see [workblocks]
48   auto workblocks = buildWorkBlocks();
49   for (auto& workblock : workblocks) {
50     bool any_changed = true;
51     while (any_changed) {
52       any_changed = false;
53       auto workblock_end = workblock.end()->reverseIterator();
54       auto workblock_begin = workblock.begin()->reverseIterator();
55       for (auto it = workblock_end; it != workblock_begin;) {
56         bool changed = false;
57         std::tie(it, changed) = scanNode(*it, workblock_begin);
58         any_changed |= changed;
59       }
60     }
61   }
62 
63   // Construct Subgraphs Recursively
64   for (Node* n : block_->nodes()) {
65     for (auto subBlock : n->blocks()) {
66       GraphRewriter(subBlock, graph_, aliasDb_).buildupSubgraphs();
67     }
68   }
69 }
70 
buildWorkBlocks()71 std::vector<WorkBlock> GraphRewriter::buildWorkBlocks() {
72   // [workblocks]
73   // the IR has many nodes which can never be reordered around, such as a
74   // prim::Bailout. if a node N is surrounded by two nodes which cannot be
75   // reordered, A and B, then a fusion group that is created from N
76   // can only contain nodes from (A, B) The nodes from A to B represent one
77   // work block for the subgraph rewriter to work on. By creating these up
78   // front, we avoid retraversing the whole graph block any time scanNode
79   // returns
80   Node* end_bound_node = block_->return_node();
81   Node* curr = end_bound_node->prev();
82   std::vector<WorkBlock> worklist;
83   while (curr != block_->param_node()) {
84     // cannot reorder around side effectful nodes
85     if (curr->hasSideEffects()) {
86       worklist.emplace_back(curr, end_bound_node);
87       end_bound_node = curr;
88     }
89     curr = curr->prev();
90   }
91   worklist.emplace_back(curr, end_bound_node);
92   return worklist;
93 }
94 
scanNode(Node * consumer,graph_node_list::iterator workblock_begin)95 std::pair<graph_node_list::iterator, bool> GraphRewriter::scanNode(
96     Node* consumer,
97     graph_node_list::iterator workblock_begin) {
98   GRAPH_DEBUG("Scanning ", consumer->kind().toQualString());
99   if (llgaHelper_.shouldConsiderForMerge(consumer)) {
100     if (!llgaHelper_.isLlgaSubgraph(consumer)) {
101       consumer = llgaHelper_.createSingletonSubgraph(consumer, aliasDb_);
102     }
103     // Iterate through the workblock to merge nodes of the
104     // same partition determined by LLGA graph helper.
105     // Nodes like B and C do not share a common input but belong to a
106     // same partition, and thus we cannot only scan the input nodes
107     // to find merging opportunities. Instead, we have to scan through
108     // the whole workblock, which might lead to O^2 accesses in worst case
109     //              A
110     //      + - - / - \ - - +
111     //      |    B     C    |
112     //      |    |     |    |
113     //      |    D     E    |
114     //      + - - \ - / - - +
115     //              F
116     auto prev = ++consumer->reverseIterator();
117     for (auto it = prev; it != workblock_begin; it++) {
118       if (auto group = tryMerge(consumer, *it)) {
119         // we successfully merged, so the new group's `inputs` may have
120         // changed. So rescan the new group for more merging opportunities.
121         return std::make_pair(group.value()->reverseIterator(), true);
122       }
123     }
124   }
125   return std::make_pair(++consumer->reverseIterator(), false);
126 }
127 
128 // Try to merge `producer` into `consumer`. If successful, this destroys
129 // `producer` and returns the `consumer` group.
tryMerge(Node * consumer,Node * producer)130 std::optional<Node*> GraphRewriter::tryMerge(Node* consumer, Node* producer) {
131   AT_ASSERT(llgaHelper_.isLlgaSubgraph(consumer));
132   bool canMerge = llgaHelper_.shouldMerge(producer, consumer) &&
133       aliasDb_.moveBeforeTopologicallyValid(producer, consumer);
134   if (!canMerge) {
135     return std::nullopt;
136   }
137   llgaHelper_.mergeNodeIntoSubgraph(producer, consumer, aliasDb_);
138   return consumer;
139 }
140 
141 } // namespace onednn
142 } // namespace fuser
143 } // namespace jit
144 } // namespace torch
145