1 #pragma once 2 3 #include <oneapi/dnnl/dnnl_graph.hpp> 4 #include <torch/csrc/jit/codegen/onednn/operator.h> 5 #include <torch/csrc/jit/ir/alias_analysis.h> 6 #include <torch/csrc/jit/ir/ir.h> 7 8 namespace torch { 9 namespace jit { 10 namespace fuser { 11 namespace onednn { 12 13 #define STRIDED_LAYOUT 0 14 #define OPAQUE_LAYOUT 1 15 16 struct OpPartitionMap { addOpPartitionMap17 void add(uint64_t opId, uint64_t partitionId) { 18 opmap_[opId] = partitionId; 19 } addOpPartitionMap20 void add(Node* n, uint64_t partitionId) { 21 add(Operator::getId(n), partitionId); 22 } hasOpPartitionMap23 bool has(uint64_t opId) { 24 return opmap_.count(opId) > 0; 25 } hasOpPartitionMap26 bool has(Node* n) { 27 return has(Operator::getId(n)); 28 } getOpPartitionMap29 uint64_t get(uint64_t opId) { 30 return opmap_[opId]; 31 } getOpPartitionMap32 uint64_t get(Node* n) { 33 auto opId = Operator::getId(n); 34 TORCH_CHECK( 35 has(opId), 36 "Node ", 37 n->kind().toQualString(), 38 " does not belong to any LLGA partition"); 39 return get(opId); 40 } 41 42 private: 43 std::unordered_map<uint64_t, uint64_t> opmap_; 44 }; 45 46 class LlgaGraphHelper { 47 public: 48 LlgaGraphHelper( 49 const std::shared_ptr<Graph>& graph, 50 dnnl::graph::partition::policy policy = 51 dnnl::graph::partition::policy::fusion); 52 53 bool shouldMerge(Node* toMerge, Node* subgraph); 54 55 bool shouldConsiderForMerge(Node* node); 56 57 bool checkForSingleOpPartition(Node* node); 58 59 Node* createSingletonSubgraph(Node* n, AliasDb& db); 60 61 void mergeNodeIntoSubgraph(Node* toMerge, Node* subgraphNode, AliasDb& db); 62 63 void unmergeIfAnyNodeIsMissing(Node* subgraphNode); 64 65 static bool isLlgaSubgraph(const Node* node); 66 67 Operator makeEltwiseOp(Node* node, dnnl::graph::op::kind kind); 68 69 Operator makeBinaryOp(Node* node, dnnl::graph::op::kind kind); 70 71 std::vector<dnnl::graph::partition> getPartitions() const; 72 73 std::map<size_t, Value*> getTensorIdToValue() const; 74 75 Operator createOperator(Node* node); 76 77 private: 78 size_t countSupportedOps(const std::shared_ptr<Graph>& graph) const; 79 std::unique_ptr<dnnl::graph::graph> dnnl_graph_ = nullptr; 80 std::unique_ptr<torch::jit::AliasDb> aliasDb_ = nullptr; 81 OpPartitionMap opToOwningPartition_; 82 std::vector<dnnl::graph::partition> partitions_; 83 std::map<size_t, Value*> 84 tensorIdToValue_; // map from tensorId to torch::jit::Value 85 }; 86 87 class LlgaNodeWrapper { 88 public: 89 LlgaNodeWrapper(const Node* node); 90 91 void setOpaqueLayout(size_t offset); 92 93 bool useOpaqueLayout(size_t offset) const; 94 95 friend class LlgaGraphHelper; 96 97 private: 98 Node* n; 99 }; 100 101 } // namespace onednn 102 } // namespace fuser 103 } // namespace jit 104 } // namespace torch 105