xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/onednn/graph_helper.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <oneapi/dnnl/dnnl_graph.hpp>
4 #include <torch/csrc/jit/codegen/onednn/operator.h>
5 #include <torch/csrc/jit/ir/alias_analysis.h>
6 #include <torch/csrc/jit/ir/ir.h>
7 
8 namespace torch {
9 namespace jit {
10 namespace fuser {
11 namespace onednn {
12 
13 #define STRIDED_LAYOUT 0
14 #define OPAQUE_LAYOUT 1
15 
16 struct OpPartitionMap {
addOpPartitionMap17   void add(uint64_t opId, uint64_t partitionId) {
18     opmap_[opId] = partitionId;
19   }
addOpPartitionMap20   void add(Node* n, uint64_t partitionId) {
21     add(Operator::getId(n), partitionId);
22   }
hasOpPartitionMap23   bool has(uint64_t opId) {
24     return opmap_.count(opId) > 0;
25   }
hasOpPartitionMap26   bool has(Node* n) {
27     return has(Operator::getId(n));
28   }
getOpPartitionMap29   uint64_t get(uint64_t opId) {
30     return opmap_[opId];
31   }
getOpPartitionMap32   uint64_t get(Node* n) {
33     auto opId = Operator::getId(n);
34     TORCH_CHECK(
35         has(opId),
36         "Node ",
37         n->kind().toQualString(),
38         " does not belong to any LLGA partition");
39     return get(opId);
40   }
41 
42  private:
43   std::unordered_map<uint64_t, uint64_t> opmap_;
44 };
45 
46 class LlgaGraphHelper {
47  public:
48   LlgaGraphHelper(
49       const std::shared_ptr<Graph>& graph,
50       dnnl::graph::partition::policy policy =
51           dnnl::graph::partition::policy::fusion);
52 
53   bool shouldMerge(Node* toMerge, Node* subgraph);
54 
55   bool shouldConsiderForMerge(Node* node);
56 
57   bool checkForSingleOpPartition(Node* node);
58 
59   Node* createSingletonSubgraph(Node* n, AliasDb& db);
60 
61   void mergeNodeIntoSubgraph(Node* toMerge, Node* subgraphNode, AliasDb& db);
62 
63   void unmergeIfAnyNodeIsMissing(Node* subgraphNode);
64 
65   static bool isLlgaSubgraph(const Node* node);
66 
67   Operator makeEltwiseOp(Node* node, dnnl::graph::op::kind kind);
68 
69   Operator makeBinaryOp(Node* node, dnnl::graph::op::kind kind);
70 
71   std::vector<dnnl::graph::partition> getPartitions() const;
72 
73   std::map<size_t, Value*> getTensorIdToValue() const;
74 
75   Operator createOperator(Node* node);
76 
77  private:
78   size_t countSupportedOps(const std::shared_ptr<Graph>& graph) const;
79   std::unique_ptr<dnnl::graph::graph> dnnl_graph_ = nullptr;
80   std::unique_ptr<torch::jit::AliasDb> aliasDb_ = nullptr;
81   OpPartitionMap opToOwningPartition_;
82   std::vector<dnnl::graph::partition> partitions_;
83   std::map<size_t, Value*>
84       tensorIdToValue_; // map from tensorId to torch::jit::Value
85 };
86 
87 class LlgaNodeWrapper {
88  public:
89   LlgaNodeWrapper(const Node* node);
90 
91   void setOpaqueLayout(size_t offset);
92 
93   bool useOpaqueLayout(size_t offset) const;
94 
95   friend class LlgaGraphHelper;
96 
97  private:
98   Node* n;
99 };
100 
101 } // namespace onednn
102 } // namespace fuser
103 } // namespace jit
104 } // namespace torch
105