xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/subgraph_rewrite.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /** This file defines API for pattern-based subgraph rewrites.
2  *
3  * The API can be used for finding concrete patterns in the model and replacing
4  * the corresponding subgraphs with another subgraph. A special case of such
5  * rewrites is fusion, where the new subgraph consists of just a single node.
6  *
7  * There is a default set of the most common patterns that everyone could use.
8  * Alternatively, an arbitrary pattern can be registered.
9  */
10 #pragma once
11 
12 #include <torch/csrc/jit/api/module.h>
13 #include <torch/csrc/jit/ir/ir.h>
14 
15 #include <functional>
16 #include <unordered_set>
17 #include <vector>
18 
19 namespace torch::jit {
20 
21 // Forward declarations.
22 struct RewritePatternDescr;
23 struct Match;
24 
25 using MatchFilter = std::function<
26     bool(const Match&, const std::unordered_map<std::string, Value*>&)>;
27 
28 /** Run pattern-based subgraph rewrites on all methods in the module.
29  *
30  * This pass will go through all methods in the module and try to replace all
31  * recognized patterns (see SubgraphRewriter::RegisterDefaultPatterns for the
32  * list of these patterns).
33  */
34 TORCH_API Module PatternBasedRewrite(const Module& module);
35 
36 /** A class implementing API for pattern-based subgraph rewrites.
37  *
38  * To perform pattern-based subgraph rewrites on a module using this API, one
39  * needs to create an object of such class, register rewrite patterns and run
40  * the transformation pass (`runOnModule`).
41  *
42  * To use standard patterns, one could use `RegisterDefaultPatterns`.
43  *
44  * To enable rewrites of custom patterns, the custom patterns must be registered
45  * with `RegisterRewritePattern`.
46  */
47 class TORCH_API SubgraphRewriter {
48  public:
49   // Run pattern-based subgraph rewrite pass on the module.
50   Module runOnModule(const Module& module);
51 
52   // Run pattern-based subgraph rewrite pass on the graph (used in testing).
53   // `filter` is a function that does extra filtering on the match. If it
54   // returns false for a given Match, we'll skip the Match. The filter
55   // function's arguments consist of a Match and a value map from parsing the
56   // pattern graph. Both the Match and the value map are necessary because we
57   // need to 1) do extra filtering on the matched result as well as 2) refer to
58   // the values in the matched result through the values in the pattern graph.
59   void runOnGraph(
60       std::shared_ptr<Graph>& graph,
61       const std::vector<MatchFilter>& filters);
62 
63   void runOnGraph(
64       std::shared_ptr<Graph>& graph,
65       const MatchFilter& filter =
66           [](const Match&, const std::unordered_map<std::string, Value*>&) {
67             return true;
68           }) {
69     runOnGraph(graph, std::vector<MatchFilter>({filter}));
70   }
71 
72   // Register standard rewrite patterns.
73   void RegisterDefaultPatterns();
74 
75   /** Register a custom rewrite pattern.
76    *
77    * The method takes two parameters specifying the pattern:
78    * \p PATTERN - IR string representing the pattern subgraph.
79    * \p REPLACEMENT - IR string representing the replacement subgraph.
80    * \p value name map - vector of pairs mapping values in the replacement graph
81    * to the values in the pattern graph. Used for preserving source range info
82    * across graph rewrite.
83    *
84    * See examples of pattern registering in `RegisterDefaultPatterns`.
85    */
86   void RegisterRewritePattern(
87       const std::string& pattern,
88       const std::string& replacement,
89       const std::vector<std::pair<std::string, std::string>>& value_name_pair =
90           {});
91 
92  private:
93   std::vector<RewritePatternDescr> patterns_;
94   std::unordered_set<Node*> nodes_to_delete_;
95 
96   void rewriteSinglePatternOnGraph(
97       std::shared_ptr<Graph>& graph,
98       const RewritePatternDescr& pattern,
99       const std::vector<MatchFilter>& filters);
100 
101   bool overlapsWithPreviousMatches(const Match* match);
102 };
103 
104 /** Rewrite pattern descriptor.
105  *
106  * This structure is used in the implementation of `SubgraphRewriter` and
107  * is not supposed to be used externally.
108  */
109 struct RewritePatternDescr {
110   std::string pattern;
111   std::string replacement;
112   std::unordered_map<std::string, std::string> value_name_map;
113 };
114 
115 } // namespace torch::jit
116