xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/onednn/graph_fuser.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/codegen/onednn/graph_helper.h>
4 #include <torch/csrc/jit/ir/ir.h>
5 
6 namespace torch {
7 namespace jit {
8 namespace fuser {
9 namespace onednn {
10 
11 struct WorkBlock : public std::pair<Node*, Node*> {
12   using pair::pair;
13 
beginWorkBlock14   Node* begin() {
15     return this->first;
16   }
endWorkBlock17   Node* end() {
18     return this->second;
19   }
20 };
21 
22 class GraphRewriter {
23  public:
GraphRewriter(Block * block,std::shared_ptr<Graph> graph,AliasDb & aliasDb)24   GraphRewriter(Block* block, std::shared_ptr<Graph> graph, AliasDb& aliasDb)
25       : block_(block),
26         graph_(std::move(graph)),
27         aliasDb_(aliasDb),
28         llgaHelper_(graph_) {}
29 
30   void cleanupSubgraphs();
31   void buildupSubgraphs();
32 
33  private:
34   Block* block_;
35   std::shared_ptr<Graph> graph_;
36   AliasDb& aliasDb_;
37   LlgaGraphHelper llgaHelper_;
38   std::vector<WorkBlock> buildWorkBlocks();
39   std::pair<graph_node_list::iterator, bool> scanNode(
40       Node* consumer,
41       graph_node_list::iterator workblock_begin);
42   std::optional<Node*> tryMerge(Node* consumer, Node* producer);
43 };
44 
45 // This pass creates the subgraphs for oneDNN Graph Fusion Nodes.
46 // Its code-structure has been vastly inspired from
47 // torch/csrc/jit/passes/create_autodiff_subgraphs.cpp
48 void CreateLlgaSubgraphs(std::shared_ptr<Graph>& graph);
49 
50 } // namespace onednn
51 } // namespace fuser
52 } // namespace jit
53 } // namespace torch
54