1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/backends/xnnpack/runtime/XNNCompiler.h>
10 #include <executorch/backends/xnnpack/runtime/XNNHeader.h>
11 #include <executorch/backends/xnnpack/serialization/schema_generated.h>
12 #include <executorch/extension/threadpool/threadpool.h>
13 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
14 #include <unordered_map>
15
16 #pragma clang diagnostic ignored "-Wmissing-prototypes"
17 #pragma clang diagnostic ignored "-Wglobal-constructors"
18
19 namespace executorch {
20 namespace backends {
21 namespace xnnpack {
22 namespace delegate {
23
24 using executorch::runtime::Error;
25 using executorch::runtime::MemoryAllocator;
26 using executorch::runtime::Result;
27
28 /*
29 * Provide compile-time allocation.
30 */
31 class CompileAllocator {
32 public:
33 /*
34 * Allocate memory which will be automatically freed at the end
35 * of the compilation process.
36 */
allocateTemporary(size_t size)37 void* allocateTemporary(size_t size) {
38 auto mem = new uint8_t[size];
39 temporaries_.emplace_back(mem);
40 return mem;
41 }
42
43 private:
44 std::vector<std::unique_ptr<uint8_t[]>> temporaries_;
45 };
46
47 // Flatbuffer types
48 using ValuePtr = const fb_xnnpack::XValue*;
49 using NodePtr = const fb_xnnpack::XNode*;
50 using GraphPtr = const fb_xnnpack::XNNGraph*;
51 using DataType = fb_xnnpack::XNNDatatype;
52
53 // Type for define node function. This is the function signature
54 // for any function that takes in a flatbuffer node and defines it
55 // into our xnn_subgraph
56 using DefineNodeFunc = Error (*)(
57 xnn_subgraph_t,
58 const std::unordered_map<uint32_t, uint32_t>&,
59 NodePtr,
60 const fb_xnnpack::XNNGraph*) noexcept;
61
62 /*
63 Convert a tensor from fp32 to bf16.
64 */
convertF32TensorToBF16(const float * f32_data,uint16_t * bf16_data_out,size_t numel)65 void convertF32TensorToBF16(
66 const float* f32_data,
67 uint16_t* bf16_data_out,
68 size_t numel) {
69 for (auto i = 0u; i < numel; i++) {
70 // Adjust the f32 value such that it rounds properly after truncation.
71 // Constant factor scales 1+2^-8 to 1+2e-7.
72 float f32_adjusted = f32_data[i] * 1.00389105f;
73 uint32_t f32_bits;
74 memcpy(&f32_bits, &f32_adjusted, sizeof(float));
75 bf16_data_out[i] = static_cast<uint16_t>(f32_bits >> 16);
76 }
77 }
78
79 /*
80 Gets the output min and output max for a given node operator
81 */
getOutputMinMax(const NodePtr node)82 std::pair<float, float> getOutputMinMax(const NodePtr node) noexcept {
83 float output_min = -std::numeric_limits<float>::infinity();
84 float output_max = std::numeric_limits<float>::infinity();
85 auto output_min_max = node->output_min_max();
86 if (output_min_max != nullptr) {
87 output_min = output_min_max->output_min();
88 output_max = output_min_max->output_max();
89 }
90
91 return {output_min, output_max};
92 }
93
94 /*
95 Converts flatbuffer xnn data type to xnnpack data type
96 */
getDataType(const DataType & data_type)97 xnn_datatype getDataType(const DataType& data_type) {
98 switch (data_type) {
99 case DataType::xnn_datatype_fp32:
100 return xnn_datatype::xnn_datatype_fp32;
101 case DataType::xnn_datatype_fp16:
102 return xnn_datatype::xnn_datatype_fp16;
103 case DataType::xnn_datatype_qint8:
104 return xnn_datatype::xnn_datatype_qint8;
105 case DataType::xnn_datatype_quint8:
106 return xnn_datatype::xnn_datatype_quint8;
107 case DataType::xnn_datatype_qint32:
108 return xnn_datatype::xnn_datatype_qint32;
109 case DataType::xnn_datatype_qcint8:
110 return xnn_datatype::xnn_datatype_qcint8;
111 case DataType::xnn_datatype_qcint32:
112 return xnn_datatype::xnn_datatype_qcint32;
113 case DataType::xnn_datatype_qcint4:
114 return xnn_datatype::xnn_datatype_qcint4;
115 case DataType::xnn_datatype_qdint8:
116 return xnn_datatype::xnn_datatype_qdint8;
117 case DataType::xnn_datatype_qbint4:
118 return xnn_datatype::xnn_datatype_qbint4;
119 default:
120 return xnn_datatype::xnn_datatype_invalid;
121 }
122 }
123
isQuantizedDataType(const xnn_datatype data_type)124 bool isQuantizedDataType(const xnn_datatype data_type) {
125 switch (data_type) {
126 case xnn_datatype::xnn_datatype_qint8:
127 case xnn_datatype::xnn_datatype_quint8:
128 case xnn_datatype::xnn_datatype_qint32:
129 case xnn_datatype::xnn_datatype_qcint8:
130 case xnn_datatype::xnn_datatype_qcint32:
131 case xnn_datatype::xnn_datatype_qcint4:
132 case xnn_datatype::xnn_datatype_qdint8:
133 return true;
134 default:
135 return false;
136 }
137 }
138
139 /**
140 Converts dims from uint32 to size_t. Takes in a flatbuffer vector
141 of uint32_t and returns a std::vector of size_t. XNNPACK takes in
142 dims of size_t* but tensor shape is serialized in flatbuffer as
143 int32_t. As a result, we need to static cast the shapes to size_t
144 */
145 template <typename T = size_t>
flatbufferDimsToVector(const flatbuffers::Vector<uint32_t> * fb_dims)146 std::vector<T> flatbufferDimsToVector(
147 const flatbuffers::Vector<uint32_t>* fb_dims) {
148 std::vector<T> dims_data;
149 dims_data.reserve(fb_dims->size());
150 for (auto fb_dim : *fb_dims) {
151 dims_data.push_back(static_cast<T>(fb_dim));
152 }
153 return dims_data;
154 }
155
156 /**
157 Gets the constant data pointer associated with the given tensor value.
158 Obtaining the constant data pointer can either be from within the flatbuffer
159 payload (deprecated) or via offsets to the constant_data_ptr. If no constant
160 data associated with the tensor value, then returns nullptr.
161 */
getConstantDataPtr(const fb_xnnpack::XNNTensorValue * tensor_value,GraphPtr flatbuffer_graph,const uint8_t * constant_data_ptr)162 const uint8_t* getConstantDataPtr(
163 const fb_xnnpack::XNNTensorValue* tensor_value,
164 GraphPtr flatbuffer_graph,
165 const uint8_t* constant_data_ptr) {
166 auto buffer_idx = tensor_value->constant_buffer_idx();
167 if (buffer_idx) {
168 if (!constant_data_ptr) {
169 // TODO(T172265611): Remove constant_buffer in flatbuffer path after BC
170 // window
171 const auto& constant_buffer = *flatbuffer_graph->constant_buffer();
172 return constant_buffer[buffer_idx]->storage()->data();
173 } else {
174 const auto& constant_data_offsets = *flatbuffer_graph->constant_data();
175 uint64_t constant_data_offset =
176 constant_data_offsets[buffer_idx]->offset();
177 return constant_data_ptr + constant_data_offset;
178 }
179 }
180
181 return nullptr;
182 }
183
184 /**
185 Define serialized tensor value into
186 the subgraph. While also keeping track of the remapped ids from
187 the serialized id to the newly generated id.
188 */
defineTensor(xnn_subgraph_t subgraph_ptr,std::unordered_map<uint32_t,uint32_t> & remapped_ids,ValuePtr value,GraphPtr flatbuffer_graph,const uint8_t * constant_data_ptr,std::vector<uint32_t> & input_ids,std::vector<uint32_t> & output_ids,CompileAllocator & allocator)189 Error defineTensor(
190 xnn_subgraph_t subgraph_ptr,
191 std::unordered_map<uint32_t, uint32_t>& remapped_ids,
192 ValuePtr value,
193 GraphPtr flatbuffer_graph,
194 const uint8_t* constant_data_ptr,
195 std::vector<uint32_t>& input_ids,
196 std::vector<uint32_t>& output_ids,
197 CompileAllocator& allocator) {
198 const fb_xnnpack::XNNTensorValue* tensor_value = nullptr;
199 const fb_xnnpack::XNNQuantizedTensorValue* qtensor_value = nullptr;
200
201 switch (value->xvalue_union_type()) {
202 case fb_xnnpack::XValueUnion::XNNTensorValue: {
203 tensor_value = value->xvalue_union_as_XNNTensorValue();
204 break;
205 }
206 case fb_xnnpack::XValueUnion::XNNQuantizedTensorValue: {
207 qtensor_value = value->xvalue_union_as_XNNQuantizedTensorValue();
208 tensor_value = qtensor_value->tensor_value();
209 break;
210 }
211 default: {
212 ET_CHECK_OR_RETURN_ERROR(
213 false,
214 NotImplemented,
215 "Unhandled value type: %s",
216 fb_xnnpack::EnumNameXValueUnion(value->xvalue_union_type()));
217 }
218 }
219
220 ET_CHECK_OR_RETURN_ERROR(
221 tensor_value != nullptr,
222 Internal,
223 "Deserialized Tensor is Null, this should never happen");
224
225 // Get tensor dims, here we need to use a vector in order
226 // to properly convert the uint32_t* to size_t*
227 std::vector<size_t> dims_data = flatbufferDimsToVector(tensor_value->dims());
228
229 // XNNPACK Id
230 uint32_t id = XNN_INVALID_VALUE_ID;
231
232 // Get Pointer to constant data from flatbuffer, if its non-constant
233 // it is a nullptr
234 const uint8_t* buffer_ptr =
235 getConstantDataPtr(tensor_value, flatbuffer_graph, constant_data_ptr);
236
237 xnn_status status;
238 // The type we might have to convert to
239 auto dq_datatype = getDataType(tensor_value->dq_datatype());
240
241 if (dq_datatype != xnn_datatype::xnn_datatype_invalid) {
242 if (dq_datatype != xnn_datatype::xnn_datatype_qint8) {
243 ET_CHECK_OR_RETURN_ERROR(
244 false,
245 Internal,
246 "Only int8_t is supported for dq_datatype for now, got: %d",
247 dq_datatype);
248 } else {
249 ET_CHECK_OR_RETURN_ERROR(
250 (tensor_value->flags() & XNN_VALUE_FLAG_EXTERNAL_INPUT),
251 Internal,
252 "Dynamic quantization of tensor is only allowed for the external input tensor value for now! got flags: %u",
253 tensor_value->flags());
254 }
255 }
256
257 if (qtensor_value == nullptr) {
258 // FP32 tensor
259 if (!isQuantizedDataType(dq_datatype)) {
260 // Define non-quantied tensor
261 status = xnn_define_tensor_value(
262 /*subgraph=*/subgraph_ptr,
263 /*datatype=*/getDataType(tensor_value->datatype()),
264 /*num_dims=*/tensor_value->num_dims(),
265 /*dims=*/dims_data.data(),
266 /*data=*/buffer_ptr,
267 /*external_id=*/tensor_value->external_id(),
268 /*flags=*/tensor_value->flags(),
269 /*id_out=*/&id);
270 } else if (dq_datatype != xnn_datatype::xnn_datatype_invalid) {
271 ET_CHECK_OR_RETURN_ERROR(
272 isQuantizedDataType(dq_datatype),
273 Internal,
274 "Dynamic quantization can only produce supported quantized dtypes");
275 ET_CHECK_OR_RETURN_ERROR(
276 tensor_value->external_id() != XNN_INVALID_VALUE_ID,
277 Internal,
278 "Dynamic quantization can only work with external inputs for now, got an internal ID");
279 ET_CHECK_OR_RETURN_ERROR(
280 buffer_ptr == nullptr,
281 Internal,
282 "Dynamic quantization can only work with external inputs for now, got const data");
283
284 switch (dq_datatype) {
285 case xnn_datatype::xnn_datatype_qint8: {
286 // HACK TO Maintain FC/BC for ASR this will be removed after 01/2024
287
288 // When encountering a dynamically quantized tensor via dq_datatype,
289 // which is the old flow for serializing dynamically quantized linear.
290 // We replace the definition of a single tensor with a new dynamic
291 // Quantization pattern. We change the pattern from:
292 // serialized_qd_input
293 // to
294 // (fp32_input --> convert --> qdint8_input)
295
296 status = xnn_define_dynamically_quantized_tensor_value(
297 /*subgraph=*/subgraph_ptr,
298 /*datatype=*/xnn_datatype_qdint8,
299 /*num_dims=*/tensor_value->num_dims(),
300 /*num_nonbatch_dims=*/1, // always do per token quantization
301 /*dims=*/dims_data.data(),
302 /*external_id=*/XNN_INVALID_VALUE_ID, // always internal value id
303 /*flags=*/0, // this is netiher external input or output
304 /*id_out=*/&id);
305
306 // this is the FP16 or FP32 external value that is being dynamically
307 // quantized
308 uint32_t float_id;
309 enum xnn_datatype fp_datatype = getDataType(tensor_value->datatype());
310 status = xnn_define_tensor_value(
311 /*subgraph=*/subgraph_ptr,
312 /*datatype=*/fp_datatype,
313 /*num_dims=*/tensor_value->num_dims(),
314 /*dims=*/dims_data.data(),
315 /*data=*/buffer_ptr,
316 /*external_id=*/tensor_value->external_id(),
317 /*flags=*/tensor_value->flags(),
318 /*id_out=*/&float_id);
319
320 // Define dynamic conversion from float to qdint8
321 status = xnn_define_convert(
322 /*subgraph=*/subgraph_ptr,
323 /*input_id=*/float_id,
324 /*output_id=*/id,
325 /*flags=*/0);
326 break;
327 }
328 default:
329 ET_CHECK_OR_RETURN_ERROR(
330 false,
331 NotImplemented,
332 "Unhandled Dyanmic Quantization dtype: %d",
333 dq_datatype);
334 }
335 } else {
336 ET_CHECK_OR_RETURN_ERROR(false, NotImplemented, "Unhandled fp32 tensor");
337 }
338 } else {
339 // define tensor for quantized
340 switch (qtensor_value->quant_params_type()) {
341 case fb_xnnpack::XNNQuantParams::PerTensorQuant: {
342 auto qparams = qtensor_value->quant_params_as_PerTensorQuant();
343 ET_LOG(
344 Debug,
345 "define quant tensor (per tensor): buffer_ptr: %p, scale: %f, zp: %u\n",
346 buffer_ptr,
347 qparams->scale(),
348 qparams->zero_point());
349 status = xnn_define_quantized_tensor_value(
350 /*subgraph=*/subgraph_ptr,
351 /*datatype=*/getDataType(tensor_value->datatype()),
352 /*zero_point=*/qparams->zero_point(),
353 /*scale=*/qparams->scale(),
354 /*num_dims=*/tensor_value->num_dims(),
355 /*dims=*/dims_data.data(),
356 /*data=*/buffer_ptr,
357 /*external_id=*/tensor_value->external_id(),
358 /*flags=*/tensor_value->flags(),
359 /*id_out=*/&id);
360 break;
361 }
362 case fb_xnnpack::XNNQuantParams::PerChannelQuant: {
363 auto qparams = qtensor_value->quant_params_as_PerChannelQuant();
364 enum xnn_datatype dtype = getDataType(tensor_value->datatype());
365 int32_t zero_point =
366 (dtype == xnn_datatype::xnn_datatype_qcint4 ? 8 : 0);
367
368 ET_LOG(
369 Debug,
370 "define quant tensor (per channel): buffer_ptr: %p, scale.numel(): %u, channel_dim: %u, dtype: %u, zero_point: %d\n",
371 buffer_ptr,
372 qparams->scale()->size(),
373 qparams->channel_dim(),
374 dtype,
375 zero_point);
376 status = xnn_define_channelwise_quantized_tensor_value_v2(
377 /*subgraph=*/subgraph_ptr,
378 /*datatype=*/dtype,
379 /*zero_point=*/zero_point,
380 /*scale=*/qparams->scale()->data(),
381 /*num_dims=*/tensor_value->num_dims(),
382 /*channel_dim*/ qparams->channel_dim(),
383 /*dims=*/dims_data.data(),
384 /*data=*/buffer_ptr,
385 /*external_id=*/tensor_value->external_id(),
386 /*flags=*/tensor_value->flags(),
387 /*id_out=*/&id);
388 break;
389 }
390 case fb_xnnpack::XNNQuantParams::PerChannelGroupQuant: {
391 xnn_datatype datatype = getDataType(tensor_value->datatype());
392 ET_CHECK_OR_RETURN_ERROR(
393 datatype == xnn_datatype::xnn_datatype_qbint4,
394 Internal,
395 "Unsupported datatype for per channel group quantization: %d",
396 datatype);
397 auto qparams = qtensor_value->quant_params_as_PerChannelGroupQuant();
398 size_t group_size = qparams->group_size();
399 size_t output_channels = tensor_value->dims()->Get(0);
400 size_t input_channels = tensor_value->dims()->Get(1);
401
402 const uint16_t* scale_data = nullptr;
403 uint32_t scale_numel = 0;
404
405 // Block scales are preferably serialized as bf16 but can also be
406 // serialized as fp32 for backwards compatability.
407 if (qparams->scale_bf16() != nullptr) {
408 scale_data =
409 static_cast<const uint16_t*>(qparams->scale_bf16()->data());
410 scale_numel = qparams->scale_bf16()->size();
411 } else {
412 // Read fp32 scales, convert to bf16.
413 auto conv_buffer = static_cast<uint16_t*>(allocator.allocateTemporary(
414 qparams->scale()->size() * sizeof(uint16_t)));
415 scale_numel = qparams->scale()->size();
416 convertF32TensorToBF16(
417 qparams->scale()->data(), conv_buffer, scale_numel);
418 scale_data = conv_buffer;
419 }
420
421 ET_CHECK_OR_RETURN_ERROR(
422 scale_numel == output_channels * input_channels / group_size,
423 Internal,
424 "scale size %zu != output channels %zu * group size %zu",
425 static_cast<size_t>(scale_numel),
426 output_channels,
427 group_size);
428 int32_t zero_point =
429 (datatype == xnn_datatype::xnn_datatype_qbint4 ? 8 : 0);
430 ET_LOG(
431 Debug,
432 "define quant tensor (per channel group): buffer_ptr: %p, scale.numel(): %u, channel_dim: %u, grpup_size: %zu, output_channels: %zu, dtype: %u, zero_point: %d, datatype: %d\n",
433 buffer_ptr,
434 scale_numel,
435 qparams->channel_dim(),
436 group_size,
437 output_channels,
438 datatype,
439 zero_point,
440 datatype);
441
442 status = xnn_define_blockwise_quantized_tensor_value(
443 /*subgraph=*/subgraph_ptr,
444 /*datatype=*/datatype,
445 /*zero_point=*/zero_point,
446 /*scale=*/scale_data,
447 /*num_dims=*/tensor_value->num_dims(),
448 /*channel_dim=*/qparams->channel_dim(),
449 /*block_size=*/qparams->group_size(),
450 /*dims=*/dims_data.data(),
451 /*data=*/buffer_ptr,
452 /*external_id=*/tensor_value->external_id(),
453 /*flags=*/tensor_value->flags(),
454 /*id_out=*/&id);
455 break;
456 }
457 case fb_xnnpack::XNNQuantParams::PerTokenDynamicQuant: {
458 auto qparams = qtensor_value->quant_params_as_PerTokenDynamicQuant();
459 ET_LOG(
460 Debug,
461 "define quant tensor (dynamic): num_dims: %i, num_nonbatch_dims: %i\n",
462 tensor_value->num_dims(),
463 qparams->num_nonbatch_dims());
464 ET_CHECK_OR_RETURN_ERROR(
465 buffer_ptr == nullptr,
466 Internal,
467 "Dynamically quantized tensor should not have constant data but found non-nullptr");
468 // TODO(T179441835): Dynamic Quantization with num_nonbatch_dims > 1
469 ET_CHECK_OR_RETURN_ERROR(
470 qparams->num_nonbatch_dims() == 1,
471 Internal,
472 "Dynamically Quantized Tensors currently only support per token quantization");
473 status = xnn_define_dynamically_quantized_tensor_value(
474 /*subgraph=*/subgraph_ptr,
475 /*datatype=*/getDataType(tensor_value->datatype()),
476 /*num_dims=*/tensor_value->num_dims(),
477 /*num_nonbatch_dims*/ qparams->num_nonbatch_dims(),
478 /*dims=*/dims_data.data(),
479 /*external_id=*/tensor_value->external_id(),
480 /*flags=*/tensor_value->flags(),
481 /*id_out=*/&id);
482 break;
483 }
484 default: {
485 ET_CHECK_OR_RETURN_ERROR(
486 false,
487 NotImplemented,
488 "Unhandled Quantization Parameters: %s",
489 fb_xnnpack::EnumNameXNNQuantParams(
490 qtensor_value->quant_params_type()));
491 }
492 }
493 }
494
495 ET_CHECK_OR_RETURN_ERROR(
496 status == xnn_status_success,
497 Internal,
498 "Failed to define tensor %i with code: %s",
499 tensor_value->id_out(),
500 xnn_status_to_string(status));
501
502 // map serialized id to newly generated id
503 remapped_ids.emplace(std::make_pair(tensor_value->id_out(), id));
504
505 // Add external ids to either list of input or output ids
506 if (tensor_value->flags() & XNN_VALUE_FLAG_EXTERNAL_INPUT) {
507 input_ids.push_back(tensor_value->external_id());
508 }
509 if (tensor_value->flags() & XNN_VALUE_FLAG_EXTERNAL_OUTPUT) {
510 output_ids.push_back(tensor_value->external_id());
511 }
512
513 return Error::Ok;
514 };
515
516 #define MAYBE_UNUSED(x) (void)(x)
517
518 /*
519 Define serialized add node into the subgraph, using the remapped ids
520 to map the serialized ids, to the new ids generated when defining
521 the tensor value
522 */
defineAddNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)523 Error defineAddNode(
524 xnn_subgraph_t subgraph_ptr,
525 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
526 const NodePtr node,
527 const fb_xnnpack::XNNGraph* graph) noexcept {
528 MAYBE_UNUSED(graph);
529
530 std::pair<float, float> min_max = getOutputMinMax(node);
531 auto graph_node = node->xnode_union_as_XNNAdd();
532 xnn_status status = xnn_define_add2(
533 subgraph_ptr,
534 min_max.first,
535 min_max.second,
536 remapped_ids.at(graph_node->input1_id()),
537 remapped_ids.at(graph_node->input2_id()),
538 remapped_ids.at(graph_node->output_id()),
539 graph_node->flags());
540 ET_CHECK_OR_RETURN_ERROR(
541 status == xnn_status_success,
542 Internal,
543 "Failed to create add node %i with code: %s",
544 node->debug_handle(),
545 xnn_status_to_string(status));
546
547 return Error::Ok;
548 };
549
550 /*
551 Define Minimum operator Node into the subgraph
552 */
defineMinimumNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)553 Error defineMinimumNode(
554 xnn_subgraph_t subgraph_ptr,
555 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
556 const NodePtr node,
557 const fb_xnnpack::XNNGraph* graph) noexcept {
558 MAYBE_UNUSED(graph);
559
560 auto graph_node = node->xnode_union_as_XNNMinimum();
561 xnn_status status = xnn_define_minimum2(
562 subgraph_ptr,
563 remapped_ids.at(graph_node->input1_id()),
564 remapped_ids.at(graph_node->input2_id()),
565 remapped_ids.at(graph_node->output_id()),
566 graph_node->flags());
567
568 ET_CHECK_OR_RETURN_ERROR(
569 status == xnn_status_success,
570 Internal,
571 "Failed to create minumum node %i with code: %s",
572 node->debug_handle(),
573 xnn_status_to_string(status));
574
575 return Error::Ok;
576 };
577
578 /*
579 Define subtract operator Node into the subgraph
580 */
defineSubtractNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)581 Error defineSubtractNode(
582 xnn_subgraph_t subgraph_ptr,
583 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
584 const NodePtr node,
585 const fb_xnnpack::XNNGraph* graph) noexcept {
586 MAYBE_UNUSED(graph);
587
588 auto graph_node = node->xnode_union_as_XNNSubtract();
589 std::pair<float, float> min_max = getOutputMinMax(node);
590 xnn_status status = xnn_define_subtract(
591 subgraph_ptr,
592 min_max.first,
593 min_max.second,
594 remapped_ids.at(graph_node->input1_id()),
595 remapped_ids.at(graph_node->input2_id()),
596 remapped_ids.at(graph_node->output_id()),
597 graph_node->flags());
598
599 ET_CHECK_OR_RETURN_ERROR(
600 status == xnn_status_success,
601 Internal,
602 "Failed to create subtract node %i with code: %s",
603 node->debug_handle(),
604 xnn_status_to_string(status));
605
606 return Error::Ok;
607 };
608
609 /*
610 Define Multiply operator Node into the subgraph
611 */
defineMultiplyNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)612 Error defineMultiplyNode(
613 xnn_subgraph_t subgraph_ptr,
614 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
615 const NodePtr node,
616 const fb_xnnpack::XNNGraph* graph) noexcept {
617 MAYBE_UNUSED(graph);
618
619 auto graph_node = node->xnode_union_as_XNNMultiply();
620 std::pair<float, float> min_max = getOutputMinMax(node);
621 xnn_status status = xnn_define_multiply2(
622 subgraph_ptr,
623 min_max.first,
624 min_max.second,
625 remapped_ids.at(graph_node->input1_id()),
626 remapped_ids.at(graph_node->input2_id()),
627 remapped_ids.at(graph_node->output_id()),
628 graph_node->flags());
629
630 ET_CHECK_OR_RETURN_ERROR(
631 status == xnn_status_success,
632 Internal,
633 "Failed to create multiply node %i with code: %s",
634 node->debug_handle(),
635 xnn_status_to_string(status));
636
637 return Error::Ok;
638 };
639
640 #ifdef ENABLE_XNNPACK_KLEIDI
isQP8(const fb_xnnpack::XNNGraph * graph,const NodePtr node)641 bool isQP8(const fb_xnnpack::XNNGraph* graph, const NodePtr node) {
642 assert(node->xnode_union_type() == fb_xnnpack::XNodeUnion::XNNConvert);
643 auto graph_node = node->xnode_union_as_XNNConvert();
644 auto cvt_output_id = graph_node->output_id();
645
646 auto check_dtype = [graph](uint32_t id, DataType dtype) -> bool {
647 assert(
648 dtype == DataType::xnn_datatype_qdint8 ||
649 dtype == DataType::xnn_datatype_qbint4);
650 for (auto value : *graph->xvalues()) {
651 if (value->xvalue_union_type() !=
652 fb_xnnpack::XValueUnion::XNNQuantizedTensorValue) {
653 continue;
654 }
655 auto tensor =
656 value->xvalue_union_as_XNNQuantizedTensorValue()->tensor_value();
657 if (tensor->id_out() == id) {
658 return tensor->datatype() == dtype;
659 }
660 }
661 return false;
662 };
663
664 // Check if the output tensor is qint8 else bail early.
665 if (!check_dtype(cvt_output_id, DataType::xnn_datatype_qdint8)) {
666 return false;
667 }
668
669 // Find if the convert output is going to the right linear node.
670 // Assuming if we can find one valid linear node, then we can use QP8
671 // for all the linear nodes consuming this convert output.
672 for (auto node : *graph->xnodes()) {
673 if (node->xnode_union_type() == fb_xnnpack::XNodeUnion::XNNFullyConnected) {
674 auto linear_node = node->xnode_union_as_XNNFullyConnected();
675 if (linear_node->input1_id() == cvt_output_id) {
676 if (check_dtype(
677 linear_node->filter_id(), DataType::xnn_datatype_qbint4)) {
678 return true;
679 }
680 }
681 }
682 }
683 return false;
684 }
685 #endif // ENABLE_XNNPACK_KLEIDI
686
687 /*
688 Define Convert operator Node into the subgraph
689 */
defineConvertNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * flatbuffer_graph)690 Error defineConvertNode(
691 xnn_subgraph_t subgraph_ptr,
692 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
693 const NodePtr node,
694 const fb_xnnpack::XNNGraph* flatbuffer_graph) noexcept {
695 MAYBE_UNUSED(flatbuffer_graph);
696 auto graph_node = node->xnode_union_as_XNNConvert();
697
698 int32_t flags = graph_node->flags();
699 #ifdef ENABLE_XNNPACK_KLEIDI
700 // This is not currently exposed at include/xnnpack.h yet once it is
701 // we can remove this runtime logic and do this ahead-of-time
702 #define XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM 0x00000100;
703 if (isQP8(flatbuffer_graph, node)) {
704 flags |= XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM;
705 ET_LOG(
706 Debug,
707 "Setting XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM flag for convert node %i",
708 node->debug_handle());
709 }
710 #endif
711
712 xnn_status status = xnn_define_convert(
713 subgraph_ptr,
714 remapped_ids.at(graph_node->input_id()),
715 remapped_ids.at(graph_node->output_id()),
716 flags);
717
718 ET_CHECK_OR_RETURN_ERROR(
719 status == xnn_status_success,
720 Internal,
721 "Failed to create convert node %i with code: %s",
722 node->debug_handle(),
723 xnn_status_to_string(status));
724
725 return Error::Ok;
726 };
727 /*
728 Define serialized linear(fully-connected) node into the subgraph using
729 the remapped ids to map the serialized ids, to the new ids generated
730 when defining the tensor values
731 */
defineFullyConnectedNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)732 Error defineFullyConnectedNode(
733 xnn_subgraph_t subgraph_ptr,
734 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
735 const NodePtr node,
736 const fb_xnnpack::XNNGraph* graph) noexcept {
737 MAYBE_UNUSED(graph);
738
739 auto graph_node = node->xnode_union_as_XNNFullyConnected();
740 std::pair<float, float> min_max = getOutputMinMax(node);
741 xnn_status status = xnn_define_fully_connected(
742 subgraph_ptr,
743 min_max.first,
744 min_max.second,
745 remapped_ids.at(graph_node->input1_id()),
746 remapped_ids.at(graph_node->filter_id()),
747 remapped_ids.at(graph_node->bias_id()),
748 remapped_ids.at(graph_node->output_id()),
749 graph_node->flags());
750 ET_CHECK_OR_RETURN_ERROR(
751 status == xnn_status_success,
752 Internal,
753 "Failed to create linear node %i, with code: %s",
754 node->debug_handle(),
755 xnn_status_to_string(status));
756
757 return Error::Ok;
758 };
759
760 /*
761 Define serialized clamp node into the subgraph, using the remapped ids
762 to map the serialized ids, to the new ids generated when defining
763 the tensor value
764 */
defineClampNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)765 Error defineClampNode(
766 xnn_subgraph_t subgraph_ptr,
767 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
768 const NodePtr node,
769 const fb_xnnpack::XNNGraph* graph) noexcept {
770 MAYBE_UNUSED(graph);
771
772 std::pair<float, float> min_max = getOutputMinMax(node);
773 auto graph_node = node->xnode_union_as_XNNClamp();
774 xnn_status status = xnn_define_clamp(
775 subgraph_ptr,
776 min_max.first,
777 min_max.second,
778 remapped_ids.at(graph_node->input_id()),
779 remapped_ids.at(graph_node->output_id()),
780 graph_node->flags());
781
782 ET_CHECK_OR_RETURN_ERROR(
783 status == xnn_status_success,
784 Internal,
785 "Failed to create hardtanh node %i with code: %s",
786 node->debug_handle(),
787 xnn_status_to_string(status));
788
789 return Error::Ok;
790 }
791
792 /*
793 Define serialized softmax node into the subgraph, using the remapped ids
794 to map the serialized ids, to the new ids generated when defining
795 the tensor value
796 */
defineSoftmaxNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)797 Error defineSoftmaxNode(
798 xnn_subgraph_t subgraph_ptr,
799 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
800 const NodePtr node,
801 const fb_xnnpack::XNNGraph* graph) noexcept {
802 MAYBE_UNUSED(graph);
803
804 auto graph_node = node->xnode_union_as_XNNSoftmax();
805 xnn_status status = xnn_define_softmax(
806 subgraph_ptr,
807 remapped_ids.at(graph_node->input_id()),
808 remapped_ids.at(graph_node->output_id()),
809 graph_node->flags());
810 ET_CHECK_OR_RETURN_ERROR(
811 status == xnn_status_success,
812 Internal,
813 "Failed to create softmax node %i with code: %s",
814 node->debug_handle(),
815 xnn_status_to_string(status));
816
817 return Error::Ok;
818 }
819
820 /*
821 Define serialized sigmoid node into the subgraph, using the remapped ids
822 to map the serialized ids, to the new ids generated when defining
823 the tensor value
824 */
defineSigmoidNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)825 Error defineSigmoidNode(
826 xnn_subgraph_t subgraph_ptr,
827 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
828 const NodePtr node,
829 const fb_xnnpack::XNNGraph* graph) noexcept {
830 MAYBE_UNUSED(graph);
831
832 auto graph_node = node->xnode_union_as_XNNSigmoid();
833 xnn_status status = xnn_define_sigmoid(
834 subgraph_ptr,
835 remapped_ids.at(graph_node->input_id()),
836 remapped_ids.at(graph_node->output_id()),
837 graph_node->flags());
838 ET_CHECK_OR_RETURN_ERROR(
839 status == xnn_status_success,
840 Internal,
841 "Failed to create sigmoid node %i with code: %s",
842 node->debug_handle(),
843 xnn_status_to_string(status));
844
845 return Error::Ok;
846 }
847
848 /*
849 Define serialized floor node into the subgraph, using the remapped ids
850 to map the serialized ids, to the new ids generated when defining
851 the tensor value
852 */
defineFloorNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)853 Error defineFloorNode(
854 xnn_subgraph_t subgraph_ptr,
855 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
856 const NodePtr node,
857 const fb_xnnpack::XNNGraph* graph) noexcept {
858 MAYBE_UNUSED(graph);
859
860 auto graph_node = node->xnode_union_as_XNNFloor();
861 xnn_status status = xnn_define_floor(
862 subgraph_ptr,
863 remapped_ids.at(graph_node->input_id()),
864 remapped_ids.at(graph_node->output_id()),
865 graph_node->flags());
866 ET_CHECK_OR_RETURN_ERROR(
867 status == xnn_status_success,
868 Internal,
869 "Failed to create floor node %i with code: %s",
870 node->debug_handle(),
871 xnn_status_to_string(status));
872
873 return Error::Ok;
874 }
875
defineGlobalAvgPooling2dNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)876 Error defineGlobalAvgPooling2dNode(
877 xnn_subgraph_t subgraph_ptr,
878 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
879 const NodePtr node,
880 const fb_xnnpack::XNNGraph* graph) noexcept {
881 MAYBE_UNUSED(graph);
882
883 auto graph_node = node->xnode_union_as_XNNGlobalAvgPooling2d();
884 std::pair<float, float> min_max = getOutputMinMax(node);
885 xnn_status status = xnn_define_global_average_pooling_2d(
886 subgraph_ptr,
887 min_max.first,
888 min_max.second,
889 remapped_ids.at(graph_node->input_id()),
890 remapped_ids.at(graph_node->output_id()),
891 graph_node->flags());
892 ET_CHECK_OR_RETURN_ERROR(
893 status == xnn_status_success,
894 Internal,
895 "Failed to create global average pooling node %i with code: %s",
896 node->debug_handle(),
897 xnn_status_to_string(status));
898
899 return Error::Ok;
900 }
901
defineAvgPooling2dNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)902 Error defineAvgPooling2dNode(
903 xnn_subgraph_t subgraph_ptr,
904 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
905 const NodePtr node,
906 const fb_xnnpack::XNNGraph* graph) noexcept {
907 MAYBE_UNUSED(graph);
908
909 auto graph_node = node->xnode_union_as_XNNAvgPooling2d();
910 std::pair<float, float> min_max = getOutputMinMax(node);
911 xnn_status status = xnn_define_average_pooling_2d(
912 subgraph_ptr,
913 graph_node->padding_top(),
914 graph_node->padding_right(),
915 graph_node->padding_bottom(),
916 graph_node->padding_left(),
917 graph_node->pooling_height(),
918 graph_node->pooling_width(),
919 graph_node->stride_height(),
920 graph_node->stride_width(),
921 min_max.first,
922 min_max.second,
923 remapped_ids.at(graph_node->input_id()),
924 remapped_ids.at(graph_node->output_id()),
925 graph_node->flags());
926 ET_CHECK_OR_RETURN_ERROR(
927 status == xnn_status_success,
928 Internal,
929 "Failed to create average pooling node %i with code: %s",
930 node->debug_handle(),
931 xnn_status_to_string(status));
932
933 return Error::Ok;
934 }
935
936 /*
937 Define serialized conv2d node into the subgraph, using the remapped ids
938 to map the serialized ids, to the new ids generated when defining the
939 tensor value
940 */
defineConv2dNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)941 Error defineConv2dNode(
942 xnn_subgraph_t subgraph_ptr,
943 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
944 const NodePtr node,
945 const fb_xnnpack::XNNGraph* graph) noexcept {
946 MAYBE_UNUSED(graph);
947
948 auto graph_node = node->xnode_union_as_XNNConv2d();
949 std::pair<float, float> min_max = getOutputMinMax(node);
950 xnn_status status = xnn_define_convolution_2d(
951 subgraph_ptr,
952 graph_node->padding_top(),
953 graph_node->padding_right(),
954 graph_node->padding_bottom(),
955 graph_node->padding_left(),
956 graph_node->kernel_height(),
957 graph_node->kernel_width(),
958 graph_node->subsampling_height(),
959 graph_node->subsampling_width(),
960 graph_node->dilation_height(),
961 graph_node->dilation_width(),
962 graph_node->groups(),
963 graph_node->group_input_channels(),
964 graph_node->group_output_channels(),
965 min_max.first,
966 min_max.second,
967 remapped_ids.at(graph_node->input1_id()),
968 remapped_ids.at(graph_node->filter_id()),
969 remapped_ids.at(graph_node->bias_id()),
970 remapped_ids.at(graph_node->output_id()),
971 graph_node->flags());
972 ET_CHECK_OR_RETURN_ERROR(
973 status == xnn_status_success,
974 Internal,
975 "Failed to create convolution node %i with code: %s",
976 node->debug_handle(),
977 xnn_status_to_string(status));
978
979 return Error::Ok;
980 }
981
982 /*
983 Define serialized maxpool2d node into the subgraph, using the remapped ids
984 to map the serialized ids, to the new ids generated when defining the
985 tensor value
986 */
defineMaxPooling2dNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)987 Error defineMaxPooling2dNode(
988 xnn_subgraph_t subgraph_ptr,
989 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
990 const NodePtr node,
991 const fb_xnnpack::XNNGraph* graph) noexcept {
992 MAYBE_UNUSED(graph);
993
994 auto graph_node = node->xnode_union_as_XNNMaxPooling2d();
995 std::pair<float, float> min_max = getOutputMinMax(node);
996 xnn_status status = xnn_define_max_pooling_2d(
997 subgraph_ptr,
998 graph_node->padding_top(),
999 graph_node->padding_right(),
1000 graph_node->padding_bottom(),
1001 graph_node->padding_left(),
1002 graph_node->pooling_height(),
1003 graph_node->pooling_width(),
1004 graph_node->stride_height(),
1005 graph_node->stride_width(),
1006 graph_node->dilation_height(),
1007 graph_node->dilation_width(),
1008 min_max.first,
1009 min_max.second,
1010 remapped_ids.at(graph_node->input_id()),
1011 remapped_ids.at(graph_node->output_id()),
1012 graph_node->flags());
1013
1014 ET_CHECK_OR_RETURN_ERROR(
1015 status == xnn_status_success,
1016 Internal,
1017 "Failed to create maxpool2d node %i with code: %s",
1018 node->debug_handle(),
1019 xnn_status_to_string(status));
1020
1021 return Error::Ok;
1022 }
1023
1024 /*
1025 Define serialized div node into the subgraph
1026 */
defineDivNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1027 Error defineDivNode(
1028 xnn_subgraph_t subgraph_ptr,
1029 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1030 const NodePtr node,
1031 const fb_xnnpack::XNNGraph* graph) noexcept {
1032 MAYBE_UNUSED(graph);
1033
1034 auto graph_node = node->xnode_union_as_XNNDiv();
1035 std::pair<float, float> min_max = getOutputMinMax(node);
1036 xnn_status status = xnn_define_divide(
1037 subgraph_ptr,
1038 min_max.first,
1039 min_max.second,
1040 remapped_ids.at(graph_node->input1_id()),
1041 remapped_ids.at(graph_node->input2_id()),
1042 remapped_ids.at(graph_node->output_id()),
1043 graph_node->flags());
1044 ET_CHECK_OR_RETURN_ERROR(
1045 status == xnn_status_success,
1046 Internal,
1047 "Failed to create div node %i with code: %s",
1048 node->debug_handle(),
1049 xnn_status_to_string(status));
1050
1051 return Error::Ok;
1052 }
1053
1054 /*
1055 Define serialized static transpose node into the subgraph, using the remapped
1056 ids to map the serialized ids, to the new ids generated when defining the
1057 tensor value
1058 */
defineStaticTransposeNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1059 Error defineStaticTransposeNode(
1060 xnn_subgraph_t subgraph_ptr,
1061 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1062 const NodePtr node,
1063 const fb_xnnpack::XNNGraph* graph) noexcept {
1064 MAYBE_UNUSED(graph);
1065
1066 auto graph_node = node->xnode_union_as_XNNStaticTranspose();
1067
1068 // Get tensor dims, we need to convert the uint32_t* to size_t*
1069 std::vector<size_t> dims_data = flatbufferDimsToVector(graph_node->perm());
1070 xnn_status status = xnn_define_static_transpose(
1071 subgraph_ptr,
1072 graph_node->num_dims(),
1073 dims_data.data(),
1074 remapped_ids.at(graph_node->input_id()),
1075 remapped_ids.at(graph_node->output_id()),
1076 graph_node->flags());
1077 ET_CHECK_OR_RETURN_ERROR(
1078 status == xnn_status_success,
1079 Internal,
1080 "Failed to create sigmoid node %i with code: %s",
1081 node->debug_handle(),
1082 xnn_status_to_string(status));
1083
1084 return Error::Ok;
1085 }
1086
1087 /*
1088 Define serialized static resize bilinear 2d node into the subgraph, using the
1089 remapped ids to map the serialized ids, to the new ids generated when defining
1090 the tensor value
1091 */
defineStaticResizeBilinear2DNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1092 Error defineStaticResizeBilinear2DNode(
1093 xnn_subgraph_t subgraph_ptr,
1094 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1095 const NodePtr node,
1096 const fb_xnnpack::XNNGraph* graph) noexcept {
1097 MAYBE_UNUSED(graph);
1098
1099 const fb_xnnpack::XNNStaticResizeBilinear2D* graph_node =
1100 node->xnode_union_as_XNNStaticResizeBilinear2D();
1101
1102 xnn_status status = xnn_define_static_resize_bilinear_2d(
1103 subgraph_ptr,
1104 graph_node->new_height(),
1105 graph_node->new_width(),
1106 remapped_ids.at(graph_node->input_id()),
1107 remapped_ids.at(graph_node->output_id()),
1108 graph_node->flags());
1109 ET_CHECK_OR_RETURN_ERROR(
1110 status == xnn_status_success,
1111 Internal,
1112 "Failed to create StaticResizeBilinear2DNode node %i with code: %s",
1113 node->debug_handle(),
1114 xnn_status_to_string(status));
1115
1116 return Error::Ok;
1117 }
1118
1119 /*
1120 Define serialized static constant pad node into the subgraph, using the
1121 remapped ids to map the serialized ids, to the new ids generated when defining
1122 the tensor value
1123 */
defineStaticConstantPadNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1124 Error defineStaticConstantPadNode(
1125 xnn_subgraph_t subgraph_ptr,
1126 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1127 const NodePtr node,
1128 const fb_xnnpack::XNNGraph* graph) noexcept {
1129 MAYBE_UNUSED(graph);
1130
1131 const fb_xnnpack::XNNStaticConstantPad* graph_node =
1132 node->xnode_union_as_XNNStaticConstantPad();
1133
1134 std::vector<size_t> pre_paddings_dims =
1135 flatbufferDimsToVector(graph_node->pre_paddings());
1136 std::vector<size_t> post_paddings_dims =
1137 flatbufferDimsToVector(graph_node->post_paddings());
1138
1139 xnn_status status = xnn_define_static_constant_pad(
1140 subgraph_ptr,
1141 pre_paddings_dims.data(),
1142 post_paddings_dims.data(),
1143 graph_node->padding_value(),
1144 remapped_ids.at(graph_node->input_id()),
1145 remapped_ids.at(graph_node->output_id()),
1146 graph_node->flags());
1147 ET_CHECK_OR_RETURN_ERROR(
1148 status == xnn_status_success,
1149 Internal,
1150 "Failed to create StaticConstantPad node %i with code: %s",
1151 node->debug_handle(),
1152 xnn_status_to_string(status));
1153
1154 return Error::Ok;
1155 }
1156
1157 /*
1158 Define serialized depthwise conv2d node into the subgraph, using the remapped
1159 ids to map the serialized ids, to the new ids generated when defining the
1160 tensor value
1161 */
defineDepthwiseConv2dNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1162 Error defineDepthwiseConv2dNode(
1163 xnn_subgraph_t subgraph_ptr,
1164 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1165 const NodePtr node,
1166 const fb_xnnpack::XNNGraph* graph) noexcept {
1167 MAYBE_UNUSED(graph);
1168
1169 auto graph_node = node->xnode_union_as_XNNDepthwiseConv2d();
1170 std::pair<float, float> min_max = getOutputMinMax(node);
1171 xnn_status status = xnn_define_depthwise_convolution_2d(
1172 subgraph_ptr,
1173 graph_node->padding_top(),
1174 graph_node->padding_right(),
1175 graph_node->padding_bottom(),
1176 graph_node->padding_left(),
1177 graph_node->kernel_height(),
1178 graph_node->kernel_width(),
1179 graph_node->subsampling_height(),
1180 graph_node->subsampling_width(),
1181 graph_node->dilation_height(),
1182 graph_node->dilation_width(),
1183 graph_node->group_output_channels() /
1184 graph_node->group_input_channels(), // depth_multiplier
1185 graph_node->groups(), // input_channels = groups for depthwise conv
1186 min_max.first,
1187 min_max.second,
1188 remapped_ids.at(graph_node->input1_id()),
1189 remapped_ids.at(graph_node->filter_id()),
1190 remapped_ids.at(graph_node->bias_id()),
1191 remapped_ids.at(graph_node->output_id()),
1192 graph_node->flags());
1193
1194 ET_CHECK_OR_RETURN_ERROR(
1195 status == xnn_status_success,
1196 Internal,
1197 "Failed to create depthwise convolution node %i with code: %s",
1198 node->debug_handle(),
1199 xnn_status_to_string(status));
1200
1201 return Error::Ok;
1202 }
1203
defineStaticReshapeNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1204 Error defineStaticReshapeNode(
1205 xnn_subgraph_t subgraph_ptr,
1206 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1207 const NodePtr node,
1208 const fb_xnnpack::XNNGraph* graph) noexcept {
1209 MAYBE_UNUSED(graph);
1210
1211 auto graph_node = node->xnode_union_as_XNNStaticReshape();
1212
1213 // Get tensor dims, we need to convert the uint32_t* to size_t*
1214 std::vector<size_t> dims_data =
1215 flatbufferDimsToVector(graph_node->new_shape());
1216 xnn_status status = xnn_define_static_reshape(
1217 subgraph_ptr,
1218 graph_node->num_dims(),
1219 dims_data.data(),
1220 remapped_ids.at(graph_node->input_id()),
1221 remapped_ids.at(graph_node->output_id()),
1222 graph_node->flags());
1223 ET_CHECK_OR_RETURN_ERROR(
1224 status == xnn_status_success,
1225 Internal,
1226 "Failed to create squeeze node %i with code: %s",
1227 node->debug_handle(),
1228 xnn_status_to_string(status));
1229
1230 return Error::Ok;
1231 }
1232
1233 /*
1234 Define serialized maxpool2d node into the subgraph, using the remapped ids
1235 to map the serialized ids, to the new ids generated when defining the
1236 tensor value
1237 */
defineArgMaxPooling2dNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1238 Error defineArgMaxPooling2dNode(
1239 xnn_subgraph_t subgraph_ptr,
1240 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1241 const NodePtr node,
1242 const fb_xnnpack::XNNGraph* graph) noexcept {
1243 MAYBE_UNUSED(graph);
1244
1245 auto graph_node = node->xnode_union_as_XNNArgMaxPooling2d();
1246
1247 xnn_status status = xnn_define_argmax_pooling_2d(
1248 subgraph_ptr,
1249 graph_node->padding_top(),
1250 graph_node->padding_right(),
1251 graph_node->padding_bottom(),
1252 graph_node->padding_left(),
1253 graph_node->pooling_height(),
1254 graph_node->pooling_width(),
1255 remapped_ids.at(graph_node->input_id()),
1256 remapped_ids.at(graph_node->output_value_id()),
1257 remapped_ids.at(graph_node->output_index_id()),
1258 graph_node->flags());
1259
1260 ET_CHECK_OR_RETURN_ERROR(
1261 status == xnn_status_success,
1262 Internal,
1263 "Failed to create argmaxpool2d node %i with code: %s",
1264 node->debug_handle(),
1265 xnn_status_to_string(status));
1266
1267 return Error::Ok;
1268 }
1269
1270 /*
1271 Define serialized square root node into the subgraph, using the remapped ids
1272 to map the serialized ids, to the new ids generated when defining the
1273 tensor value
1274 */
defineSquareRootNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1275 Error defineSquareRootNode(
1276 xnn_subgraph_t subgraph_ptr,
1277 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1278 const NodePtr node,
1279 const fb_xnnpack::XNNGraph* graph) noexcept {
1280 MAYBE_UNUSED(graph);
1281
1282 auto graph_node = node->xnode_union_as_XNNSquareRoot();
1283
1284 xnn_status status = xnn_define_square_root(
1285 subgraph_ptr,
1286 remapped_ids.at(graph_node->input_id()),
1287 remapped_ids.at(graph_node->output_id()),
1288 graph_node->flags());
1289
1290 ET_CHECK_OR_RETURN_ERROR(
1291 status == xnn_status_success,
1292 Internal,
1293 "Failed to create square root node %i with code: %s",
1294 node->debug_handle(),
1295 xnn_status_to_string(status));
1296
1297 return Error::Ok;
1298 }
1299
1300 /*
1301 Define serialized ceiling node into the subgraph, using the remapped ids
1302 to map the serialized ids, to the new ids generated when defining the
1303 tensor value
1304 */
defineCeilingNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1305 Error defineCeilingNode(
1306 xnn_subgraph_t subgraph_ptr,
1307 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1308 const NodePtr node,
1309 const fb_xnnpack::XNNGraph* graph) noexcept {
1310 MAYBE_UNUSED(graph);
1311
1312 auto graph_node = node->xnode_union_as_XNNCeiling();
1313
1314 xnn_status status = xnn_define_ceiling(
1315 subgraph_ptr,
1316 remapped_ids.at(graph_node->input_id()),
1317 remapped_ids.at(graph_node->output_id()),
1318 graph_node->flags());
1319
1320 ET_CHECK_OR_RETURN_ERROR(
1321 status == xnn_status_success,
1322 Internal,
1323 "Failed to create ceiling node %i with code: %s",
1324 node->debug_handle(),
1325 xnn_status_to_string(status));
1326
1327 return Error::Ok;
1328 }
1329
1330 /*
1331 Define serialized hardswish node into the subgraph, using the remapped ids
1332 to map the serialized ids, to the new ids generated when defining the
1333 tensor value
1334 */
defineHardswishNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1335 Error defineHardswishNode(
1336 xnn_subgraph_t subgraph_ptr,
1337 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1338 const NodePtr node,
1339 const fb_xnnpack::XNNGraph* graph) noexcept {
1340 MAYBE_UNUSED(graph);
1341
1342 auto graph_node = node->xnode_union_as_XNNHardswish();
1343
1344 xnn_status status = xnn_define_hardswish(
1345 subgraph_ptr,
1346 remapped_ids.at(graph_node->input_id()),
1347 remapped_ids.at(graph_node->output_id()),
1348 graph_node->flags());
1349
1350 ET_CHECK_OR_RETURN_ERROR(
1351 status == xnn_status_success,
1352 Internal,
1353 "Failed to create hardswish node %i with code: %s",
1354 node->debug_handle(),
1355 xnn_status_to_string(status));
1356
1357 return Error::Ok;
1358 }
1359
1360 /*
1361 Define serialized leaky relu node into the subgraph, using the remapped ids
1362 to map the serialized ids, to the new ids generated when defining the
1363 tensor value
1364 */
defineLeakyReLUNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1365 Error defineLeakyReLUNode(
1366 xnn_subgraph_t subgraph_ptr,
1367 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1368 const NodePtr node,
1369 const fb_xnnpack::XNNGraph* graph) noexcept {
1370 MAYBE_UNUSED(graph);
1371
1372 auto graph_node = node->xnode_union_as_XNNLeakyReLU();
1373
1374 xnn_status status = xnn_define_leaky_relu(
1375 subgraph_ptr,
1376 graph_node->negative_slope(),
1377 remapped_ids.at(graph_node->input_id()),
1378 remapped_ids.at(graph_node->output_id()),
1379 graph_node->flags());
1380
1381 ET_CHECK_OR_RETURN_ERROR(
1382 status == xnn_status_success,
1383 Internal,
1384 "Failed to create leaky relu node %i with code: %s",
1385 node->debug_handle(),
1386 xnn_status_to_string(status));
1387
1388 return Error::Ok;
1389 }
1390
1391 /*
1392 Define serialized maximum node into the subgraph, using the remapped ids
1393 to map the serialized ids, to the new ids generated when defining the
1394 tensor value
1395 */
defineMaximumNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1396 Error defineMaximumNode(
1397 xnn_subgraph_t subgraph_ptr,
1398 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1399 const NodePtr node,
1400 const fb_xnnpack::XNNGraph* graph) noexcept {
1401 MAYBE_UNUSED(graph);
1402
1403 auto graph_node = node->xnode_union_as_XNNMaximum();
1404
1405 xnn_status status = xnn_define_maximum2(
1406 subgraph_ptr,
1407 remapped_ids.at(graph_node->input1_id()),
1408 remapped_ids.at(graph_node->input2_id()),
1409 remapped_ids.at(graph_node->output_id()),
1410 graph_node->flags());
1411
1412 ET_CHECK_OR_RETURN_ERROR(
1413 status == xnn_status_success,
1414 Internal,
1415 "Failed to create maximum node %i with code: %s",
1416 node->debug_handle(),
1417 xnn_status_to_string(status));
1418
1419 return Error::Ok;
1420 }
1421
1422 /*
1423 Define Negate node into subgraph, using the remapped ids to map the
1424 serialized ids, to the new ids generated when defining the tensor value
1425 */
defineNegateNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1426 Error defineNegateNode(
1427 xnn_subgraph_t subgraph_ptr,
1428 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1429 const NodePtr node,
1430 const fb_xnnpack::XNNGraph* graph) noexcept {
1431 MAYBE_UNUSED(graph);
1432
1433 auto graph_node = node->xnode_union_as_XNNNegate();
1434
1435 xnn_status status = xnn_define_negate(
1436 subgraph_ptr,
1437 remapped_ids.at(graph_node->input_id()),
1438 remapped_ids.at(graph_node->output_id()),
1439 graph_node->flags());
1440
1441 ET_CHECK_OR_RETURN_ERROR(
1442 status == xnn_status_success,
1443 Internal,
1444 "Failed to create negate node %i with code: %s",
1445 node->debug_handle(),
1446 xnn_status_to_string(status));
1447
1448 return Error::Ok;
1449 }
1450
1451 /*
1452 Defines square node into subgraph using the remapped ids to map the
1453 serialized ids to the new ids generated when defining the tensor value
1454 */
defineSquareNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1455 Error defineSquareNode(
1456 xnn_subgraph_t subgraph_ptr,
1457 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1458 const NodePtr node,
1459 const fb_xnnpack::XNNGraph* graph) noexcept {
1460 MAYBE_UNUSED(graph);
1461
1462 auto graph_node = node->xnode_union_as_XNNSquare();
1463
1464 xnn_status status = xnn_define_square(
1465 subgraph_ptr,
1466 remapped_ids.at(graph_node->input_id()),
1467 remapped_ids.at(graph_node->output_id()),
1468 graph_node->flags());
1469
1470 ET_CHECK_OR_RETURN_ERROR(
1471 status == xnn_status_success,
1472 Internal,
1473 "Failed to create square node %i with code: %s",
1474 node->debug_handle(),
1475 xnn_status_to_string(status));
1476
1477 return Error::Ok;
1478 }
1479
1480 /*
1481 Defines square node into subgraph using the remapped ids to map the
1482 serialized ids to the new ids generated when defining the tensor value
1483 */
defineELUNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1484 Error defineELUNode(
1485 xnn_subgraph_t subgraph_ptr,
1486 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1487 const NodePtr node,
1488 const fb_xnnpack::XNNGraph* graph) noexcept {
1489 MAYBE_UNUSED(graph);
1490
1491 auto graph_node = node->xnode_union_as_XNNELU();
1492
1493 xnn_status status = xnn_define_elu(
1494 subgraph_ptr,
1495 graph_node->alpha(),
1496 remapped_ids.at(graph_node->input_id()),
1497 remapped_ids.at(graph_node->output_id()),
1498 graph_node->flags());
1499
1500 ET_CHECK_OR_RETURN_ERROR(
1501 status == xnn_status_success,
1502 Internal,
1503 "Failed to create ELU node %i with code: %s",
1504 node->debug_handle(),
1505 xnn_status_to_string(status));
1506
1507 return Error::Ok;
1508 }
1509
1510 /*
1511 Defines absolute value node into subgraph using the remapped ids to map the
1512 serialized ids to the new ids generated when defining the tensor value
1513 */
defineAbsNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1514 Error defineAbsNode(
1515 xnn_subgraph_t subgraph_ptr,
1516 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1517 const NodePtr node,
1518 const fb_xnnpack::XNNGraph* graph) noexcept {
1519 MAYBE_UNUSED(graph);
1520
1521 auto graph_node = node->xnode_union_as_XNNAbs();
1522
1523 xnn_status status = xnn_define_abs(
1524 subgraph_ptr,
1525 remapped_ids.at(graph_node->input_id()),
1526 remapped_ids.at(graph_node->output_id()),
1527 graph_node->flags());
1528
1529 ET_CHECK_OR_RETURN_ERROR(
1530 status == xnn_status_success,
1531 Internal,
1532 "Failed to create abs node %i with code: %s",
1533 node->debug_handle(),
1534 xnn_status_to_string(status));
1535
1536 return Error::Ok;
1537 }
1538
1539 /*
1540 Defines serialized prelu node into the subgraph,
1541 using the remapped ids to map the serialized ids,
1542 to the new ids generated when defining the tensor value
1543 */
definePReLUNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1544 Error definePReLUNode(
1545 xnn_subgraph_t subgraph_ptr,
1546 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1547 const NodePtr node,
1548 const fb_xnnpack::XNNGraph* graph) noexcept {
1549 MAYBE_UNUSED(graph);
1550
1551 auto graph_node = node->xnode_union_as_XNNPReLU();
1552
1553 xnn_status status = xnn_define_prelu(
1554 subgraph_ptr,
1555 remapped_ids.at(graph_node->input1_id()),
1556 remapped_ids.at(graph_node->input2_id()),
1557 remapped_ids.at(graph_node->output_id()),
1558 graph_node->flags());
1559
1560 ET_CHECK_OR_RETURN_ERROR(
1561 status == xnn_status_success,
1562 Internal,
1563 "Failed to create prelu node %i with code: %s",
1564 node->debug_handle(),
1565 xnn_status_to_string(status));
1566
1567 return Error::Ok;
1568 }
1569
1570 /*
1571 Defines serialized concatenate2 node into the subgraph,
1572 using the remapped ids to map the serialized ids,
1573 to the new ids generated when defining the tensor value
1574 */
defineConcatenate2Node(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1575 Error defineConcatenate2Node(
1576 xnn_subgraph_t subgraph_ptr,
1577 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1578 const NodePtr node,
1579 const fb_xnnpack::XNNGraph* graph) noexcept {
1580 MAYBE_UNUSED(graph);
1581
1582 auto graph_node = node->xnode_union_as_XNNConcatenate2();
1583
1584 xnn_status status = xnn_define_concatenate2(
1585 subgraph_ptr,
1586 graph_node->axis(),
1587 remapped_ids.at(graph_node->input1_id()),
1588 remapped_ids.at(graph_node->input2_id()),
1589 remapped_ids.at(graph_node->output_id()),
1590 graph_node->flags());
1591
1592 ET_CHECK_OR_RETURN_ERROR(
1593 status == xnn_status_success,
1594 Internal,
1595 "Failed to create cat2 node %i with code: %s",
1596 node->debug_handle(),
1597 xnn_status_to_string(status));
1598
1599 return Error::Ok;
1600 }
1601
1602 /*
1603 Defines serialized concatenate2 node into the subgraph,
1604 using the remapped ids to map the serialized ids,
1605 to the new ids generated when defining the tensor value
1606 */
defineConcatenate3Node(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1607 Error defineConcatenate3Node(
1608 xnn_subgraph_t subgraph_ptr,
1609 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1610 const NodePtr node,
1611 const fb_xnnpack::XNNGraph* graph) noexcept {
1612 MAYBE_UNUSED(graph);
1613
1614 auto graph_node = node->xnode_union_as_XNNConcatenate3();
1615
1616 xnn_status status = xnn_define_concatenate3(
1617 subgraph_ptr,
1618 graph_node->axis(),
1619 remapped_ids.at(graph_node->input1_id()),
1620 remapped_ids.at(graph_node->input2_id()),
1621 remapped_ids.at(graph_node->input3_id()),
1622 remapped_ids.at(graph_node->output_id()),
1623 graph_node->flags());
1624
1625 ET_CHECK_OR_RETURN_ERROR(
1626 status == xnn_status_success,
1627 Internal,
1628 "Failed to create cat3 node %i with code: %s",
1629 node->debug_handle(),
1630 xnn_status_to_string(status));
1631
1632 return Error::Ok;
1633 }
1634
1635 /*
1636 Defines serialized concatenate2 node into the subgraph,
1637 using the remapped ids to map the serialized ids,
1638 to the new ids generated when defining the tensor value
1639 */
defineConcatenate4Node(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1640 Error defineConcatenate4Node(
1641 xnn_subgraph_t subgraph_ptr,
1642 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1643 const NodePtr node,
1644 const fb_xnnpack::XNNGraph* graph) noexcept {
1645 MAYBE_UNUSED(graph);
1646
1647 auto graph_node = node->xnode_union_as_XNNConcatenate4();
1648
1649 xnn_status status = xnn_define_concatenate4(
1650 subgraph_ptr,
1651 graph_node->axis(),
1652 remapped_ids.at(graph_node->input1_id()),
1653 remapped_ids.at(graph_node->input2_id()),
1654 remapped_ids.at(graph_node->input3_id()),
1655 remapped_ids.at(graph_node->input4_id()),
1656 remapped_ids.at(graph_node->output_id()),
1657 graph_node->flags());
1658
1659 ET_CHECK_OR_RETURN_ERROR(
1660 status == xnn_status_success,
1661 Internal,
1662 "Failed to create cat4 node %i with code: %s",
1663 node->debug_handle(),
1664 xnn_status_to_string(status));
1665
1666 return Error::Ok;
1667 }
1668
1669 /*
1670 Defines serialized static_slice node into the subgraph,
1671 using the remapped ids to map the serialized ids,
1672 to the new ids generated when defining the tensor value
1673 */
defineStaticSliceNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1674 Error defineStaticSliceNode(
1675 xnn_subgraph_t subgraph_ptr,
1676 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1677 const NodePtr node,
1678 const fb_xnnpack::XNNGraph* graph) noexcept {
1679 MAYBE_UNUSED(graph);
1680
1681 auto graph_node = node->xnode_union_as_XNNStaticSlice();
1682
1683 std::vector<size_t> offsets = flatbufferDimsToVector(graph_node->offsets());
1684 std::vector<size_t> sizes = flatbufferDimsToVector(graph_node->sizes());
1685
1686 xnn_status status = xnn_define_static_slice(
1687 subgraph_ptr,
1688 graph_node->num_dims(),
1689 offsets.data(),
1690 sizes.data(),
1691 remapped_ids.at(graph_node->input_id()),
1692 remapped_ids.at(graph_node->output_id()),
1693 graph_node->flags());
1694
1695 ET_CHECK_OR_RETURN_ERROR(
1696 status == xnn_status_success,
1697 Internal,
1698 "Failed to create static slice node %i with code: %s",
1699 node->debug_handle(),
1700 xnn_status_to_string(status));
1701
1702 return Error::Ok;
1703 }
1704
1705 /*
1706 Defines Scaled Dot Product Attention (SDPA) node into the subgraph,
1707 using the remapped ids to map the serialized ids,
1708 to the new ids generated when defining the tensor value
1709 */
defineScaledDotProductAttentionNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1710 Error defineScaledDotProductAttentionNode(
1711 xnn_subgraph_t subgraph_ptr,
1712 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1713 const NodePtr node,
1714 const fb_xnnpack::XNNGraph* graph) noexcept {
1715 MAYBE_UNUSED(graph);
1716
1717 auto graph_node = node->xnode_union_as_XNNScaledDotProductAttention();
1718
1719 xnn_status status = xnn_define_scaled_dot_product_attention(
1720 subgraph_ptr,
1721 xnn_attention_logits_cap_type_none, // cap_type
1722 nullptr, // cap_value - not used
1723 remapped_ids.at(graph_node->query_id()),
1724 remapped_ids.at(graph_node->key_id()),
1725 remapped_ids.at(graph_node->value_id()),
1726 remapped_ids.at(graph_node->scale_id()),
1727 remapped_ids.at(graph_node->mask_id()),
1728 remapped_ids.at(graph_node->output_id()),
1729 graph_node->flags());
1730
1731 ET_CHECK_OR_RETURN_ERROR(
1732 status == xnn_status_success,
1733 Internal,
1734 "Failed to create SDPA node %i with code: %s",
1735 node->debug_handle(),
1736 xnn_status_to_string(status));
1737
1738 return Error::Ok;
1739 }
1740
1741 /*
1742 Defines batch matrix multiply node into the subgraph,
1743 using the remapped ids to map the serialized ids,
1744 to the new ids generated when defining the tensor value
1745 */
defineBatchMatrixMultiplyNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1746 Error defineBatchMatrixMultiplyNode(
1747 xnn_subgraph_t subgraph_ptr,
1748 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1749 const NodePtr node,
1750 const fb_xnnpack::XNNGraph* graph) noexcept {
1751 MAYBE_UNUSED(graph);
1752
1753 auto graph_node = node->xnode_union_as_XNNBatchMatrixMultiply();
1754
1755 xnn_status status = xnn_define_batch_matrix_multiply(
1756 subgraph_ptr,
1757 remapped_ids.at(graph_node->input1_id()),
1758 remapped_ids.at(graph_node->input2_id()),
1759 remapped_ids.at(graph_node->output_id()),
1760 graph_node->flags());
1761
1762 ET_CHECK_OR_RETURN_ERROR(
1763 status == xnn_status_success,
1764 Internal,
1765 "Failed to create BMM node %i with code: %s",
1766 node->debug_handle(),
1767 xnn_status_to_string(status));
1768
1769 return Error::Ok;
1770 }
1771
1772 /*
1773 Returns not Implemented Error code. This function is meant to be
1774 called when the compiler encountes a XNodeType from the flatbuffer
1775 that has not yet been implemented
1776 */
defineNotImplementedNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1777 Error defineNotImplementedNode(
1778 xnn_subgraph_t subgraph_ptr,
1779 const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1780 const NodePtr node,
1781 const fb_xnnpack::XNNGraph* graph) noexcept {
1782 MAYBE_UNUSED(graph);
1783
1784 ET_CHECK_OR_RETURN_ERROR(
1785 false,
1786 NotImplemented,
1787 "Unhandled node type: %s",
1788 fb_xnnpack::EnumNameXNodeUnion(node->xnode_union_type()));
1789 }
1790
1791 /*
1792 Returns the pointer to the defineNode function that handles the given
1793 XNode type
1794 */
1795 #define _DEFINE(name) \
1796 case fb_xnnpack::XNodeUnion::XNN##name: \
1797 return &define##name##Node;
1798
getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType)1799 DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
1800 switch (nodeType) {
1801 _DEFINE(Add)
1802 _DEFINE(FullyConnected)
1803 _DEFINE(Softmax)
1804 _DEFINE(Sigmoid)
1805 _DEFINE(StaticTranspose)
1806 _DEFINE(Clamp)
1807 _DEFINE(Conv2d)
1808 _DEFINE(Div)
1809 _DEFINE(StaticResizeBilinear2D)
1810 _DEFINE(StaticConstantPad)
1811 _DEFINE(AvgPooling2d)
1812 _DEFINE(Minimum)
1813 _DEFINE(DepthwiseConv2d)
1814 _DEFINE(MaxPooling2d)
1815 _DEFINE(Multiply)
1816 _DEFINE(Subtract)
1817 _DEFINE(Floor)
1818 _DEFINE(Convert)
1819 _DEFINE(GlobalAvgPooling2d)
1820 _DEFINE(StaticReshape)
1821 _DEFINE(ArgMaxPooling2d)
1822 _DEFINE(SquareRoot)
1823 _DEFINE(Ceiling)
1824 _DEFINE(Hardswish)
1825 _DEFINE(LeakyReLU)
1826 _DEFINE(Maximum)
1827 _DEFINE(Negate)
1828 _DEFINE(Square)
1829 _DEFINE(ELU)
1830 _DEFINE(Abs)
1831 _DEFINE(PReLU)
1832 _DEFINE(Concatenate2)
1833 _DEFINE(Concatenate3)
1834 _DEFINE(Concatenate4)
1835 _DEFINE(StaticSlice)
1836 _DEFINE(ScaledDotProductAttention)
1837 _DEFINE(BatchMatrixMultiply)
1838 case fb_xnnpack::XNodeUnion::NONE:
1839 default: // Adding here as a catch all, just in case
1840 return &defineNotImplementedNode;
1841 }
1842 }
1843 #undef _DEFINE
1844
1845 /*
1846 Builds the xnnpack runtime object using the buffer pointer. The buffer pointer
1847 must be a valid pointer to the serialized xnnpack object. It also fills the
1848 XNNExecutor object with the built xnn_runtime and the input/output ids.
1849 */
compileModel(const void * buffer_pointer,size_t num_bytes,XNNExecutor * executor,MemoryAllocator * runtime_allocator,xnn_workspace_t workspace)1850 ET_NODISCARD Error XNNCompiler::compileModel(
1851 const void* buffer_pointer,
1852 size_t num_bytes,
1853 XNNExecutor* executor,
1854 MemoryAllocator* runtime_allocator,
1855 xnn_workspace_t workspace) {
1856 Result<XNNHeader> header = XNNHeader::Parse(buffer_pointer, num_bytes);
1857 const uint8_t* flatbuffer_data = nullptr;
1858 const uint8_t* constant_data = nullptr;
1859 CompileAllocator compile_allocator;
1860
1861 // Header status can only either be Error::Ok or Error::NotFound
1862 if (header.ok()) {
1863 flatbuffer_data = reinterpret_cast<const uint8_t*>(buffer_pointer) +
1864 header->flatbuffer_offset;
1865 constant_data = reinterpret_cast<const uint8_t*>(buffer_pointer) +
1866 header->constant_data_offset;
1867 } else if (header.error() == Error::NotFound) {
1868 flatbuffer_data = reinterpret_cast<const uint8_t*>(buffer_pointer);
1869 } else {
1870 ET_LOG(Error, "XNNHeader may be corrupt");
1871 return header.error();
1872 }
1873
1874 // Temporarily support identifier XN00 and XN01
1875 bool is_supported_version =
1876 strncmp(flatbuffers::GetBufferIdentifier(flatbuffer_data), "XN00", 4) ==
1877 0 ||
1878 strncmp(flatbuffers::GetBufferIdentifier(flatbuffer_data), "XN01", 4) ==
1879 0;
1880 ET_CHECK_OR_RETURN_ERROR(
1881 is_supported_version,
1882 DelegateInvalidCompatibility,
1883 "XNNPACK Delegate Serialization Format version identifier '%.4s' != expected XN00 or XN01'",
1884 flatbuffers::GetBufferIdentifier(flatbuffer_data));
1885
1886 auto flatbuffer_graph = fb_xnnpack::GetXNNGraph(flatbuffer_data);
1887 // initialize xnnpack
1888 xnn_status status = xnn_initialize(/*allocator =*/nullptr);
1889 ET_CHECK_OR_RETURN_ERROR(
1890 xnn_status_success == status,
1891 Internal,
1892 "XNN Initialize failed with code: %s",
1893 xnn_status_to_string(status));
1894
1895 // create xnnpack subgraph
1896 xnn_subgraph_t subgraph_ptr = nullptr;
1897 status = xnn_create_subgraph(
1898 /*external_value_ids=*/flatbuffer_graph->num_externs(),
1899 /*flags=*/0,
1900 &subgraph_ptr);
1901 ET_CHECK_OR_RETURN_ERROR(
1902 xnn_status_success == status,
1903 Internal,
1904 "XNN Subgraph creation failed with code: %s",
1905 xnn_status_to_string(status));
1906
1907 std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> subgraph(
1908 subgraph_ptr, &xnn_delete_subgraph);
1909
1910 // mapping from old ids to new created value ids
1911 // The old ids that were serialied were generated AoT, since
1912 // we are re-defining tensor values, the defined IDs could be
1913 // different from the ones generated AoT, as a result, we need
1914 // a new mapping from the old ids to the newly created ones
1915 std::unordered_map<uint32_t, uint32_t> remapped_ids;
1916 // Invalid ids do not need to be remapped
1917 remapped_ids.emplace(XNN_INVALID_VALUE_ID, XNN_INVALID_VALUE_ID);
1918
1919 // External Ids for inputs and outputs
1920 std::vector<uint32_t> input_ids;
1921 std::vector<uint32_t> output_ids;
1922 Error err = Error::Ok;
1923 for (auto value : *flatbuffer_graph->xvalues()) {
1924 err = defineTensor(
1925 subgraph.get(),
1926 remapped_ids,
1927 value,
1928 flatbuffer_graph,
1929 constant_data,
1930 input_ids,
1931 output_ids,
1932 compile_allocator);
1933
1934 if (err != Error::Ok) {
1935 return err;
1936 }
1937 }
1938
1939 for (auto node : *flatbuffer_graph->xnodes()) {
1940 err = getDefineNodeFunc(node->xnode_union_type())(
1941 subgraph.get(), remapped_ids, node, flatbuffer_graph);
1942 if (err != Error::Ok) {
1943 return err;
1944 }
1945 }
1946 uint32_t runtime_flags = 0;
1947
1948 #if defined(ENABLE_XNNPACK_PROFILING) || defined(ET_EVENT_TRACER_ENABLED)
1949 runtime_flags |= XNN_FLAG_BASIC_PROFILING;
1950 #endif
1951
1952 xnn_runtime_t runtime_ptr = nullptr;
1953
1954 #ifdef ENABLE_XNNPACK_SHARED_WORKSPACE
1955 ET_CHECK_OR_RETURN_ERROR(
1956 workspace != nullptr, Internal, "Failed to initialize XNNPACK workspace");
1957 status = xnn_create_runtime_v4(
1958 subgraph.get(),
1959 /*weight_cache=*/nullptr, // TODO - support weight cache
1960 workspace,
1961 ::executorch::extension::threadpool::get_pthreadpool(),
1962 runtime_flags,
1963 &runtime_ptr);
1964 #else
1965 status = xnn_create_runtime_v3(
1966 subgraph.get(),
1967 /*weight_cache=*/nullptr, // TODO - support weight cache
1968 ::executorch::extension::threadpool::get_pthreadpool(),
1969 runtime_flags,
1970 &runtime_ptr);
1971 #endif
1972
1973 ET_CHECK_OR_RETURN_ERROR(
1974 xnn_status_success == status,
1975 Internal,
1976 "XNN Runtime creation failed with code: %s",
1977 xnn_status_to_string(status));
1978
1979 err = executor->initialize( // NOLINT: runtime_ptr is non-null
1980 runtime_ptr,
1981 std::move(input_ids),
1982 std::move(output_ids));
1983
1984 return err;
1985 };
1986
1987 } // namespace delegate
1988 } // namespace xnnpack
1989 } // namespace backends
1990 } // namespace executorch
1991