xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Meta Platforms, Inc. and affiliates.
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <caffe2/torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.h>
7 #include <torch/csrc/jit/runtime/graph_iterator.h>
8 #include <xnnpack.h>
9 
10 // graph passes
11 #include <torch/csrc/jit/passes/constant_propagation.h>
12 #include <torch/csrc/jit/passes/dead_code_elimination.h>
13 #include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
14 #include <torch/csrc/jit/passes/lower_tuples.h>
15 #include <torch/csrc/jit/passes/remove_mutation.h>
16 #include <torch/csrc/jit/runtime/jit_trace.h>
17 #include <torch/csrc/jit/tensorexpr/graph_opt.h>
18 
19 namespace torch {
20 namespace jit {
21 namespace xnnpack {
22 namespace delegate {
23 
optimizeAndTraceGraph(std::shared_ptr<torch::jit::Graph> graph,std::vector<c10::IValue> & example_inputs)24 std::shared_ptr<torch::jit::Graph> XNNGraph::optimizeAndTraceGraph(
25     std::shared_ptr<torch::jit::Graph> graph,
26     std::vector<c10::IValue>& example_inputs) {
27   OptimizeFrozenGraph(graph, true);
28   RemoveListMutation(graph);
29   RemoveTensorMutation(graph);
30   LowerAllTuples(graph);
31   ConstantPropagation(graph);
32   graph = TraceGraph(graph, example_inputs);
33 
34   return graph;
35 }
36 
buildXNNGraph(std::shared_ptr<torch::jit::Graph> & graph,std::vector<c10::IValue> example_inputs)37 void XNNGraph::buildXNNGraph(
38     std::shared_ptr<torch::jit::Graph>& graph,
39     std::vector<c10::IValue> example_inputs) {
40   graph = optimizeAndTraceGraph(graph, example_inputs);
41   checkOpsToDelegate(graph);
42   gatherTensorValues(graph);
43 
44   // count unique input/outputs (some inputs can be outputs)
45   std::unordered_set<torch::jit::Value*> externals;
46   for (auto inp : _inputs) {
47     externals.insert(inp);
48   }
49   for (auto out : _outputs) {
50     externals.insert(out);
51   }
52 
53   // create subgraph
54   xnn_status status = xnn_create_subgraph(
55       /*external_value_ids=*/externals.size(),
56       /*flags=*/0,
57       &_subgraph_ptr);
58   TORCH_CHECK(xnn_status_success == status, "Failed to create xnn subgraph");
59 
60   defineAllTensorValues();
61   defineAllNodes(graph);
62   // at this point graph is complete, for the sake of testing preprocess at
63   // this point we will do runtime setup and run with some default values
64 }
65 
runGraphOnInputs(std::vector<at::Tensor> tensor_inputs,std::vector<at::Tensor> tensor_outputs)66 void XNNGraph::runGraphOnInputs(
67     std::vector<at::Tensor> tensor_inputs,
68     std::vector<at::Tensor> tensor_outputs) {
69   TORCH_CHECK(
70       _subgraph_ptr != nullptr,
71       "run buildXNNGraph before running graph on inputs");
72   xnn_runtime_t runtime = nullptr;
73   xnn_status status =
74       xnn_create_runtime_v2(_subgraph_ptr, nullptr, /*flags=*/0, &runtime);
75   TORCH_CHECK(
76       xnn_status_success == status,
77       "failed to create runtime for running inputs");
78 
79   // smart pointer for runtime
80   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(
81       runtime, xnn_delete_runtime);
82 
83   std::vector<xnn_external_value> external_values;
84   TORCH_CHECK(
85       tensor_inputs.size() == _inputs.size(),
86       "supplied inputs does not match expected inputs");
87   for (int i = 0; i < tensor_inputs.size(); i++) {
88     external_values.push_back(
89         {_val_to_ids[_inputs[i]], tensor_inputs[i].data_ptr<float>()});
90   }
91 
92   TORCH_CHECK(
93       tensor_outputs.size() == _outputs.size(),
94       "supplied outputs does not match expected outputs");
95   for (int i = 0; i < tensor_outputs.size(); i++) {
96     external_values.push_back(
97         {_val_to_ids[_outputs[i]], tensor_outputs[i].data_ptr<float>()});
98   }
99   status = xnn_setup_runtime(
100       auto_runtime.get(), external_values.size(), external_values.data());
101   TORCH_CHECK(xnn_status_success == status, "runtime not properly setup");
102 
103   TORCH_CHECK(xnn_status_success == xnn_invoke_runtime(auto_runtime.get()));
104 }
105 
checkOpsToDelegate(std::shared_ptr<torch::jit::Graph> & graph)106 void XNNGraph::checkOpsToDelegate(std::shared_ptr<torch::jit::Graph>& graph) {
107   std::unordered_set<string> unsupported_ops;
108   DepthFirstGraphNodeIterator it(graph);
109   Node* node = nullptr;
110   while ((node = it.next()) != nullptr) {
111     switch (node->kind()) {
112       case prim::Constant:
113       case aten::add: {
114         break;
115       }
116       default: {
117         unsupported_ops.insert(node->kind().toDisplayString());
118       }
119     }
120   }
121   std::stringstream error;
122   for (auto itr = unsupported_ops.begin(); itr != unsupported_ops.end();
123        itr++) {
124     error << *itr << std::endl;
125     ;
126   }
127   TORCH_CHECK(
128       unsupported_ops.empty(),
129       "the module contains the following unsupported ops:\n" + error.str());
130 }
131 
serializedXNNGraph()132 std::string XNNGraph::serializedXNNGraph() {
133   std::vector<uint32_t> input_ids;
134   std::vector<uint32_t> output_ids;
135   std::unordered_set<uint32_t> num_externs;
136 
137   for (auto val : _inputs) {
138     input_ids.push_back(_val_to_ids[val]);
139     num_externs.emplace(_val_to_ids[val]);
140   }
141 
142   for (auto val : _outputs) {
143     output_ids.push_back(_val_to_ids[val]);
144     num_externs.emplace(_val_to_ids[val]);
145   }
146 
147   return _serializer.finishAndSerialize(
148       input_ids, output_ids, num_externs.size());
149 }
150 
getGraphOutputShapes()151 std::vector<std::vector<long>> XNNGraph::getGraphOutputShapes() {
152   std::vector<std::vector<long>> output_shapes;
153   for (auto val : _outputs) {
154     auto tensor_ptr = val->type()->cast<TensorType>();
155     std::vector<long> sizes = tensor_ptr->sizes().concrete_sizes().value();
156     output_shapes.push_back(sizes);
157   }
158 
159   return output_shapes;
160 }
161 
defineAllNodes(std::shared_ptr<torch::jit::Graph> & graph)162 void XNNGraph::defineAllNodes(std::shared_ptr<torch::jit::Graph>& graph) {
163   DepthFirstGraphNodeIterator it(graph);
164   Node* node = nullptr;
165   while ((node = it.next()) != nullptr) {
166     switch (node->kind()) {
167       case prim::Constant: {
168         break;
169       }
170       case aten::add: {
171         // todo: handle alpha for aten::add
172         uint32_t input1_id = _val_to_ids[node->inputs()[0]];
173         uint32_t input2_id = _val_to_ids[node->inputs()[1]];
174         TORCH_CHECK(
175             node->inputs()[2]->type()->cast<IntType>() == 1,
176             "non-1 alpha values not supported");
177         uint32_t output_id = _val_to_ids[node->outputs()[0]];
178 
179         xnn_status status = xnn_define_add2(
180             _subgraph_ptr,
181             output_min,
182             output_max,
183             input1_id,
184             input2_id,
185             output_id,
186             /*flags=*/0);
187         _serializer.serializeAddNode(input1_id, input2_id, output_id, 0);
188         TORCH_CHECK(status == xnn_status_success, "failed to create add node");
189         break;
190       }
191       default: {
192         throw std::exception();
193         TORCH_CHECK(
194             false,
195             "The node of ",
196             node->kind().toQualString(),
197             " is not supported yet");
198         break;
199       }
200     }
201   }
202 }
203 
defineAllTensorValues()204 void XNNGraph::defineAllTensorValues() {
205   uint32_t external_id =
206       std::numeric_limits<decltype(XNN_INVALID_VALUE_ID)>::min();
207   for (auto val : _intermediate_tensors) {
208     if (_val_to_ids.find(val) == _val_to_ids.end()) {
209       uint32_t id = XNN_INVALID_VALUE_ID;
210 
211       // cast value to tensortype
212       auto tensor_ptr = val->type()->cast<TensorType>();
213       auto num_dims = tensor_ptr->dim().value();
214 
215       // create size_t* for tensor shape, casting must be done from long ->
216       // size_t
217       std::vector<long> sizes = tensor_ptr->sizes().concrete_sizes().value();
218       std::vector<size_t> tensor_shape;
219       tensor_shape.reserve(sizes.size());
220       for (auto dim : sizes) {
221         TORCH_CHECK(dim >= 0, "Input Dims should be unsigned");
222         tensor_shape.push_back(static_cast<size_t>(dim));
223       }
224 
225       // ext_id value
226       uint32_t ext_id = XNN_INVALID_VALUE_ID;
227 
228       // update flag for if tensor is either graph input/output
229       uint32_t flags = 0;
230 
231       // Check if value was produced by prim::Constant
232       void* value_data = nullptr;
233       size_t buffer_idx = 0;
234       size_t num_bytes = 0;
235       if (val->node()->kind() == prim::Constant) {
236         std::optional<IValue> constant = val->node()->t(attr::value);
237         auto const_val = constant->toIValue().toTensor();
238         // Need tensor data to be contiguous for serialization
239         auto cont_const_val = const_val.contiguous();
240         value_data = cont_const_val.data_ptr();
241 
242         num_bytes = const_val.storage().nbytes();
243         buffer_idx = _serializer.serializeData(
244             static_cast<const uint8_t*>(value_data), num_bytes);
245       }
246 
247       if (isGraphInput(val) || isGraphOutput(val)) {
248         if (isGraphInput(val)) {
249           flags |= XNN_VALUE_FLAG_EXTERNAL_INPUT;
250         }
251         if (isGraphOutput(val)) {
252           flags |= XNN_VALUE_FLAG_EXTERNAL_OUTPUT;
253         }
254         ext_id = external_id++;
255       }
256       xnn_status status = xnn_define_tensor_value(
257           /*subgraph=*/_subgraph_ptr,
258           /*datatype=*/xnn_datatype_fp32,
259           /*num_dims=*/num_dims,
260           /*dims=*/tensor_shape.data(),
261           /*data=*/value_data,
262           /*external_id=*/ext_id,
263           /*flags=*/flags,
264           /*id_out=*/&id);
265       TORCH_CHECK(
266           status == xnn_status_success,
267           "failed to define xnn_tensor_id for: " + val->debugName());
268       _serializer.serializeTensorValue(
269           xnn_datatype_fp32,
270           num_dims,
271           tensor_shape,
272           buffer_idx,
273           ext_id,
274           flags,
275           id);
276       _val_to_ids.insert({val, id});
277     }
278   }
279 }
280 
gatherTensorValues(std::shared_ptr<torch::jit::Graph> & graph)281 void XNNGraph::gatherTensorValues(std::shared_ptr<torch::jit::Graph>& graph) {
282   for (auto input : graph->inputs()) {
283     if (input->isCompleteTensor()) {
284       _intermediate_tensors.insert(input);
285       _inputs.push_back(input);
286     }
287   }
288 
289   DepthFirstGraphNodeIterator it(graph);
290   Node* n = nullptr;
291   while ((n = it.next()) != nullptr) {
292     gatherNodeInputs(*n);
293   }
294 
295   for (auto output : graph->outputs()) {
296     if (output->isCompleteTensor()) {
297       _intermediate_tensors.insert(output);
298       _outputs.push_back(output);
299     }
300   }
301 }
302 
gatherNodeInputs(torch::jit::Node & node)303 void XNNGraph::gatherNodeInputs(torch::jit::Node& node) {
304   switch (node.kind()) {
305     case aten::add: {
306       // this case will support all ops with only two inputs i.e. sub, add,
307       for (auto value : node.inputs()) {
308         if (value->isCompleteTensor()) {
309           _intermediate_tensors.insert(value);
310         }
311       }
312     }
313   }
314 }
315 
isGraphInput(torch::jit::Value * val)316 bool XNNGraph::isGraphInput(torch::jit::Value* val) {
317   return std::count(_inputs.begin(), _inputs.end(), val) > 0;
318 };
319 
isGraphOutput(torch::jit::Value * val)320 bool XNNGraph::isGraphOutput(torch::jit::Value* val) {
321   return std::count(_outputs.begin(), _outputs.end(), val) > 0;
322 };
323 
324 } // namespace delegate
325 } // namespace xnnpack
326 } // namespace jit
327 } // namespace torch
328