1 // Copyright (c) Meta Platforms, Inc. and affiliates. 2 // 3 // This source code is licensed under the BSD-style license found in the 4 // LICENSE file in the root directory of this source tree. 5 6 #include <ATen/Functions.h> 7 #include <ATen/Utils.h> 8 #include <torch/torch.h> 9 #include <xnnpack.h> 10 #include <unordered_set> 11 #include <vector> 12 13 #include <torch/csrc/jit/backends/xnnpack/serialization/serializer.h> 14 15 namespace torch { 16 namespace jit { 17 namespace xnnpack { 18 namespace delegate { 19 20 class XNNGraph { 21 private: 22 const float output_min = -std::numeric_limits<float>::infinity(); 23 const float output_max = std::numeric_limits<float>::infinity(); 24 25 // serializer class 26 XNNSerializer _serializer; 27 // xnn subgraph 28 xnn_subgraph_t _subgraph_ptr; 29 // Set of all the tensor values throughout the jit graph 30 std::unordered_set<torch::jit::Value*> _intermediate_tensors; 31 // Set of all the tensor values mapped to the xnnpack ids 32 std::unordered_map<torch::jit::Value*, uint32_t> _val_to_ids; 33 // Vector containing the torch valued inputs/outputs, 34 // must be ordered to preserve the order of input/outputs 35 std::vector<torch::jit::Value*> _inputs; 36 std::vector<torch::jit::Value*> _outputs; 37 38 // Graph passes for optimizing and tracing torchscript graph 39 // Essentially massaging the graph into a digestiable format for 40 // xnnpack graph lowering. 41 std::shared_ptr<torch::jit::Graph> optimizeAndTraceGraph( 42 std::shared_ptr<torch::jit::Graph> graph, 43 std::vector<c10::IValue>& example_inputs); 44 45 // Gather all the intermediate tensor values within a graph. This 46 // skips through all prim constants. The purpose of this is for defining 47 // the tensor values beforehand for the xnnpack subgraph. 48 void gatherTensorValues(std::shared_ptr<torch::jit::Graph>& graph); 49 50 // Gathers the tensor values in a give node 51 void gatherNodeInputs(torch::jit::Node& node); 52 53 // Helper function to determine if a jit value is a graph input 54 bool isGraphInput(torch::jit::Value* val); 55 56 // Helper function to determine if a jit value is a graph output 57 bool isGraphOutput(torch::jit::Value* val); 58 59 // Defines all xnnpack nodes for the nodes in the graph 60 void defineAllNodes(std::shared_ptr<torch::jit::Graph>& graph); 61 62 // Defines all xnn tensor values used throughout the graph 63 void defineAllTensorValues(); 64 65 // Makes a pass through the graph and throws if any ops are unsupported 66 void checkOpsToDelegate(std::shared_ptr<torch::jit::Graph>& graph); 67 68 public: XNNGraph()69 XNNGraph() : _serializer(), _subgraph_ptr(nullptr) { 70 xnn_status status = xnn_initialize(/*allocator =*/nullptr); 71 TORCH_CHECK(xnn_status_success == status, "Failed to initialize xnnpack"); 72 } 73 ~XNNGraph()74 ~XNNGraph() { 75 xnn_deinitialize(); 76 if (_subgraph_ptr != nullptr) { 77 xnn_delete_subgraph(_subgraph_ptr); 78 } 79 } 80 81 void buildXNNGraph( 82 std::shared_ptr<torch::jit::Graph>& graph, 83 std::vector<c10::IValue> example_inputs); 84 85 void runGraphOnInputs( 86 std::vector<at::Tensor> tensor_inputs, 87 std::vector<at::Tensor> tensor_outputs); 88 89 std::string serializedXNNGraph(); 90 91 std::vector<std::vector<long>> getGraphOutputShapes(); 92 }; 93 94 } // namespace delegate 95 } // namespace xnnpack 96 } // namespace jit 97 } // namespace torch 98