1 #pragma once 2 3 #include <torch/csrc/jit/ir/ir.h> 4 #include <torch/csrc/jit/ir/irparser.h> 5 #include <torch/csrc/jit/ir/subgraph_matcher.h> 6 #include <torch/csrc/jit/passes/subgraph_rewrite.h> 7 8 namespace torch::jit::graph_rewrite_helper { 9 10 std::string getFuncName(Value* func_value); 11 Value* getValue( 12 const std::string& name, 13 const std::unordered_map<const Value*, Value*>& match_vmap, 14 const std::unordered_map<std::string, Value*>& vmap); 15 std::optional<IValue> getIValue( 16 const std::string& name, 17 const std::unordered_map<const Value*, Value*>& match_vmap, 18 const std::unordered_map<std::string, Value*>& vmap); 19 TORCH_API void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph); 20 21 bool isClampFusable( 22 const Match& match, 23 const std::unordered_map<std::string, Value*>& vmap); 24 25 // This struct contains a compiled IR patterns slated for use in the 26 // findPatternMatches function. The struct encapsulates the common 27 // information from parseIR that is used in conjunction with the 28 // pattern matching facility. A const instance of this struct can 29 // also be stored away to cache the compiled IR pattern and reduce 30 // runtime cost 31 struct PatternInfo { 32 std::string pattern_string; 33 std::unique_ptr<Graph> pattern_graph; 34 std::unordered_map<std::string, Value*> vmap; 35 std::vector<MatchFilter> filters; 36 37 static PatternInfo parse_from_str( 38 std::string pattern_string, 39 const std::vector<MatchFilter>& filters = {}) { 40 PatternInfo rv{ movePatternInfo41 std::move(pattern_string), 42 std::make_unique<Graph>(), 43 decltype(vmap){}, 44 filters}; 45 parseIR(rv.pattern_string, rv.pattern_graph.get(), rv.vmap); 46 return rv; 47 } 48 }; 49 50 } // namespace torch::jit::graph_rewrite_helper 51