1 /** This file defines API for pattern-based subgraph rewrites. 2 * 3 * The API can be used for finding concrete patterns in the model and replacing 4 * the corresponding subgraphs with another subgraph. A special case of such 5 * rewrites is fusion, where the new subgraph consists of just a single node. 6 * 7 * There is a default set of the most common patterns that everyone could use. 8 * Alternatively, an arbitrary pattern can be registered. 9 */ 10 #pragma once 11 12 #include <torch/csrc/jit/api/module.h> 13 #include <torch/csrc/jit/ir/ir.h> 14 15 #include <functional> 16 #include <unordered_set> 17 #include <vector> 18 19 namespace torch::jit { 20 21 // Forward declarations. 22 struct RewritePatternDescr; 23 struct Match; 24 25 using MatchFilter = std::function< 26 bool(const Match&, const std::unordered_map<std::string, Value*>&)>; 27 28 /** Run pattern-based subgraph rewrites on all methods in the module. 29 * 30 * This pass will go through all methods in the module and try to replace all 31 * recognized patterns (see SubgraphRewriter::RegisterDefaultPatterns for the 32 * list of these patterns). 33 */ 34 TORCH_API Module PatternBasedRewrite(const Module& module); 35 36 /** A class implementing API for pattern-based subgraph rewrites. 37 * 38 * To perform pattern-based subgraph rewrites on a module using this API, one 39 * needs to create an object of such class, register rewrite patterns and run 40 * the transformation pass (`runOnModule`). 41 * 42 * To use standard patterns, one could use `RegisterDefaultPatterns`. 43 * 44 * To enable rewrites of custom patterns, the custom patterns must be registered 45 * with `RegisterRewritePattern`. 46 */ 47 class TORCH_API SubgraphRewriter { 48 public: 49 // Run pattern-based subgraph rewrite pass on the module. 50 Module runOnModule(const Module& module); 51 52 // Run pattern-based subgraph rewrite pass on the graph (used in testing). 53 // `filter` is a function that does extra filtering on the match. If it 54 // returns false for a given Match, we'll skip the Match. The filter 55 // function's arguments consist of a Match and a value map from parsing the 56 // pattern graph. Both the Match and the value map are necessary because we 57 // need to 1) do extra filtering on the matched result as well as 2) refer to 58 // the values in the matched result through the values in the pattern graph. 59 void runOnGraph( 60 std::shared_ptr<Graph>& graph, 61 const std::vector<MatchFilter>& filters); 62 63 void runOnGraph( 64 std::shared_ptr<Graph>& graph, 65 const MatchFilter& filter = 66 [](const Match&, const std::unordered_map<std::string, Value*>&) { 67 return true; 68 }) { 69 runOnGraph(graph, std::vector<MatchFilter>({filter})); 70 } 71 72 // Register standard rewrite patterns. 73 void RegisterDefaultPatterns(); 74 75 /** Register a custom rewrite pattern. 76 * 77 * The method takes two parameters specifying the pattern: 78 * \p PATTERN - IR string representing the pattern subgraph. 79 * \p REPLACEMENT - IR string representing the replacement subgraph. 80 * \p value name map - vector of pairs mapping values in the replacement graph 81 * to the values in the pattern graph. Used for preserving source range info 82 * across graph rewrite. 83 * 84 * See examples of pattern registering in `RegisterDefaultPatterns`. 85 */ 86 void RegisterRewritePattern( 87 const std::string& pattern, 88 const std::string& replacement, 89 const std::vector<std::pair<std::string, std::string>>& value_name_pair = 90 {}); 91 92 private: 93 std::vector<RewritePatternDescr> patterns_; 94 std::unordered_set<Node*> nodes_to_delete_; 95 96 void rewriteSinglePatternOnGraph( 97 std::shared_ptr<Graph>& graph, 98 const RewritePatternDescr& pattern, 99 const std::vector<MatchFilter>& filters); 100 101 bool overlapsWithPreviousMatches(const Match* match); 102 }; 103 104 /** Rewrite pattern descriptor. 105 * 106 * This structure is used in the implementation of `SubgraphRewriter` and 107 * is not supposed to be used externally. 108 */ 109 struct RewritePatternDescr { 110 std::string pattern; 111 std::string replacement; 112 std::unordered_map<std::string, std::string> value_name_map; 113 }; 114 115 } // namespace torch::jit 116