1 #pragma once 2 3 #include <torch/csrc/jit/codegen/onednn/graph_helper.h> 4 #include <torch/csrc/jit/ir/ir.h> 5 6 namespace torch { 7 namespace jit { 8 namespace fuser { 9 namespace onednn { 10 11 struct WorkBlock : public std::pair<Node*, Node*> { 12 using pair::pair; 13 beginWorkBlock14 Node* begin() { 15 return this->first; 16 } endWorkBlock17 Node* end() { 18 return this->second; 19 } 20 }; 21 22 class GraphRewriter { 23 public: GraphRewriter(Block * block,std::shared_ptr<Graph> graph,AliasDb & aliasDb)24 GraphRewriter(Block* block, std::shared_ptr<Graph> graph, AliasDb& aliasDb) 25 : block_(block), 26 graph_(std::move(graph)), 27 aliasDb_(aliasDb), 28 llgaHelper_(graph_) {} 29 30 void cleanupSubgraphs(); 31 void buildupSubgraphs(); 32 33 private: 34 Block* block_; 35 std::shared_ptr<Graph> graph_; 36 AliasDb& aliasDb_; 37 LlgaGraphHelper llgaHelper_; 38 std::vector<WorkBlock> buildWorkBlocks(); 39 std::pair<graph_node_list::iterator, bool> scanNode( 40 Node* consumer, 41 graph_node_list::iterator workblock_begin); 42 std::optional<Node*> tryMerge(Node* consumer, Node* producer); 43 }; 44 45 // This pass creates the subgraphs for oneDNN Graph Fusion Nodes. 46 // Its code-structure has been vastly inspired from 47 // torch/csrc/jit/passes/create_autodiff_subgraphs.cpp 48 void CreateLlgaSubgraphs(std::shared_ptr<Graph>& graph); 49 50 } // namespace onednn 51 } // namespace fuser 52 } // namespace jit 53 } // namespace torch 54