xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Meta Platforms, Inc. and affiliates.
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <ATen/Functions.h>
7 #include <ATen/Utils.h>
8 #include <torch/torch.h>
9 #include <xnnpack.h>
10 #include <unordered_set>
11 #include <vector>
12 
13 #include <torch/csrc/jit/backends/xnnpack/serialization/serializer.h>
14 
15 namespace torch {
16 namespace jit {
17 namespace xnnpack {
18 namespace delegate {
19 
20 class XNNGraph {
21  private:
22   const float output_min = -std::numeric_limits<float>::infinity();
23   const float output_max = std::numeric_limits<float>::infinity();
24 
25   // serializer class
26   XNNSerializer _serializer;
27   // xnn subgraph
28   xnn_subgraph_t _subgraph_ptr;
29   // Set of all the tensor values throughout the jit graph
30   std::unordered_set<torch::jit::Value*> _intermediate_tensors;
31   // Set of all the tensor values mapped to the xnnpack ids
32   std::unordered_map<torch::jit::Value*, uint32_t> _val_to_ids;
33   // Vector containing the torch valued inputs/outputs,
34   // must be ordered to preserve the order of input/outputs
35   std::vector<torch::jit::Value*> _inputs;
36   std::vector<torch::jit::Value*> _outputs;
37 
38   // Graph passes for optimizing and tracing torchscript graph
39   // Essentially massaging the graph into a digestiable format for
40   // xnnpack graph lowering.
41   std::shared_ptr<torch::jit::Graph> optimizeAndTraceGraph(
42       std::shared_ptr<torch::jit::Graph> graph,
43       std::vector<c10::IValue>& example_inputs);
44 
45   // Gather all the intermediate tensor values within a graph. This
46   // skips through all prim constants. The purpose of this is for defining
47   // the tensor values beforehand for the xnnpack subgraph.
48   void gatherTensorValues(std::shared_ptr<torch::jit::Graph>& graph);
49 
50   // Gathers the tensor values in a give node
51   void gatherNodeInputs(torch::jit::Node& node);
52 
53   // Helper function to determine if a jit value is a graph input
54   bool isGraphInput(torch::jit::Value* val);
55 
56   // Helper function to determine if a jit value is a graph output
57   bool isGraphOutput(torch::jit::Value* val);
58 
59   // Defines all xnnpack nodes for the nodes in the graph
60   void defineAllNodes(std::shared_ptr<torch::jit::Graph>& graph);
61 
62   // Defines all xnn tensor values used throughout the graph
63   void defineAllTensorValues();
64 
65   // Makes a pass through the graph and throws if any ops are unsupported
66   void checkOpsToDelegate(std::shared_ptr<torch::jit::Graph>& graph);
67 
68  public:
XNNGraph()69   XNNGraph() : _serializer(), _subgraph_ptr(nullptr) {
70     xnn_status status = xnn_initialize(/*allocator =*/nullptr);
71     TORCH_CHECK(xnn_status_success == status, "Failed to initialize xnnpack");
72   }
73 
~XNNGraph()74   ~XNNGraph() {
75     xnn_deinitialize();
76     if (_subgraph_ptr != nullptr) {
77       xnn_delete_subgraph(_subgraph_ptr);
78     }
79   }
80 
81   void buildXNNGraph(
82       std::shared_ptr<torch::jit::Graph>& graph,
83       std::vector<c10::IValue> example_inputs);
84 
85   void runGraphOnInputs(
86       std::vector<at::Tensor> tensor_inputs,
87       std::vector<at::Tensor> tensor_outputs);
88 
89   std::string serializedXNNGraph();
90 
91   std::vector<std::vector<long>> getGraphOutputShapes();
92 };
93 
94 } // namespace delegate
95 } // namespace xnnpack
96 } // namespace jit
97 } // namespace torch
98