1 // Copyright (c) Meta Platforms, Inc. and affiliates.
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <caffe2/torch/csrc/jit/backends/xnnpack/xnnpack_graph_builder.h>
7 #include <torch/csrc/jit/runtime/graph_iterator.h>
8 #include <xnnpack.h>
9
10 // graph passes
11 #include <torch/csrc/jit/passes/constant_propagation.h>
12 #include <torch/csrc/jit/passes/dead_code_elimination.h>
13 #include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
14 #include <torch/csrc/jit/passes/lower_tuples.h>
15 #include <torch/csrc/jit/passes/remove_mutation.h>
16 #include <torch/csrc/jit/runtime/jit_trace.h>
17 #include <torch/csrc/jit/tensorexpr/graph_opt.h>
18
19 namespace torch {
20 namespace jit {
21 namespace xnnpack {
22 namespace delegate {
23
optimizeAndTraceGraph(std::shared_ptr<torch::jit::Graph> graph,std::vector<c10::IValue> & example_inputs)24 std::shared_ptr<torch::jit::Graph> XNNGraph::optimizeAndTraceGraph(
25 std::shared_ptr<torch::jit::Graph> graph,
26 std::vector<c10::IValue>& example_inputs) {
27 OptimizeFrozenGraph(graph, true);
28 RemoveListMutation(graph);
29 RemoveTensorMutation(graph);
30 LowerAllTuples(graph);
31 ConstantPropagation(graph);
32 graph = TraceGraph(graph, example_inputs);
33
34 return graph;
35 }
36
buildXNNGraph(std::shared_ptr<torch::jit::Graph> & graph,std::vector<c10::IValue> example_inputs)37 void XNNGraph::buildXNNGraph(
38 std::shared_ptr<torch::jit::Graph>& graph,
39 std::vector<c10::IValue> example_inputs) {
40 graph = optimizeAndTraceGraph(graph, example_inputs);
41 checkOpsToDelegate(graph);
42 gatherTensorValues(graph);
43
44 // count unique input/outputs (some inputs can be outputs)
45 std::unordered_set<torch::jit::Value*> externals;
46 for (auto inp : _inputs) {
47 externals.insert(inp);
48 }
49 for (auto out : _outputs) {
50 externals.insert(out);
51 }
52
53 // create subgraph
54 xnn_status status = xnn_create_subgraph(
55 /*external_value_ids=*/externals.size(),
56 /*flags=*/0,
57 &_subgraph_ptr);
58 TORCH_CHECK(xnn_status_success == status, "Failed to create xnn subgraph");
59
60 defineAllTensorValues();
61 defineAllNodes(graph);
62 // at this point graph is complete, for the sake of testing preprocess at
63 // this point we will do runtime setup and run with some default values
64 }
65
runGraphOnInputs(std::vector<at::Tensor> tensor_inputs,std::vector<at::Tensor> tensor_outputs)66 void XNNGraph::runGraphOnInputs(
67 std::vector<at::Tensor> tensor_inputs,
68 std::vector<at::Tensor> tensor_outputs) {
69 TORCH_CHECK(
70 _subgraph_ptr != nullptr,
71 "run buildXNNGraph before running graph on inputs");
72 xnn_runtime_t runtime = nullptr;
73 xnn_status status =
74 xnn_create_runtime_v2(_subgraph_ptr, nullptr, /*flags=*/0, &runtime);
75 TORCH_CHECK(
76 xnn_status_success == status,
77 "failed to create runtime for running inputs");
78
79 // smart pointer for runtime
80 std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> auto_runtime(
81 runtime, xnn_delete_runtime);
82
83 std::vector<xnn_external_value> external_values;
84 TORCH_CHECK(
85 tensor_inputs.size() == _inputs.size(),
86 "supplied inputs does not match expected inputs");
87 for (int i = 0; i < tensor_inputs.size(); i++) {
88 external_values.push_back(
89 {_val_to_ids[_inputs[i]], tensor_inputs[i].data_ptr<float>()});
90 }
91
92 TORCH_CHECK(
93 tensor_outputs.size() == _outputs.size(),
94 "supplied outputs does not match expected outputs");
95 for (int i = 0; i < tensor_outputs.size(); i++) {
96 external_values.push_back(
97 {_val_to_ids[_outputs[i]], tensor_outputs[i].data_ptr<float>()});
98 }
99 status = xnn_setup_runtime(
100 auto_runtime.get(), external_values.size(), external_values.data());
101 TORCH_CHECK(xnn_status_success == status, "runtime not properly setup");
102
103 TORCH_CHECK(xnn_status_success == xnn_invoke_runtime(auto_runtime.get()));
104 }
105
checkOpsToDelegate(std::shared_ptr<torch::jit::Graph> & graph)106 void XNNGraph::checkOpsToDelegate(std::shared_ptr<torch::jit::Graph>& graph) {
107 std::unordered_set<string> unsupported_ops;
108 DepthFirstGraphNodeIterator it(graph);
109 Node* node = nullptr;
110 while ((node = it.next()) != nullptr) {
111 switch (node->kind()) {
112 case prim::Constant:
113 case aten::add: {
114 break;
115 }
116 default: {
117 unsupported_ops.insert(node->kind().toDisplayString());
118 }
119 }
120 }
121 std::stringstream error;
122 for (auto itr = unsupported_ops.begin(); itr != unsupported_ops.end();
123 itr++) {
124 error << *itr << std::endl;
125 ;
126 }
127 TORCH_CHECK(
128 unsupported_ops.empty(),
129 "the module contains the following unsupported ops:\n" + error.str());
130 }
131
serializedXNNGraph()132 std::string XNNGraph::serializedXNNGraph() {
133 std::vector<uint32_t> input_ids;
134 std::vector<uint32_t> output_ids;
135 std::unordered_set<uint32_t> num_externs;
136
137 for (auto val : _inputs) {
138 input_ids.push_back(_val_to_ids[val]);
139 num_externs.emplace(_val_to_ids[val]);
140 }
141
142 for (auto val : _outputs) {
143 output_ids.push_back(_val_to_ids[val]);
144 num_externs.emplace(_val_to_ids[val]);
145 }
146
147 return _serializer.finishAndSerialize(
148 input_ids, output_ids, num_externs.size());
149 }
150
getGraphOutputShapes()151 std::vector<std::vector<long>> XNNGraph::getGraphOutputShapes() {
152 std::vector<std::vector<long>> output_shapes;
153 for (auto val : _outputs) {
154 auto tensor_ptr = val->type()->cast<TensorType>();
155 std::vector<long> sizes = tensor_ptr->sizes().concrete_sizes().value();
156 output_shapes.push_back(sizes);
157 }
158
159 return output_shapes;
160 }
161
defineAllNodes(std::shared_ptr<torch::jit::Graph> & graph)162 void XNNGraph::defineAllNodes(std::shared_ptr<torch::jit::Graph>& graph) {
163 DepthFirstGraphNodeIterator it(graph);
164 Node* node = nullptr;
165 while ((node = it.next()) != nullptr) {
166 switch (node->kind()) {
167 case prim::Constant: {
168 break;
169 }
170 case aten::add: {
171 // todo: handle alpha for aten::add
172 uint32_t input1_id = _val_to_ids[node->inputs()[0]];
173 uint32_t input2_id = _val_to_ids[node->inputs()[1]];
174 TORCH_CHECK(
175 node->inputs()[2]->type()->cast<IntType>() == 1,
176 "non-1 alpha values not supported");
177 uint32_t output_id = _val_to_ids[node->outputs()[0]];
178
179 xnn_status status = xnn_define_add2(
180 _subgraph_ptr,
181 output_min,
182 output_max,
183 input1_id,
184 input2_id,
185 output_id,
186 /*flags=*/0);
187 _serializer.serializeAddNode(input1_id, input2_id, output_id, 0);
188 TORCH_CHECK(status == xnn_status_success, "failed to create add node");
189 break;
190 }
191 default: {
192 throw std::exception();
193 TORCH_CHECK(
194 false,
195 "The node of ",
196 node->kind().toQualString(),
197 " is not supported yet");
198 break;
199 }
200 }
201 }
202 }
203
defineAllTensorValues()204 void XNNGraph::defineAllTensorValues() {
205 uint32_t external_id =
206 std::numeric_limits<decltype(XNN_INVALID_VALUE_ID)>::min();
207 for (auto val : _intermediate_tensors) {
208 if (_val_to_ids.find(val) == _val_to_ids.end()) {
209 uint32_t id = XNN_INVALID_VALUE_ID;
210
211 // cast value to tensortype
212 auto tensor_ptr = val->type()->cast<TensorType>();
213 auto num_dims = tensor_ptr->dim().value();
214
215 // create size_t* for tensor shape, casting must be done from long ->
216 // size_t
217 std::vector<long> sizes = tensor_ptr->sizes().concrete_sizes().value();
218 std::vector<size_t> tensor_shape;
219 tensor_shape.reserve(sizes.size());
220 for (auto dim : sizes) {
221 TORCH_CHECK(dim >= 0, "Input Dims should be unsigned");
222 tensor_shape.push_back(static_cast<size_t>(dim));
223 }
224
225 // ext_id value
226 uint32_t ext_id = XNN_INVALID_VALUE_ID;
227
228 // update flag for if tensor is either graph input/output
229 uint32_t flags = 0;
230
231 // Check if value was produced by prim::Constant
232 void* value_data = nullptr;
233 size_t buffer_idx = 0;
234 size_t num_bytes = 0;
235 if (val->node()->kind() == prim::Constant) {
236 std::optional<IValue> constant = val->node()->t(attr::value);
237 auto const_val = constant->toIValue().toTensor();
238 // Need tensor data to be contiguous for serialization
239 auto cont_const_val = const_val.contiguous();
240 value_data = cont_const_val.data_ptr();
241
242 num_bytes = const_val.storage().nbytes();
243 buffer_idx = _serializer.serializeData(
244 static_cast<const uint8_t*>(value_data), num_bytes);
245 }
246
247 if (isGraphInput(val) || isGraphOutput(val)) {
248 if (isGraphInput(val)) {
249 flags |= XNN_VALUE_FLAG_EXTERNAL_INPUT;
250 }
251 if (isGraphOutput(val)) {
252 flags |= XNN_VALUE_FLAG_EXTERNAL_OUTPUT;
253 }
254 ext_id = external_id++;
255 }
256 xnn_status status = xnn_define_tensor_value(
257 /*subgraph=*/_subgraph_ptr,
258 /*datatype=*/xnn_datatype_fp32,
259 /*num_dims=*/num_dims,
260 /*dims=*/tensor_shape.data(),
261 /*data=*/value_data,
262 /*external_id=*/ext_id,
263 /*flags=*/flags,
264 /*id_out=*/&id);
265 TORCH_CHECK(
266 status == xnn_status_success,
267 "failed to define xnn_tensor_id for: " + val->debugName());
268 _serializer.serializeTensorValue(
269 xnn_datatype_fp32,
270 num_dims,
271 tensor_shape,
272 buffer_idx,
273 ext_id,
274 flags,
275 id);
276 _val_to_ids.insert({val, id});
277 }
278 }
279 }
280
gatherTensorValues(std::shared_ptr<torch::jit::Graph> & graph)281 void XNNGraph::gatherTensorValues(std::shared_ptr<torch::jit::Graph>& graph) {
282 for (auto input : graph->inputs()) {
283 if (input->isCompleteTensor()) {
284 _intermediate_tensors.insert(input);
285 _inputs.push_back(input);
286 }
287 }
288
289 DepthFirstGraphNodeIterator it(graph);
290 Node* n = nullptr;
291 while ((n = it.next()) != nullptr) {
292 gatherNodeInputs(*n);
293 }
294
295 for (auto output : graph->outputs()) {
296 if (output->isCompleteTensor()) {
297 _intermediate_tensors.insert(output);
298 _outputs.push_back(output);
299 }
300 }
301 }
302
gatherNodeInputs(torch::jit::Node & node)303 void XNNGraph::gatherNodeInputs(torch::jit::Node& node) {
304 switch (node.kind()) {
305 case aten::add: {
306 // this case will support all ops with only two inputs i.e. sub, add,
307 for (auto value : node.inputs()) {
308 if (value->isCompleteTensor()) {
309 _intermediate_tensors.insert(value);
310 }
311 }
312 }
313 }
314 }
315
isGraphInput(torch::jit::Value * val)316 bool XNNGraph::isGraphInput(torch::jit::Value* val) {
317 return std::count(_inputs.begin(), _inputs.end(), val) > 0;
318 };
319
isGraphOutput(torch::jit::Value * val)320 bool XNNGraph::isGraphOutput(torch::jit::Value* val) {
321 return std::count(_outputs.begin(), _outputs.end(), val) > 0;
322 };
323
324 } // namespace delegate
325 } // namespace xnnpack
326 } // namespace jit
327 } // namespace torch
328