xref: /aosp_15_r20/external/executorch/backends/xnnpack/runtime/XNNCompiler.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/backends/xnnpack/runtime/XNNCompiler.h>
10 #include <executorch/backends/xnnpack/runtime/XNNHeader.h>
11 #include <executorch/backends/xnnpack/serialization/schema_generated.h>
12 #include <executorch/extension/threadpool/threadpool.h>
13 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
14 #include <unordered_map>
15 
16 #pragma clang diagnostic ignored "-Wmissing-prototypes"
17 #pragma clang diagnostic ignored "-Wglobal-constructors"
18 
19 namespace executorch {
20 namespace backends {
21 namespace xnnpack {
22 namespace delegate {
23 
24 using executorch::runtime::Error;
25 using executorch::runtime::MemoryAllocator;
26 using executorch::runtime::Result;
27 
28 /*
29  * Provide compile-time allocation.
30  */
31 class CompileAllocator {
32  public:
33   /*
34    * Allocate memory which will be automatically freed at the end
35    * of the compilation process.
36    */
allocateTemporary(size_t size)37   void* allocateTemporary(size_t size) {
38     auto mem = new uint8_t[size];
39     temporaries_.emplace_back(mem);
40     return mem;
41   }
42 
43  private:
44   std::vector<std::unique_ptr<uint8_t[]>> temporaries_;
45 };
46 
47 // Flatbuffer types
48 using ValuePtr = const fb_xnnpack::XValue*;
49 using NodePtr = const fb_xnnpack::XNode*;
50 using GraphPtr = const fb_xnnpack::XNNGraph*;
51 using DataType = fb_xnnpack::XNNDatatype;
52 
53 // Type for define node function. This is the function signature
54 // for any function that takes in a flatbuffer node and defines it
55 // into our xnn_subgraph
56 using DefineNodeFunc = Error (*)(
57     xnn_subgraph_t,
58     const std::unordered_map<uint32_t, uint32_t>&,
59     NodePtr,
60     const fb_xnnpack::XNNGraph*) noexcept;
61 
62 /*
63 Convert a tensor from fp32 to bf16.
64 */
convertF32TensorToBF16(const float * f32_data,uint16_t * bf16_data_out,size_t numel)65 void convertF32TensorToBF16(
66     const float* f32_data,
67     uint16_t* bf16_data_out,
68     size_t numel) {
69   for (auto i = 0u; i < numel; i++) {
70     // Adjust the f32 value such that it rounds properly after truncation.
71     // Constant factor scales 1+2^-8 to 1+2e-7.
72     float f32_adjusted = f32_data[i] * 1.00389105f;
73     uint32_t f32_bits;
74     memcpy(&f32_bits, &f32_adjusted, sizeof(float));
75     bf16_data_out[i] = static_cast<uint16_t>(f32_bits >> 16);
76   }
77 }
78 
79 /*
80 Gets the output min and output max for a given node operator
81 */
getOutputMinMax(const NodePtr node)82 std::pair<float, float> getOutputMinMax(const NodePtr node) noexcept {
83   float output_min = -std::numeric_limits<float>::infinity();
84   float output_max = std::numeric_limits<float>::infinity();
85   auto output_min_max = node->output_min_max();
86   if (output_min_max != nullptr) {
87     output_min = output_min_max->output_min();
88     output_max = output_min_max->output_max();
89   }
90 
91   return {output_min, output_max};
92 }
93 
94 /*
95 Converts flatbuffer xnn data type to xnnpack data type
96 */
getDataType(const DataType & data_type)97 xnn_datatype getDataType(const DataType& data_type) {
98   switch (data_type) {
99     case DataType::xnn_datatype_fp32:
100       return xnn_datatype::xnn_datatype_fp32;
101     case DataType::xnn_datatype_fp16:
102       return xnn_datatype::xnn_datatype_fp16;
103     case DataType::xnn_datatype_qint8:
104       return xnn_datatype::xnn_datatype_qint8;
105     case DataType::xnn_datatype_quint8:
106       return xnn_datatype::xnn_datatype_quint8;
107     case DataType::xnn_datatype_qint32:
108       return xnn_datatype::xnn_datatype_qint32;
109     case DataType::xnn_datatype_qcint8:
110       return xnn_datatype::xnn_datatype_qcint8;
111     case DataType::xnn_datatype_qcint32:
112       return xnn_datatype::xnn_datatype_qcint32;
113     case DataType::xnn_datatype_qcint4:
114       return xnn_datatype::xnn_datatype_qcint4;
115     case DataType::xnn_datatype_qdint8:
116       return xnn_datatype::xnn_datatype_qdint8;
117     case DataType::xnn_datatype_qbint4:
118       return xnn_datatype::xnn_datatype_qbint4;
119     default:
120       return xnn_datatype::xnn_datatype_invalid;
121   }
122 }
123 
isQuantizedDataType(const xnn_datatype data_type)124 bool isQuantizedDataType(const xnn_datatype data_type) {
125   switch (data_type) {
126     case xnn_datatype::xnn_datatype_qint8:
127     case xnn_datatype::xnn_datatype_quint8:
128     case xnn_datatype::xnn_datatype_qint32:
129     case xnn_datatype::xnn_datatype_qcint8:
130     case xnn_datatype::xnn_datatype_qcint32:
131     case xnn_datatype::xnn_datatype_qcint4:
132     case xnn_datatype::xnn_datatype_qdint8:
133       return true;
134     default:
135       return false;
136   }
137 }
138 
139 /**
140 Converts dims from uint32 to size_t. Takes in a flatbuffer vector
141 of uint32_t and returns a std::vector of size_t. XNNPACK takes in
142 dims of size_t* but tensor shape is serialized in flatbuffer as
143 int32_t. As a result, we need to static cast the shapes to size_t
144 */
145 template <typename T = size_t>
flatbufferDimsToVector(const flatbuffers::Vector<uint32_t> * fb_dims)146 std::vector<T> flatbufferDimsToVector(
147     const flatbuffers::Vector<uint32_t>* fb_dims) {
148   std::vector<T> dims_data;
149   dims_data.reserve(fb_dims->size());
150   for (auto fb_dim : *fb_dims) {
151     dims_data.push_back(static_cast<T>(fb_dim));
152   }
153   return dims_data;
154 }
155 
156 /**
157 Gets the constant data pointer associated with the given tensor value.
158 Obtaining the constant data pointer can either be from within the flatbuffer
159 payload (deprecated) or via offsets to the constant_data_ptr. If no constant
160 data associated with the tensor value, then returns nullptr.
161 */
getConstantDataPtr(const fb_xnnpack::XNNTensorValue * tensor_value,GraphPtr flatbuffer_graph,const uint8_t * constant_data_ptr)162 const uint8_t* getConstantDataPtr(
163     const fb_xnnpack::XNNTensorValue* tensor_value,
164     GraphPtr flatbuffer_graph,
165     const uint8_t* constant_data_ptr) {
166   auto buffer_idx = tensor_value->constant_buffer_idx();
167   if (buffer_idx) {
168     if (!constant_data_ptr) {
169       // TODO(T172265611): Remove constant_buffer in flatbuffer path after BC
170       // window
171       const auto& constant_buffer = *flatbuffer_graph->constant_buffer();
172       return constant_buffer[buffer_idx]->storage()->data();
173     } else {
174       const auto& constant_data_offsets = *flatbuffer_graph->constant_data();
175       uint64_t constant_data_offset =
176           constant_data_offsets[buffer_idx]->offset();
177       return constant_data_ptr + constant_data_offset;
178     }
179   }
180 
181   return nullptr;
182 }
183 
184 /**
185 Define serialized tensor value into
186 the subgraph. While also keeping track of the remapped ids from
187 the serialized id to the newly generated id.
188 */
defineTensor(xnn_subgraph_t subgraph_ptr,std::unordered_map<uint32_t,uint32_t> & remapped_ids,ValuePtr value,GraphPtr flatbuffer_graph,const uint8_t * constant_data_ptr,std::vector<uint32_t> & input_ids,std::vector<uint32_t> & output_ids,CompileAllocator & allocator)189 Error defineTensor(
190     xnn_subgraph_t subgraph_ptr,
191     std::unordered_map<uint32_t, uint32_t>& remapped_ids,
192     ValuePtr value,
193     GraphPtr flatbuffer_graph,
194     const uint8_t* constant_data_ptr,
195     std::vector<uint32_t>& input_ids,
196     std::vector<uint32_t>& output_ids,
197     CompileAllocator& allocator) {
198   const fb_xnnpack::XNNTensorValue* tensor_value = nullptr;
199   const fb_xnnpack::XNNQuantizedTensorValue* qtensor_value = nullptr;
200 
201   switch (value->xvalue_union_type()) {
202     case fb_xnnpack::XValueUnion::XNNTensorValue: {
203       tensor_value = value->xvalue_union_as_XNNTensorValue();
204       break;
205     }
206     case fb_xnnpack::XValueUnion::XNNQuantizedTensorValue: {
207       qtensor_value = value->xvalue_union_as_XNNQuantizedTensorValue();
208       tensor_value = qtensor_value->tensor_value();
209       break;
210     }
211     default: {
212       ET_CHECK_OR_RETURN_ERROR(
213           false,
214           NotImplemented,
215           "Unhandled value type: %s",
216           fb_xnnpack::EnumNameXValueUnion(value->xvalue_union_type()));
217     }
218   }
219 
220   ET_CHECK_OR_RETURN_ERROR(
221       tensor_value != nullptr,
222       Internal,
223       "Deserialized Tensor is Null, this should never happen");
224 
225   // Get tensor dims, here we need to use a vector in order
226   // to properly convert the uint32_t* to size_t*
227   std::vector<size_t> dims_data = flatbufferDimsToVector(tensor_value->dims());
228 
229   // XNNPACK Id
230   uint32_t id = XNN_INVALID_VALUE_ID;
231 
232   // Get Pointer to constant data from flatbuffer, if its non-constant
233   // it is a nullptr
234   const uint8_t* buffer_ptr =
235       getConstantDataPtr(tensor_value, flatbuffer_graph, constant_data_ptr);
236 
237   xnn_status status;
238   // The type we might have to convert to
239   auto dq_datatype = getDataType(tensor_value->dq_datatype());
240 
241   if (dq_datatype != xnn_datatype::xnn_datatype_invalid) {
242     if (dq_datatype != xnn_datatype::xnn_datatype_qint8) {
243       ET_CHECK_OR_RETURN_ERROR(
244           false,
245           Internal,
246           "Only int8_t is supported for dq_datatype for now, got: %d",
247           dq_datatype);
248     } else {
249       ET_CHECK_OR_RETURN_ERROR(
250           (tensor_value->flags() & XNN_VALUE_FLAG_EXTERNAL_INPUT),
251           Internal,
252           "Dynamic quantization of tensor is only allowed for the external input tensor value for now! got flags: %u",
253           tensor_value->flags());
254     }
255   }
256 
257   if (qtensor_value == nullptr) {
258     // FP32 tensor
259     if (!isQuantizedDataType(dq_datatype)) {
260       // Define non-quantied tensor
261       status = xnn_define_tensor_value(
262           /*subgraph=*/subgraph_ptr,
263           /*datatype=*/getDataType(tensor_value->datatype()),
264           /*num_dims=*/tensor_value->num_dims(),
265           /*dims=*/dims_data.data(),
266           /*data=*/buffer_ptr,
267           /*external_id=*/tensor_value->external_id(),
268           /*flags=*/tensor_value->flags(),
269           /*id_out=*/&id);
270     } else if (dq_datatype != xnn_datatype::xnn_datatype_invalid) {
271       ET_CHECK_OR_RETURN_ERROR(
272           isQuantizedDataType(dq_datatype),
273           Internal,
274           "Dynamic quantization can only produce supported quantized dtypes");
275       ET_CHECK_OR_RETURN_ERROR(
276           tensor_value->external_id() != XNN_INVALID_VALUE_ID,
277           Internal,
278           "Dynamic quantization can only work with external inputs for now, got an internal ID");
279       ET_CHECK_OR_RETURN_ERROR(
280           buffer_ptr == nullptr,
281           Internal,
282           "Dynamic quantization can only work with external inputs for now, got const data");
283 
284       switch (dq_datatype) {
285         case xnn_datatype::xnn_datatype_qint8: {
286           // HACK TO Maintain FC/BC for ASR this will be removed after 01/2024
287 
288           // When encountering a dynamically quantized tensor via dq_datatype,
289           // which is the old flow for serializing dynamically quantized linear.
290           // We replace the definition of a single tensor with a new dynamic
291           // Quantization pattern. We change the pattern from:
292           //     serialized_qd_input
293           //           to
294           // (fp32_input --> convert --> qdint8_input)
295 
296           status = xnn_define_dynamically_quantized_tensor_value(
297               /*subgraph=*/subgraph_ptr,
298               /*datatype=*/xnn_datatype_qdint8,
299               /*num_dims=*/tensor_value->num_dims(),
300               /*num_nonbatch_dims=*/1, // always do per token quantization
301               /*dims=*/dims_data.data(),
302               /*external_id=*/XNN_INVALID_VALUE_ID, // always internal value id
303               /*flags=*/0, // this is netiher external input or output
304               /*id_out=*/&id);
305 
306           // this is the FP16 or FP32 external value that is being dynamically
307           // quantized
308           uint32_t float_id;
309           enum xnn_datatype fp_datatype = getDataType(tensor_value->datatype());
310           status = xnn_define_tensor_value(
311               /*subgraph=*/subgraph_ptr,
312               /*datatype=*/fp_datatype,
313               /*num_dims=*/tensor_value->num_dims(),
314               /*dims=*/dims_data.data(),
315               /*data=*/buffer_ptr,
316               /*external_id=*/tensor_value->external_id(),
317               /*flags=*/tensor_value->flags(),
318               /*id_out=*/&float_id);
319 
320           // Define dynamic conversion from float to qdint8
321           status = xnn_define_convert(
322               /*subgraph=*/subgraph_ptr,
323               /*input_id=*/float_id,
324               /*output_id=*/id,
325               /*flags=*/0);
326           break;
327         }
328         default:
329           ET_CHECK_OR_RETURN_ERROR(
330               false,
331               NotImplemented,
332               "Unhandled Dyanmic Quantization dtype: %d",
333               dq_datatype);
334       }
335     } else {
336       ET_CHECK_OR_RETURN_ERROR(false, NotImplemented, "Unhandled fp32 tensor");
337     }
338   } else {
339     // define tensor for quantized
340     switch (qtensor_value->quant_params_type()) {
341       case fb_xnnpack::XNNQuantParams::PerTensorQuant: {
342         auto qparams = qtensor_value->quant_params_as_PerTensorQuant();
343         ET_LOG(
344             Debug,
345             "define quant tensor (per tensor): buffer_ptr: %p, scale: %f, zp: %u\n",
346             buffer_ptr,
347             qparams->scale(),
348             qparams->zero_point());
349         status = xnn_define_quantized_tensor_value(
350             /*subgraph=*/subgraph_ptr,
351             /*datatype=*/getDataType(tensor_value->datatype()),
352             /*zero_point=*/qparams->zero_point(),
353             /*scale=*/qparams->scale(),
354             /*num_dims=*/tensor_value->num_dims(),
355             /*dims=*/dims_data.data(),
356             /*data=*/buffer_ptr,
357             /*external_id=*/tensor_value->external_id(),
358             /*flags=*/tensor_value->flags(),
359             /*id_out=*/&id);
360         break;
361       }
362       case fb_xnnpack::XNNQuantParams::PerChannelQuant: {
363         auto qparams = qtensor_value->quant_params_as_PerChannelQuant();
364         enum xnn_datatype dtype = getDataType(tensor_value->datatype());
365         int32_t zero_point =
366             (dtype == xnn_datatype::xnn_datatype_qcint4 ? 8 : 0);
367 
368         ET_LOG(
369             Debug,
370             "define quant tensor (per channel): buffer_ptr: %p, scale.numel(): %u, channel_dim: %u, dtype: %u, zero_point: %d\n",
371             buffer_ptr,
372             qparams->scale()->size(),
373             qparams->channel_dim(),
374             dtype,
375             zero_point);
376         status = xnn_define_channelwise_quantized_tensor_value_v2(
377             /*subgraph=*/subgraph_ptr,
378             /*datatype=*/dtype,
379             /*zero_point=*/zero_point,
380             /*scale=*/qparams->scale()->data(),
381             /*num_dims=*/tensor_value->num_dims(),
382             /*channel_dim*/ qparams->channel_dim(),
383             /*dims=*/dims_data.data(),
384             /*data=*/buffer_ptr,
385             /*external_id=*/tensor_value->external_id(),
386             /*flags=*/tensor_value->flags(),
387             /*id_out=*/&id);
388         break;
389       }
390       case fb_xnnpack::XNNQuantParams::PerChannelGroupQuant: {
391         xnn_datatype datatype = getDataType(tensor_value->datatype());
392         ET_CHECK_OR_RETURN_ERROR(
393             datatype == xnn_datatype::xnn_datatype_qbint4,
394             Internal,
395             "Unsupported datatype for per channel group quantization: %d",
396             datatype);
397         auto qparams = qtensor_value->quant_params_as_PerChannelGroupQuant();
398         size_t group_size = qparams->group_size();
399         size_t output_channels = tensor_value->dims()->Get(0);
400         size_t input_channels = tensor_value->dims()->Get(1);
401 
402         const uint16_t* scale_data = nullptr;
403         uint32_t scale_numel = 0;
404 
405         // Block scales are preferably serialized as bf16 but can also be
406         // serialized as fp32 for backwards compatability.
407         if (qparams->scale_bf16() != nullptr) {
408           scale_data =
409               static_cast<const uint16_t*>(qparams->scale_bf16()->data());
410           scale_numel = qparams->scale_bf16()->size();
411         } else {
412           // Read fp32 scales, convert to bf16.
413           auto conv_buffer = static_cast<uint16_t*>(allocator.allocateTemporary(
414               qparams->scale()->size() * sizeof(uint16_t)));
415           scale_numel = qparams->scale()->size();
416           convertF32TensorToBF16(
417               qparams->scale()->data(), conv_buffer, scale_numel);
418           scale_data = conv_buffer;
419         }
420 
421         ET_CHECK_OR_RETURN_ERROR(
422             scale_numel == output_channels * input_channels / group_size,
423             Internal,
424             "scale size %zu != output channels %zu * group size %zu",
425             static_cast<size_t>(scale_numel),
426             output_channels,
427             group_size);
428         int32_t zero_point =
429             (datatype == xnn_datatype::xnn_datatype_qbint4 ? 8 : 0);
430         ET_LOG(
431             Debug,
432             "define quant tensor (per channel group): buffer_ptr: %p, scale.numel(): %u, channel_dim: %u, grpup_size: %zu, output_channels: %zu, dtype: %u, zero_point: %d, datatype: %d\n",
433             buffer_ptr,
434             scale_numel,
435             qparams->channel_dim(),
436             group_size,
437             output_channels,
438             datatype,
439             zero_point,
440             datatype);
441 
442         status = xnn_define_blockwise_quantized_tensor_value(
443             /*subgraph=*/subgraph_ptr,
444             /*datatype=*/datatype,
445             /*zero_point=*/zero_point,
446             /*scale=*/scale_data,
447             /*num_dims=*/tensor_value->num_dims(),
448             /*channel_dim=*/qparams->channel_dim(),
449             /*block_size=*/qparams->group_size(),
450             /*dims=*/dims_data.data(),
451             /*data=*/buffer_ptr,
452             /*external_id=*/tensor_value->external_id(),
453             /*flags=*/tensor_value->flags(),
454             /*id_out=*/&id);
455         break;
456       }
457       case fb_xnnpack::XNNQuantParams::PerTokenDynamicQuant: {
458         auto qparams = qtensor_value->quant_params_as_PerTokenDynamicQuant();
459         ET_LOG(
460             Debug,
461             "define quant tensor (dynamic): num_dims: %i, num_nonbatch_dims: %i\n",
462             tensor_value->num_dims(),
463             qparams->num_nonbatch_dims());
464         ET_CHECK_OR_RETURN_ERROR(
465             buffer_ptr == nullptr,
466             Internal,
467             "Dynamically quantized tensor should not have constant data but found non-nullptr");
468         // TODO(T179441835): Dynamic Quantization with num_nonbatch_dims > 1
469         ET_CHECK_OR_RETURN_ERROR(
470             qparams->num_nonbatch_dims() == 1,
471             Internal,
472             "Dynamically Quantized Tensors currently only support per token quantization");
473         status = xnn_define_dynamically_quantized_tensor_value(
474             /*subgraph=*/subgraph_ptr,
475             /*datatype=*/getDataType(tensor_value->datatype()),
476             /*num_dims=*/tensor_value->num_dims(),
477             /*num_nonbatch_dims*/ qparams->num_nonbatch_dims(),
478             /*dims=*/dims_data.data(),
479             /*external_id=*/tensor_value->external_id(),
480             /*flags=*/tensor_value->flags(),
481             /*id_out=*/&id);
482         break;
483       }
484       default: {
485         ET_CHECK_OR_RETURN_ERROR(
486             false,
487             NotImplemented,
488             "Unhandled Quantization Parameters: %s",
489             fb_xnnpack::EnumNameXNNQuantParams(
490                 qtensor_value->quant_params_type()));
491       }
492     }
493   }
494 
495   ET_CHECK_OR_RETURN_ERROR(
496       status == xnn_status_success,
497       Internal,
498       "Failed to define tensor %i with code: %s",
499       tensor_value->id_out(),
500       xnn_status_to_string(status));
501 
502   // map serialized id to newly generated id
503   remapped_ids.emplace(std::make_pair(tensor_value->id_out(), id));
504 
505   // Add external ids to either list of input or output ids
506   if (tensor_value->flags() & XNN_VALUE_FLAG_EXTERNAL_INPUT) {
507     input_ids.push_back(tensor_value->external_id());
508   }
509   if (tensor_value->flags() & XNN_VALUE_FLAG_EXTERNAL_OUTPUT) {
510     output_ids.push_back(tensor_value->external_id());
511   }
512 
513   return Error::Ok;
514 };
515 
516 #define MAYBE_UNUSED(x) (void)(x)
517 
518 /*
519 Define serialized add node into the subgraph, using the remapped ids
520 to map the serialized ids, to the new ids generated when defining
521 the tensor value
522 */
defineAddNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)523 Error defineAddNode(
524     xnn_subgraph_t subgraph_ptr,
525     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
526     const NodePtr node,
527     const fb_xnnpack::XNNGraph* graph) noexcept {
528   MAYBE_UNUSED(graph);
529 
530   std::pair<float, float> min_max = getOutputMinMax(node);
531   auto graph_node = node->xnode_union_as_XNNAdd();
532   xnn_status status = xnn_define_add2(
533       subgraph_ptr,
534       min_max.first,
535       min_max.second,
536       remapped_ids.at(graph_node->input1_id()),
537       remapped_ids.at(graph_node->input2_id()),
538       remapped_ids.at(graph_node->output_id()),
539       graph_node->flags());
540   ET_CHECK_OR_RETURN_ERROR(
541       status == xnn_status_success,
542       Internal,
543       "Failed to create add node %i with code: %s",
544       node->debug_handle(),
545       xnn_status_to_string(status));
546 
547   return Error::Ok;
548 };
549 
550 /*
551 Define Minimum operator Node into the subgraph
552 */
defineMinimumNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)553 Error defineMinimumNode(
554     xnn_subgraph_t subgraph_ptr,
555     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
556     const NodePtr node,
557     const fb_xnnpack::XNNGraph* graph) noexcept {
558   MAYBE_UNUSED(graph);
559 
560   auto graph_node = node->xnode_union_as_XNNMinimum();
561   xnn_status status = xnn_define_minimum2(
562       subgraph_ptr,
563       remapped_ids.at(graph_node->input1_id()),
564       remapped_ids.at(graph_node->input2_id()),
565       remapped_ids.at(graph_node->output_id()),
566       graph_node->flags());
567 
568   ET_CHECK_OR_RETURN_ERROR(
569       status == xnn_status_success,
570       Internal,
571       "Failed to create minumum node %i with code: %s",
572       node->debug_handle(),
573       xnn_status_to_string(status));
574 
575   return Error::Ok;
576 };
577 
578 /*
579 Define subtract operator Node into the subgraph
580 */
defineSubtractNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)581 Error defineSubtractNode(
582     xnn_subgraph_t subgraph_ptr,
583     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
584     const NodePtr node,
585     const fb_xnnpack::XNNGraph* graph) noexcept {
586   MAYBE_UNUSED(graph);
587 
588   auto graph_node = node->xnode_union_as_XNNSubtract();
589   std::pair<float, float> min_max = getOutputMinMax(node);
590   xnn_status status = xnn_define_subtract(
591       subgraph_ptr,
592       min_max.first,
593       min_max.second,
594       remapped_ids.at(graph_node->input1_id()),
595       remapped_ids.at(graph_node->input2_id()),
596       remapped_ids.at(graph_node->output_id()),
597       graph_node->flags());
598 
599   ET_CHECK_OR_RETURN_ERROR(
600       status == xnn_status_success,
601       Internal,
602       "Failed to create subtract node %i with code: %s",
603       node->debug_handle(),
604       xnn_status_to_string(status));
605 
606   return Error::Ok;
607 };
608 
609 /*
610 Define Multiply operator Node into the subgraph
611 */
defineMultiplyNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)612 Error defineMultiplyNode(
613     xnn_subgraph_t subgraph_ptr,
614     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
615     const NodePtr node,
616     const fb_xnnpack::XNNGraph* graph) noexcept {
617   MAYBE_UNUSED(graph);
618 
619   auto graph_node = node->xnode_union_as_XNNMultiply();
620   std::pair<float, float> min_max = getOutputMinMax(node);
621   xnn_status status = xnn_define_multiply2(
622       subgraph_ptr,
623       min_max.first,
624       min_max.second,
625       remapped_ids.at(graph_node->input1_id()),
626       remapped_ids.at(graph_node->input2_id()),
627       remapped_ids.at(graph_node->output_id()),
628       graph_node->flags());
629 
630   ET_CHECK_OR_RETURN_ERROR(
631       status == xnn_status_success,
632       Internal,
633       "Failed to create multiply node %i with code: %s",
634       node->debug_handle(),
635       xnn_status_to_string(status));
636 
637   return Error::Ok;
638 };
639 
640 #ifdef ENABLE_XNNPACK_KLEIDI
isQP8(const fb_xnnpack::XNNGraph * graph,const NodePtr node)641 bool isQP8(const fb_xnnpack::XNNGraph* graph, const NodePtr node) {
642   assert(node->xnode_union_type() == fb_xnnpack::XNodeUnion::XNNConvert);
643   auto graph_node = node->xnode_union_as_XNNConvert();
644   auto cvt_output_id = graph_node->output_id();
645 
646   auto check_dtype = [graph](uint32_t id, DataType dtype) -> bool {
647     assert(
648         dtype == DataType::xnn_datatype_qdint8 ||
649         dtype == DataType::xnn_datatype_qbint4);
650     for (auto value : *graph->xvalues()) {
651       if (value->xvalue_union_type() !=
652           fb_xnnpack::XValueUnion::XNNQuantizedTensorValue) {
653         continue;
654       }
655       auto tensor =
656           value->xvalue_union_as_XNNQuantizedTensorValue()->tensor_value();
657       if (tensor->id_out() == id) {
658         return tensor->datatype() == dtype;
659       }
660     }
661     return false;
662   };
663 
664   // Check if the output tensor is qint8 else bail early.
665   if (!check_dtype(cvt_output_id, DataType::xnn_datatype_qdint8)) {
666     return false;
667   }
668 
669   // Find if the convert output is going to the right linear node.
670   // Assuming if we can find one valid linear node, then we can use QP8
671   // for all the linear nodes consuming this convert output.
672   for (auto node : *graph->xnodes()) {
673     if (node->xnode_union_type() == fb_xnnpack::XNodeUnion::XNNFullyConnected) {
674       auto linear_node = node->xnode_union_as_XNNFullyConnected();
675       if (linear_node->input1_id() == cvt_output_id) {
676         if (check_dtype(
677                 linear_node->filter_id(), DataType::xnn_datatype_qbint4)) {
678           return true;
679         }
680       }
681     }
682   }
683   return false;
684 }
685 #endif // ENABLE_XNNPACK_KLEIDI
686 
687 /*
688 Define Convert operator Node into the subgraph
689 */
defineConvertNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * flatbuffer_graph)690 Error defineConvertNode(
691     xnn_subgraph_t subgraph_ptr,
692     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
693     const NodePtr node,
694     const fb_xnnpack::XNNGraph* flatbuffer_graph) noexcept {
695   MAYBE_UNUSED(flatbuffer_graph);
696   auto graph_node = node->xnode_union_as_XNNConvert();
697 
698   int32_t flags = graph_node->flags();
699 #ifdef ENABLE_XNNPACK_KLEIDI
700 // This is not currently exposed at include/xnnpack.h yet once it is
701 // we can remove this runtime logic and do this ahead-of-time
702 #define XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM 0x00000100;
703   if (isQP8(flatbuffer_graph, node)) {
704     flags |= XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM;
705     ET_LOG(
706         Debug,
707         "Setting XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM flag for convert node %i",
708         node->debug_handle());
709   }
710 #endif
711 
712   xnn_status status = xnn_define_convert(
713       subgraph_ptr,
714       remapped_ids.at(graph_node->input_id()),
715       remapped_ids.at(graph_node->output_id()),
716       flags);
717 
718   ET_CHECK_OR_RETURN_ERROR(
719       status == xnn_status_success,
720       Internal,
721       "Failed to create convert node %i with code: %s",
722       node->debug_handle(),
723       xnn_status_to_string(status));
724 
725   return Error::Ok;
726 };
727 /*
728 Define serialized linear(fully-connected) node into the subgraph using
729 the remapped ids to map the serialized ids, to the new ids generated
730 when defining the tensor values
731 */
defineFullyConnectedNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)732 Error defineFullyConnectedNode(
733     xnn_subgraph_t subgraph_ptr,
734     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
735     const NodePtr node,
736     const fb_xnnpack::XNNGraph* graph) noexcept {
737   MAYBE_UNUSED(graph);
738 
739   auto graph_node = node->xnode_union_as_XNNFullyConnected();
740   std::pair<float, float> min_max = getOutputMinMax(node);
741   xnn_status status = xnn_define_fully_connected(
742       subgraph_ptr,
743       min_max.first,
744       min_max.second,
745       remapped_ids.at(graph_node->input1_id()),
746       remapped_ids.at(graph_node->filter_id()),
747       remapped_ids.at(graph_node->bias_id()),
748       remapped_ids.at(graph_node->output_id()),
749       graph_node->flags());
750   ET_CHECK_OR_RETURN_ERROR(
751       status == xnn_status_success,
752       Internal,
753       "Failed to create linear node %i, with code: %s",
754       node->debug_handle(),
755       xnn_status_to_string(status));
756 
757   return Error::Ok;
758 };
759 
760 /*
761 Define serialized clamp node into the subgraph, using the remapped ids
762 to map the serialized ids, to the new ids generated when defining
763 the tensor value
764 */
defineClampNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)765 Error defineClampNode(
766     xnn_subgraph_t subgraph_ptr,
767     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
768     const NodePtr node,
769     const fb_xnnpack::XNNGraph* graph) noexcept {
770   MAYBE_UNUSED(graph);
771 
772   std::pair<float, float> min_max = getOutputMinMax(node);
773   auto graph_node = node->xnode_union_as_XNNClamp();
774   xnn_status status = xnn_define_clamp(
775       subgraph_ptr,
776       min_max.first,
777       min_max.second,
778       remapped_ids.at(graph_node->input_id()),
779       remapped_ids.at(graph_node->output_id()),
780       graph_node->flags());
781 
782   ET_CHECK_OR_RETURN_ERROR(
783       status == xnn_status_success,
784       Internal,
785       "Failed to create hardtanh node %i with code: %s",
786       node->debug_handle(),
787       xnn_status_to_string(status));
788 
789   return Error::Ok;
790 }
791 
792 /*
793 Define serialized softmax node into the subgraph, using the remapped ids
794 to map the serialized ids, to the new ids generated when defining
795 the tensor value
796 */
defineSoftmaxNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)797 Error defineSoftmaxNode(
798     xnn_subgraph_t subgraph_ptr,
799     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
800     const NodePtr node,
801     const fb_xnnpack::XNNGraph* graph) noexcept {
802   MAYBE_UNUSED(graph);
803 
804   auto graph_node = node->xnode_union_as_XNNSoftmax();
805   xnn_status status = xnn_define_softmax(
806       subgraph_ptr,
807       remapped_ids.at(graph_node->input_id()),
808       remapped_ids.at(graph_node->output_id()),
809       graph_node->flags());
810   ET_CHECK_OR_RETURN_ERROR(
811       status == xnn_status_success,
812       Internal,
813       "Failed to create softmax node %i with code: %s",
814       node->debug_handle(),
815       xnn_status_to_string(status));
816 
817   return Error::Ok;
818 }
819 
820 /*
821 Define serialized sigmoid node into the subgraph, using the remapped ids
822 to map the serialized ids, to the new ids generated when defining
823 the tensor value
824 */
defineSigmoidNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)825 Error defineSigmoidNode(
826     xnn_subgraph_t subgraph_ptr,
827     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
828     const NodePtr node,
829     const fb_xnnpack::XNNGraph* graph) noexcept {
830   MAYBE_UNUSED(graph);
831 
832   auto graph_node = node->xnode_union_as_XNNSigmoid();
833   xnn_status status = xnn_define_sigmoid(
834       subgraph_ptr,
835       remapped_ids.at(graph_node->input_id()),
836       remapped_ids.at(graph_node->output_id()),
837       graph_node->flags());
838   ET_CHECK_OR_RETURN_ERROR(
839       status == xnn_status_success,
840       Internal,
841       "Failed to create sigmoid node %i with code: %s",
842       node->debug_handle(),
843       xnn_status_to_string(status));
844 
845   return Error::Ok;
846 }
847 
848 /*
849 Define serialized floor node into the subgraph, using the remapped ids
850 to map the serialized ids, to the new ids generated when defining
851 the tensor value
852 */
defineFloorNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)853 Error defineFloorNode(
854     xnn_subgraph_t subgraph_ptr,
855     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
856     const NodePtr node,
857     const fb_xnnpack::XNNGraph* graph) noexcept {
858   MAYBE_UNUSED(graph);
859 
860   auto graph_node = node->xnode_union_as_XNNFloor();
861   xnn_status status = xnn_define_floor(
862       subgraph_ptr,
863       remapped_ids.at(graph_node->input_id()),
864       remapped_ids.at(graph_node->output_id()),
865       graph_node->flags());
866   ET_CHECK_OR_RETURN_ERROR(
867       status == xnn_status_success,
868       Internal,
869       "Failed to create floor node %i with code: %s",
870       node->debug_handle(),
871       xnn_status_to_string(status));
872 
873   return Error::Ok;
874 }
875 
defineGlobalAvgPooling2dNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)876 Error defineGlobalAvgPooling2dNode(
877     xnn_subgraph_t subgraph_ptr,
878     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
879     const NodePtr node,
880     const fb_xnnpack::XNNGraph* graph) noexcept {
881   MAYBE_UNUSED(graph);
882 
883   auto graph_node = node->xnode_union_as_XNNGlobalAvgPooling2d();
884   std::pair<float, float> min_max = getOutputMinMax(node);
885   xnn_status status = xnn_define_global_average_pooling_2d(
886       subgraph_ptr,
887       min_max.first,
888       min_max.second,
889       remapped_ids.at(graph_node->input_id()),
890       remapped_ids.at(graph_node->output_id()),
891       graph_node->flags());
892   ET_CHECK_OR_RETURN_ERROR(
893       status == xnn_status_success,
894       Internal,
895       "Failed to create global average pooling node %i with code: %s",
896       node->debug_handle(),
897       xnn_status_to_string(status));
898 
899   return Error::Ok;
900 }
901 
defineAvgPooling2dNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)902 Error defineAvgPooling2dNode(
903     xnn_subgraph_t subgraph_ptr,
904     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
905     const NodePtr node,
906     const fb_xnnpack::XNNGraph* graph) noexcept {
907   MAYBE_UNUSED(graph);
908 
909   auto graph_node = node->xnode_union_as_XNNAvgPooling2d();
910   std::pair<float, float> min_max = getOutputMinMax(node);
911   xnn_status status = xnn_define_average_pooling_2d(
912       subgraph_ptr,
913       graph_node->padding_top(),
914       graph_node->padding_right(),
915       graph_node->padding_bottom(),
916       graph_node->padding_left(),
917       graph_node->pooling_height(),
918       graph_node->pooling_width(),
919       graph_node->stride_height(),
920       graph_node->stride_width(),
921       min_max.first,
922       min_max.second,
923       remapped_ids.at(graph_node->input_id()),
924       remapped_ids.at(graph_node->output_id()),
925       graph_node->flags());
926   ET_CHECK_OR_RETURN_ERROR(
927       status == xnn_status_success,
928       Internal,
929       "Failed to create average pooling node %i with code: %s",
930       node->debug_handle(),
931       xnn_status_to_string(status));
932 
933   return Error::Ok;
934 }
935 
936 /*
937 Define serialized conv2d node into the subgraph, using the remapped ids
938 to map the serialized ids, to the new ids generated when defining the
939 tensor value
940 */
defineConv2dNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)941 Error defineConv2dNode(
942     xnn_subgraph_t subgraph_ptr,
943     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
944     const NodePtr node,
945     const fb_xnnpack::XNNGraph* graph) noexcept {
946   MAYBE_UNUSED(graph);
947 
948   auto graph_node = node->xnode_union_as_XNNConv2d();
949   std::pair<float, float> min_max = getOutputMinMax(node);
950   xnn_status status = xnn_define_convolution_2d(
951       subgraph_ptr,
952       graph_node->padding_top(),
953       graph_node->padding_right(),
954       graph_node->padding_bottom(),
955       graph_node->padding_left(),
956       graph_node->kernel_height(),
957       graph_node->kernel_width(),
958       graph_node->subsampling_height(),
959       graph_node->subsampling_width(),
960       graph_node->dilation_height(),
961       graph_node->dilation_width(),
962       graph_node->groups(),
963       graph_node->group_input_channels(),
964       graph_node->group_output_channels(),
965       min_max.first,
966       min_max.second,
967       remapped_ids.at(graph_node->input1_id()),
968       remapped_ids.at(graph_node->filter_id()),
969       remapped_ids.at(graph_node->bias_id()),
970       remapped_ids.at(graph_node->output_id()),
971       graph_node->flags());
972   ET_CHECK_OR_RETURN_ERROR(
973       status == xnn_status_success,
974       Internal,
975       "Failed to create convolution node %i with code: %s",
976       node->debug_handle(),
977       xnn_status_to_string(status));
978 
979   return Error::Ok;
980 }
981 
982 /*
983 Define serialized maxpool2d node into the subgraph, using the remapped ids
984 to map the serialized ids, to the new ids generated when defining the
985 tensor value
986 */
defineMaxPooling2dNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)987 Error defineMaxPooling2dNode(
988     xnn_subgraph_t subgraph_ptr,
989     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
990     const NodePtr node,
991     const fb_xnnpack::XNNGraph* graph) noexcept {
992   MAYBE_UNUSED(graph);
993 
994   auto graph_node = node->xnode_union_as_XNNMaxPooling2d();
995   std::pair<float, float> min_max = getOutputMinMax(node);
996   xnn_status status = xnn_define_max_pooling_2d(
997       subgraph_ptr,
998       graph_node->padding_top(),
999       graph_node->padding_right(),
1000       graph_node->padding_bottom(),
1001       graph_node->padding_left(),
1002       graph_node->pooling_height(),
1003       graph_node->pooling_width(),
1004       graph_node->stride_height(),
1005       graph_node->stride_width(),
1006       graph_node->dilation_height(),
1007       graph_node->dilation_width(),
1008       min_max.first,
1009       min_max.second,
1010       remapped_ids.at(graph_node->input_id()),
1011       remapped_ids.at(graph_node->output_id()),
1012       graph_node->flags());
1013 
1014   ET_CHECK_OR_RETURN_ERROR(
1015       status == xnn_status_success,
1016       Internal,
1017       "Failed to create maxpool2d node %i with code: %s",
1018       node->debug_handle(),
1019       xnn_status_to_string(status));
1020 
1021   return Error::Ok;
1022 }
1023 
1024 /*
1025 Define serialized div node into the subgraph
1026 */
defineDivNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1027 Error defineDivNode(
1028     xnn_subgraph_t subgraph_ptr,
1029     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1030     const NodePtr node,
1031     const fb_xnnpack::XNNGraph* graph) noexcept {
1032   MAYBE_UNUSED(graph);
1033 
1034   auto graph_node = node->xnode_union_as_XNNDiv();
1035   std::pair<float, float> min_max = getOutputMinMax(node);
1036   xnn_status status = xnn_define_divide(
1037       subgraph_ptr,
1038       min_max.first,
1039       min_max.second,
1040       remapped_ids.at(graph_node->input1_id()),
1041       remapped_ids.at(graph_node->input2_id()),
1042       remapped_ids.at(graph_node->output_id()),
1043       graph_node->flags());
1044   ET_CHECK_OR_RETURN_ERROR(
1045       status == xnn_status_success,
1046       Internal,
1047       "Failed to create div node %i with code: %s",
1048       node->debug_handle(),
1049       xnn_status_to_string(status));
1050 
1051   return Error::Ok;
1052 }
1053 
1054 /*
1055 Define serialized static transpose node into the subgraph, using the remapped
1056 ids to map the serialized ids, to the new ids generated when defining the
1057 tensor value
1058 */
defineStaticTransposeNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1059 Error defineStaticTransposeNode(
1060     xnn_subgraph_t subgraph_ptr,
1061     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1062     const NodePtr node,
1063     const fb_xnnpack::XNNGraph* graph) noexcept {
1064   MAYBE_UNUSED(graph);
1065 
1066   auto graph_node = node->xnode_union_as_XNNStaticTranspose();
1067 
1068   // Get tensor dims, we need to convert the uint32_t* to size_t*
1069   std::vector<size_t> dims_data = flatbufferDimsToVector(graph_node->perm());
1070   xnn_status status = xnn_define_static_transpose(
1071       subgraph_ptr,
1072       graph_node->num_dims(),
1073       dims_data.data(),
1074       remapped_ids.at(graph_node->input_id()),
1075       remapped_ids.at(graph_node->output_id()),
1076       graph_node->flags());
1077   ET_CHECK_OR_RETURN_ERROR(
1078       status == xnn_status_success,
1079       Internal,
1080       "Failed to create sigmoid node %i with code: %s",
1081       node->debug_handle(),
1082       xnn_status_to_string(status));
1083 
1084   return Error::Ok;
1085 }
1086 
1087 /*
1088 Define serialized static resize bilinear 2d node into the subgraph, using the
1089 remapped ids to map the serialized ids, to the new ids generated when defining
1090 the tensor value
1091 */
defineStaticResizeBilinear2DNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1092 Error defineStaticResizeBilinear2DNode(
1093     xnn_subgraph_t subgraph_ptr,
1094     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1095     const NodePtr node,
1096     const fb_xnnpack::XNNGraph* graph) noexcept {
1097   MAYBE_UNUSED(graph);
1098 
1099   const fb_xnnpack::XNNStaticResizeBilinear2D* graph_node =
1100       node->xnode_union_as_XNNStaticResizeBilinear2D();
1101 
1102   xnn_status status = xnn_define_static_resize_bilinear_2d(
1103       subgraph_ptr,
1104       graph_node->new_height(),
1105       graph_node->new_width(),
1106       remapped_ids.at(graph_node->input_id()),
1107       remapped_ids.at(graph_node->output_id()),
1108       graph_node->flags());
1109   ET_CHECK_OR_RETURN_ERROR(
1110       status == xnn_status_success,
1111       Internal,
1112       "Failed to create StaticResizeBilinear2DNode node %i with code: %s",
1113       node->debug_handle(),
1114       xnn_status_to_string(status));
1115 
1116   return Error::Ok;
1117 }
1118 
1119 /*
1120 Define serialized static constant pad node into the subgraph, using the
1121 remapped ids to map the serialized ids, to the new ids generated when defining
1122 the tensor value
1123 */
defineStaticConstantPadNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1124 Error defineStaticConstantPadNode(
1125     xnn_subgraph_t subgraph_ptr,
1126     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1127     const NodePtr node,
1128     const fb_xnnpack::XNNGraph* graph) noexcept {
1129   MAYBE_UNUSED(graph);
1130 
1131   const fb_xnnpack::XNNStaticConstantPad* graph_node =
1132       node->xnode_union_as_XNNStaticConstantPad();
1133 
1134   std::vector<size_t> pre_paddings_dims =
1135       flatbufferDimsToVector(graph_node->pre_paddings());
1136   std::vector<size_t> post_paddings_dims =
1137       flatbufferDimsToVector(graph_node->post_paddings());
1138 
1139   xnn_status status = xnn_define_static_constant_pad(
1140       subgraph_ptr,
1141       pre_paddings_dims.data(),
1142       post_paddings_dims.data(),
1143       graph_node->padding_value(),
1144       remapped_ids.at(graph_node->input_id()),
1145       remapped_ids.at(graph_node->output_id()),
1146       graph_node->flags());
1147   ET_CHECK_OR_RETURN_ERROR(
1148       status == xnn_status_success,
1149       Internal,
1150       "Failed to create StaticConstantPad node %i with code: %s",
1151       node->debug_handle(),
1152       xnn_status_to_string(status));
1153 
1154   return Error::Ok;
1155 }
1156 
1157 /*
1158 Define serialized depthwise conv2d node into the subgraph, using the remapped
1159 ids to map the serialized ids, to the new ids generated when defining the
1160 tensor value
1161 */
defineDepthwiseConv2dNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1162 Error defineDepthwiseConv2dNode(
1163     xnn_subgraph_t subgraph_ptr,
1164     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1165     const NodePtr node,
1166     const fb_xnnpack::XNNGraph* graph) noexcept {
1167   MAYBE_UNUSED(graph);
1168 
1169   auto graph_node = node->xnode_union_as_XNNDepthwiseConv2d();
1170   std::pair<float, float> min_max = getOutputMinMax(node);
1171   xnn_status status = xnn_define_depthwise_convolution_2d(
1172       subgraph_ptr,
1173       graph_node->padding_top(),
1174       graph_node->padding_right(),
1175       graph_node->padding_bottom(),
1176       graph_node->padding_left(),
1177       graph_node->kernel_height(),
1178       graph_node->kernel_width(),
1179       graph_node->subsampling_height(),
1180       graph_node->subsampling_width(),
1181       graph_node->dilation_height(),
1182       graph_node->dilation_width(),
1183       graph_node->group_output_channels() /
1184           graph_node->group_input_channels(), // depth_multiplier
1185       graph_node->groups(), // input_channels = groups for depthwise conv
1186       min_max.first,
1187       min_max.second,
1188       remapped_ids.at(graph_node->input1_id()),
1189       remapped_ids.at(graph_node->filter_id()),
1190       remapped_ids.at(graph_node->bias_id()),
1191       remapped_ids.at(graph_node->output_id()),
1192       graph_node->flags());
1193 
1194   ET_CHECK_OR_RETURN_ERROR(
1195       status == xnn_status_success,
1196       Internal,
1197       "Failed to create depthwise convolution node %i with code: %s",
1198       node->debug_handle(),
1199       xnn_status_to_string(status));
1200 
1201   return Error::Ok;
1202 }
1203 
defineStaticReshapeNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1204 Error defineStaticReshapeNode(
1205     xnn_subgraph_t subgraph_ptr,
1206     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1207     const NodePtr node,
1208     const fb_xnnpack::XNNGraph* graph) noexcept {
1209   MAYBE_UNUSED(graph);
1210 
1211   auto graph_node = node->xnode_union_as_XNNStaticReshape();
1212 
1213   // Get tensor dims, we need to convert the uint32_t* to size_t*
1214   std::vector<size_t> dims_data =
1215       flatbufferDimsToVector(graph_node->new_shape());
1216   xnn_status status = xnn_define_static_reshape(
1217       subgraph_ptr,
1218       graph_node->num_dims(),
1219       dims_data.data(),
1220       remapped_ids.at(graph_node->input_id()),
1221       remapped_ids.at(graph_node->output_id()),
1222       graph_node->flags());
1223   ET_CHECK_OR_RETURN_ERROR(
1224       status == xnn_status_success,
1225       Internal,
1226       "Failed to create squeeze node %i with code: %s",
1227       node->debug_handle(),
1228       xnn_status_to_string(status));
1229 
1230   return Error::Ok;
1231 }
1232 
1233 /*
1234 Define serialized maxpool2d node into the subgraph, using the remapped ids
1235 to map the serialized ids, to the new ids generated when defining the
1236 tensor value
1237 */
defineArgMaxPooling2dNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1238 Error defineArgMaxPooling2dNode(
1239     xnn_subgraph_t subgraph_ptr,
1240     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1241     const NodePtr node,
1242     const fb_xnnpack::XNNGraph* graph) noexcept {
1243   MAYBE_UNUSED(graph);
1244 
1245   auto graph_node = node->xnode_union_as_XNNArgMaxPooling2d();
1246 
1247   xnn_status status = xnn_define_argmax_pooling_2d(
1248       subgraph_ptr,
1249       graph_node->padding_top(),
1250       graph_node->padding_right(),
1251       graph_node->padding_bottom(),
1252       graph_node->padding_left(),
1253       graph_node->pooling_height(),
1254       graph_node->pooling_width(),
1255       remapped_ids.at(graph_node->input_id()),
1256       remapped_ids.at(graph_node->output_value_id()),
1257       remapped_ids.at(graph_node->output_index_id()),
1258       graph_node->flags());
1259 
1260   ET_CHECK_OR_RETURN_ERROR(
1261       status == xnn_status_success,
1262       Internal,
1263       "Failed to create argmaxpool2d node %i with code: %s",
1264       node->debug_handle(),
1265       xnn_status_to_string(status));
1266 
1267   return Error::Ok;
1268 }
1269 
1270 /*
1271 Define serialized square root node into the subgraph, using the remapped ids
1272 to map the serialized ids, to the new ids generated when defining the
1273 tensor value
1274 */
defineSquareRootNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1275 Error defineSquareRootNode(
1276     xnn_subgraph_t subgraph_ptr,
1277     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1278     const NodePtr node,
1279     const fb_xnnpack::XNNGraph* graph) noexcept {
1280   MAYBE_UNUSED(graph);
1281 
1282   auto graph_node = node->xnode_union_as_XNNSquareRoot();
1283 
1284   xnn_status status = xnn_define_square_root(
1285       subgraph_ptr,
1286       remapped_ids.at(graph_node->input_id()),
1287       remapped_ids.at(graph_node->output_id()),
1288       graph_node->flags());
1289 
1290   ET_CHECK_OR_RETURN_ERROR(
1291       status == xnn_status_success,
1292       Internal,
1293       "Failed to create square root node %i with code: %s",
1294       node->debug_handle(),
1295       xnn_status_to_string(status));
1296 
1297   return Error::Ok;
1298 }
1299 
1300 /*
1301 Define serialized ceiling node into the subgraph, using the remapped ids
1302 to map the serialized ids, to the new ids generated when defining the
1303 tensor value
1304 */
defineCeilingNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1305 Error defineCeilingNode(
1306     xnn_subgraph_t subgraph_ptr,
1307     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1308     const NodePtr node,
1309     const fb_xnnpack::XNNGraph* graph) noexcept {
1310   MAYBE_UNUSED(graph);
1311 
1312   auto graph_node = node->xnode_union_as_XNNCeiling();
1313 
1314   xnn_status status = xnn_define_ceiling(
1315       subgraph_ptr,
1316       remapped_ids.at(graph_node->input_id()),
1317       remapped_ids.at(graph_node->output_id()),
1318       graph_node->flags());
1319 
1320   ET_CHECK_OR_RETURN_ERROR(
1321       status == xnn_status_success,
1322       Internal,
1323       "Failed to create ceiling node %i with code: %s",
1324       node->debug_handle(),
1325       xnn_status_to_string(status));
1326 
1327   return Error::Ok;
1328 }
1329 
1330 /*
1331 Define serialized hardswish node into the subgraph, using the remapped ids
1332 to map the serialized ids, to the new ids generated when defining the
1333 tensor value
1334 */
defineHardswishNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1335 Error defineHardswishNode(
1336     xnn_subgraph_t subgraph_ptr,
1337     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1338     const NodePtr node,
1339     const fb_xnnpack::XNNGraph* graph) noexcept {
1340   MAYBE_UNUSED(graph);
1341 
1342   auto graph_node = node->xnode_union_as_XNNHardswish();
1343 
1344   xnn_status status = xnn_define_hardswish(
1345       subgraph_ptr,
1346       remapped_ids.at(graph_node->input_id()),
1347       remapped_ids.at(graph_node->output_id()),
1348       graph_node->flags());
1349 
1350   ET_CHECK_OR_RETURN_ERROR(
1351       status == xnn_status_success,
1352       Internal,
1353       "Failed to create hardswish node %i with code: %s",
1354       node->debug_handle(),
1355       xnn_status_to_string(status));
1356 
1357   return Error::Ok;
1358 }
1359 
1360 /*
1361 Define serialized leaky relu node into the subgraph, using the remapped ids
1362 to map the serialized ids, to the new ids generated when defining the
1363 tensor value
1364 */
defineLeakyReLUNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1365 Error defineLeakyReLUNode(
1366     xnn_subgraph_t subgraph_ptr,
1367     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1368     const NodePtr node,
1369     const fb_xnnpack::XNNGraph* graph) noexcept {
1370   MAYBE_UNUSED(graph);
1371 
1372   auto graph_node = node->xnode_union_as_XNNLeakyReLU();
1373 
1374   xnn_status status = xnn_define_leaky_relu(
1375       subgraph_ptr,
1376       graph_node->negative_slope(),
1377       remapped_ids.at(graph_node->input_id()),
1378       remapped_ids.at(graph_node->output_id()),
1379       graph_node->flags());
1380 
1381   ET_CHECK_OR_RETURN_ERROR(
1382       status == xnn_status_success,
1383       Internal,
1384       "Failed to create leaky relu node %i with code: %s",
1385       node->debug_handle(),
1386       xnn_status_to_string(status));
1387 
1388   return Error::Ok;
1389 }
1390 
1391 /*
1392 Define serialized maximum node into the subgraph, using the remapped ids
1393 to map the serialized ids, to the new ids generated when defining the
1394 tensor value
1395 */
defineMaximumNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1396 Error defineMaximumNode(
1397     xnn_subgraph_t subgraph_ptr,
1398     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1399     const NodePtr node,
1400     const fb_xnnpack::XNNGraph* graph) noexcept {
1401   MAYBE_UNUSED(graph);
1402 
1403   auto graph_node = node->xnode_union_as_XNNMaximum();
1404 
1405   xnn_status status = xnn_define_maximum2(
1406       subgraph_ptr,
1407       remapped_ids.at(graph_node->input1_id()),
1408       remapped_ids.at(graph_node->input2_id()),
1409       remapped_ids.at(graph_node->output_id()),
1410       graph_node->flags());
1411 
1412   ET_CHECK_OR_RETURN_ERROR(
1413       status == xnn_status_success,
1414       Internal,
1415       "Failed to create maximum node %i with code: %s",
1416       node->debug_handle(),
1417       xnn_status_to_string(status));
1418 
1419   return Error::Ok;
1420 }
1421 
1422 /*
1423 Define Negate node into subgraph, using the remapped ids to map the
1424 serialized ids, to the new ids generated when defining the tensor value
1425 */
defineNegateNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1426 Error defineNegateNode(
1427     xnn_subgraph_t subgraph_ptr,
1428     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1429     const NodePtr node,
1430     const fb_xnnpack::XNNGraph* graph) noexcept {
1431   MAYBE_UNUSED(graph);
1432 
1433   auto graph_node = node->xnode_union_as_XNNNegate();
1434 
1435   xnn_status status = xnn_define_negate(
1436       subgraph_ptr,
1437       remapped_ids.at(graph_node->input_id()),
1438       remapped_ids.at(graph_node->output_id()),
1439       graph_node->flags());
1440 
1441   ET_CHECK_OR_RETURN_ERROR(
1442       status == xnn_status_success,
1443       Internal,
1444       "Failed to create negate node %i with code: %s",
1445       node->debug_handle(),
1446       xnn_status_to_string(status));
1447 
1448   return Error::Ok;
1449 }
1450 
1451 /*
1452 Defines square node into subgraph using the remapped ids to map the
1453 serialized ids to the new ids generated when defining the tensor value
1454 */
defineSquareNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1455 Error defineSquareNode(
1456     xnn_subgraph_t subgraph_ptr,
1457     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1458     const NodePtr node,
1459     const fb_xnnpack::XNNGraph* graph) noexcept {
1460   MAYBE_UNUSED(graph);
1461 
1462   auto graph_node = node->xnode_union_as_XNNSquare();
1463 
1464   xnn_status status = xnn_define_square(
1465       subgraph_ptr,
1466       remapped_ids.at(graph_node->input_id()),
1467       remapped_ids.at(graph_node->output_id()),
1468       graph_node->flags());
1469 
1470   ET_CHECK_OR_RETURN_ERROR(
1471       status == xnn_status_success,
1472       Internal,
1473       "Failed to create square node %i with code: %s",
1474       node->debug_handle(),
1475       xnn_status_to_string(status));
1476 
1477   return Error::Ok;
1478 }
1479 
1480 /*
1481 Defines square node into subgraph using the remapped ids to map the
1482 serialized ids to the new ids generated when defining the tensor value
1483 */
defineELUNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1484 Error defineELUNode(
1485     xnn_subgraph_t subgraph_ptr,
1486     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1487     const NodePtr node,
1488     const fb_xnnpack::XNNGraph* graph) noexcept {
1489   MAYBE_UNUSED(graph);
1490 
1491   auto graph_node = node->xnode_union_as_XNNELU();
1492 
1493   xnn_status status = xnn_define_elu(
1494       subgraph_ptr,
1495       graph_node->alpha(),
1496       remapped_ids.at(graph_node->input_id()),
1497       remapped_ids.at(graph_node->output_id()),
1498       graph_node->flags());
1499 
1500   ET_CHECK_OR_RETURN_ERROR(
1501       status == xnn_status_success,
1502       Internal,
1503       "Failed to create ELU node %i with code: %s",
1504       node->debug_handle(),
1505       xnn_status_to_string(status));
1506 
1507   return Error::Ok;
1508 }
1509 
1510 /*
1511 Defines absolute value node into subgraph using the remapped ids to map the
1512 serialized ids to the new ids generated when defining the tensor value
1513 */
defineAbsNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1514 Error defineAbsNode(
1515     xnn_subgraph_t subgraph_ptr,
1516     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1517     const NodePtr node,
1518     const fb_xnnpack::XNNGraph* graph) noexcept {
1519   MAYBE_UNUSED(graph);
1520 
1521   auto graph_node = node->xnode_union_as_XNNAbs();
1522 
1523   xnn_status status = xnn_define_abs(
1524       subgraph_ptr,
1525       remapped_ids.at(graph_node->input_id()),
1526       remapped_ids.at(graph_node->output_id()),
1527       graph_node->flags());
1528 
1529   ET_CHECK_OR_RETURN_ERROR(
1530       status == xnn_status_success,
1531       Internal,
1532       "Failed to create abs node %i with code: %s",
1533       node->debug_handle(),
1534       xnn_status_to_string(status));
1535 
1536   return Error::Ok;
1537 }
1538 
1539 /*
1540 Defines serialized prelu node into the subgraph,
1541 using the remapped ids to map the serialized ids,
1542 to the new ids generated when defining the tensor value
1543 */
definePReLUNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1544 Error definePReLUNode(
1545     xnn_subgraph_t subgraph_ptr,
1546     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1547     const NodePtr node,
1548     const fb_xnnpack::XNNGraph* graph) noexcept {
1549   MAYBE_UNUSED(graph);
1550 
1551   auto graph_node = node->xnode_union_as_XNNPReLU();
1552 
1553   xnn_status status = xnn_define_prelu(
1554       subgraph_ptr,
1555       remapped_ids.at(graph_node->input1_id()),
1556       remapped_ids.at(graph_node->input2_id()),
1557       remapped_ids.at(graph_node->output_id()),
1558       graph_node->flags());
1559 
1560   ET_CHECK_OR_RETURN_ERROR(
1561       status == xnn_status_success,
1562       Internal,
1563       "Failed to create prelu node %i with code: %s",
1564       node->debug_handle(),
1565       xnn_status_to_string(status));
1566 
1567   return Error::Ok;
1568 }
1569 
1570 /*
1571 Defines serialized concatenate2 node into the subgraph,
1572 using the remapped ids to map the serialized ids,
1573 to the new ids generated when defining the tensor value
1574 */
defineConcatenate2Node(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1575 Error defineConcatenate2Node(
1576     xnn_subgraph_t subgraph_ptr,
1577     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1578     const NodePtr node,
1579     const fb_xnnpack::XNNGraph* graph) noexcept {
1580   MAYBE_UNUSED(graph);
1581 
1582   auto graph_node = node->xnode_union_as_XNNConcatenate2();
1583 
1584   xnn_status status = xnn_define_concatenate2(
1585       subgraph_ptr,
1586       graph_node->axis(),
1587       remapped_ids.at(graph_node->input1_id()),
1588       remapped_ids.at(graph_node->input2_id()),
1589       remapped_ids.at(graph_node->output_id()),
1590       graph_node->flags());
1591 
1592   ET_CHECK_OR_RETURN_ERROR(
1593       status == xnn_status_success,
1594       Internal,
1595       "Failed to create cat2 node %i with code: %s",
1596       node->debug_handle(),
1597       xnn_status_to_string(status));
1598 
1599   return Error::Ok;
1600 }
1601 
1602 /*
1603 Defines serialized concatenate2 node into the subgraph,
1604 using the remapped ids to map the serialized ids,
1605 to the new ids generated when defining the tensor value
1606 */
defineConcatenate3Node(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1607 Error defineConcatenate3Node(
1608     xnn_subgraph_t subgraph_ptr,
1609     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1610     const NodePtr node,
1611     const fb_xnnpack::XNNGraph* graph) noexcept {
1612   MAYBE_UNUSED(graph);
1613 
1614   auto graph_node = node->xnode_union_as_XNNConcatenate3();
1615 
1616   xnn_status status = xnn_define_concatenate3(
1617       subgraph_ptr,
1618       graph_node->axis(),
1619       remapped_ids.at(graph_node->input1_id()),
1620       remapped_ids.at(graph_node->input2_id()),
1621       remapped_ids.at(graph_node->input3_id()),
1622       remapped_ids.at(graph_node->output_id()),
1623       graph_node->flags());
1624 
1625   ET_CHECK_OR_RETURN_ERROR(
1626       status == xnn_status_success,
1627       Internal,
1628       "Failed to create cat3 node %i with code: %s",
1629       node->debug_handle(),
1630       xnn_status_to_string(status));
1631 
1632   return Error::Ok;
1633 }
1634 
1635 /*
1636 Defines serialized concatenate2 node into the subgraph,
1637 using the remapped ids to map the serialized ids,
1638 to the new ids generated when defining the tensor value
1639 */
defineConcatenate4Node(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1640 Error defineConcatenate4Node(
1641     xnn_subgraph_t subgraph_ptr,
1642     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1643     const NodePtr node,
1644     const fb_xnnpack::XNNGraph* graph) noexcept {
1645   MAYBE_UNUSED(graph);
1646 
1647   auto graph_node = node->xnode_union_as_XNNConcatenate4();
1648 
1649   xnn_status status = xnn_define_concatenate4(
1650       subgraph_ptr,
1651       graph_node->axis(),
1652       remapped_ids.at(graph_node->input1_id()),
1653       remapped_ids.at(graph_node->input2_id()),
1654       remapped_ids.at(graph_node->input3_id()),
1655       remapped_ids.at(graph_node->input4_id()),
1656       remapped_ids.at(graph_node->output_id()),
1657       graph_node->flags());
1658 
1659   ET_CHECK_OR_RETURN_ERROR(
1660       status == xnn_status_success,
1661       Internal,
1662       "Failed to create cat4 node %i with code: %s",
1663       node->debug_handle(),
1664       xnn_status_to_string(status));
1665 
1666   return Error::Ok;
1667 }
1668 
1669 /*
1670 Defines serialized static_slice node into the subgraph,
1671 using the remapped ids to map the serialized ids,
1672 to the new ids generated when defining the tensor value
1673 */
defineStaticSliceNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1674 Error defineStaticSliceNode(
1675     xnn_subgraph_t subgraph_ptr,
1676     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1677     const NodePtr node,
1678     const fb_xnnpack::XNNGraph* graph) noexcept {
1679   MAYBE_UNUSED(graph);
1680 
1681   auto graph_node = node->xnode_union_as_XNNStaticSlice();
1682 
1683   std::vector<size_t> offsets = flatbufferDimsToVector(graph_node->offsets());
1684   std::vector<size_t> sizes = flatbufferDimsToVector(graph_node->sizes());
1685 
1686   xnn_status status = xnn_define_static_slice(
1687       subgraph_ptr,
1688       graph_node->num_dims(),
1689       offsets.data(),
1690       sizes.data(),
1691       remapped_ids.at(graph_node->input_id()),
1692       remapped_ids.at(graph_node->output_id()),
1693       graph_node->flags());
1694 
1695   ET_CHECK_OR_RETURN_ERROR(
1696       status == xnn_status_success,
1697       Internal,
1698       "Failed to create static slice node %i with code: %s",
1699       node->debug_handle(),
1700       xnn_status_to_string(status));
1701 
1702   return Error::Ok;
1703 }
1704 
1705 /*
1706 Defines Scaled Dot Product Attention (SDPA) node into the subgraph,
1707 using the remapped ids to map the serialized ids,
1708 to the new ids generated when defining the tensor value
1709 */
defineScaledDotProductAttentionNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1710 Error defineScaledDotProductAttentionNode(
1711     xnn_subgraph_t subgraph_ptr,
1712     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1713     const NodePtr node,
1714     const fb_xnnpack::XNNGraph* graph) noexcept {
1715   MAYBE_UNUSED(graph);
1716 
1717   auto graph_node = node->xnode_union_as_XNNScaledDotProductAttention();
1718 
1719   xnn_status status = xnn_define_scaled_dot_product_attention(
1720       subgraph_ptr,
1721       xnn_attention_logits_cap_type_none, // cap_type
1722       nullptr, // cap_value - not used
1723       remapped_ids.at(graph_node->query_id()),
1724       remapped_ids.at(graph_node->key_id()),
1725       remapped_ids.at(graph_node->value_id()),
1726       remapped_ids.at(graph_node->scale_id()),
1727       remapped_ids.at(graph_node->mask_id()),
1728       remapped_ids.at(graph_node->output_id()),
1729       graph_node->flags());
1730 
1731   ET_CHECK_OR_RETURN_ERROR(
1732       status == xnn_status_success,
1733       Internal,
1734       "Failed to create SDPA node %i with code: %s",
1735       node->debug_handle(),
1736       xnn_status_to_string(status));
1737 
1738   return Error::Ok;
1739 }
1740 
1741 /*
1742 Defines batch matrix multiply node into the subgraph,
1743 using the remapped ids to map the serialized ids,
1744 to the new ids generated when defining the tensor value
1745 */
defineBatchMatrixMultiplyNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1746 Error defineBatchMatrixMultiplyNode(
1747     xnn_subgraph_t subgraph_ptr,
1748     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1749     const NodePtr node,
1750     const fb_xnnpack::XNNGraph* graph) noexcept {
1751   MAYBE_UNUSED(graph);
1752 
1753   auto graph_node = node->xnode_union_as_XNNBatchMatrixMultiply();
1754 
1755   xnn_status status = xnn_define_batch_matrix_multiply(
1756       subgraph_ptr,
1757       remapped_ids.at(graph_node->input1_id()),
1758       remapped_ids.at(graph_node->input2_id()),
1759       remapped_ids.at(graph_node->output_id()),
1760       graph_node->flags());
1761 
1762   ET_CHECK_OR_RETURN_ERROR(
1763       status == xnn_status_success,
1764       Internal,
1765       "Failed to create BMM node %i with code: %s",
1766       node->debug_handle(),
1767       xnn_status_to_string(status));
1768 
1769   return Error::Ok;
1770 }
1771 
1772 /*
1773 Returns not Implemented Error code. This function is meant to be
1774 called when the compiler encountes a XNodeType from the flatbuffer
1775 that has not yet been implemented
1776 */
defineNotImplementedNode(xnn_subgraph_t subgraph_ptr,const std::unordered_map<uint32_t,uint32_t> & remapped_ids,const NodePtr node,const fb_xnnpack::XNNGraph * graph)1777 Error defineNotImplementedNode(
1778     xnn_subgraph_t subgraph_ptr,
1779     const std::unordered_map<uint32_t, uint32_t>& remapped_ids,
1780     const NodePtr node,
1781     const fb_xnnpack::XNNGraph* graph) noexcept {
1782   MAYBE_UNUSED(graph);
1783 
1784   ET_CHECK_OR_RETURN_ERROR(
1785       false,
1786       NotImplemented,
1787       "Unhandled node type: %s",
1788       fb_xnnpack::EnumNameXNodeUnion(node->xnode_union_type()));
1789 }
1790 
1791 /*
1792 Returns the pointer to the defineNode function that handles the given
1793 XNode type
1794 */
1795 #define _DEFINE(name)                     \
1796   case fb_xnnpack::XNodeUnion::XNN##name: \
1797     return &define##name##Node;
1798 
getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType)1799 DefineNodeFunc getDefineNodeFunc(fb_xnnpack::XNodeUnion nodeType) {
1800   switch (nodeType) {
1801     _DEFINE(Add)
1802     _DEFINE(FullyConnected)
1803     _DEFINE(Softmax)
1804     _DEFINE(Sigmoid)
1805     _DEFINE(StaticTranspose)
1806     _DEFINE(Clamp)
1807     _DEFINE(Conv2d)
1808     _DEFINE(Div)
1809     _DEFINE(StaticResizeBilinear2D)
1810     _DEFINE(StaticConstantPad)
1811     _DEFINE(AvgPooling2d)
1812     _DEFINE(Minimum)
1813     _DEFINE(DepthwiseConv2d)
1814     _DEFINE(MaxPooling2d)
1815     _DEFINE(Multiply)
1816     _DEFINE(Subtract)
1817     _DEFINE(Floor)
1818     _DEFINE(Convert)
1819     _DEFINE(GlobalAvgPooling2d)
1820     _DEFINE(StaticReshape)
1821     _DEFINE(ArgMaxPooling2d)
1822     _DEFINE(SquareRoot)
1823     _DEFINE(Ceiling)
1824     _DEFINE(Hardswish)
1825     _DEFINE(LeakyReLU)
1826     _DEFINE(Maximum)
1827     _DEFINE(Negate)
1828     _DEFINE(Square)
1829     _DEFINE(ELU)
1830     _DEFINE(Abs)
1831     _DEFINE(PReLU)
1832     _DEFINE(Concatenate2)
1833     _DEFINE(Concatenate3)
1834     _DEFINE(Concatenate4)
1835     _DEFINE(StaticSlice)
1836     _DEFINE(ScaledDotProductAttention)
1837     _DEFINE(BatchMatrixMultiply)
1838     case fb_xnnpack::XNodeUnion::NONE:
1839     default: // Adding here as a catch all, just in case
1840       return &defineNotImplementedNode;
1841   }
1842 }
1843 #undef _DEFINE
1844 
1845 /*
1846 Builds the xnnpack runtime object using the buffer pointer. The buffer pointer
1847 must be a valid pointer to the serialized xnnpack object. It also fills the
1848 XNNExecutor object with the built xnn_runtime and the input/output ids.
1849 */
compileModel(const void * buffer_pointer,size_t num_bytes,XNNExecutor * executor,MemoryAllocator * runtime_allocator,xnn_workspace_t workspace)1850 ET_NODISCARD Error XNNCompiler::compileModel(
1851     const void* buffer_pointer,
1852     size_t num_bytes,
1853     XNNExecutor* executor,
1854     MemoryAllocator* runtime_allocator,
1855     xnn_workspace_t workspace) {
1856   Result<XNNHeader> header = XNNHeader::Parse(buffer_pointer, num_bytes);
1857   const uint8_t* flatbuffer_data = nullptr;
1858   const uint8_t* constant_data = nullptr;
1859   CompileAllocator compile_allocator;
1860 
1861   // Header status can only either be Error::Ok or Error::NotFound
1862   if (header.ok()) {
1863     flatbuffer_data = reinterpret_cast<const uint8_t*>(buffer_pointer) +
1864         header->flatbuffer_offset;
1865     constant_data = reinterpret_cast<const uint8_t*>(buffer_pointer) +
1866         header->constant_data_offset;
1867   } else if (header.error() == Error::NotFound) {
1868     flatbuffer_data = reinterpret_cast<const uint8_t*>(buffer_pointer);
1869   } else {
1870     ET_LOG(Error, "XNNHeader may be corrupt");
1871     return header.error();
1872   }
1873 
1874   // Temporarily support identifier XN00 and XN01
1875   bool is_supported_version =
1876       strncmp(flatbuffers::GetBufferIdentifier(flatbuffer_data), "XN00", 4) ==
1877           0 ||
1878       strncmp(flatbuffers::GetBufferIdentifier(flatbuffer_data), "XN01", 4) ==
1879           0;
1880   ET_CHECK_OR_RETURN_ERROR(
1881       is_supported_version,
1882       DelegateInvalidCompatibility,
1883       "XNNPACK Delegate Serialization Format version identifier '%.4s' != expected XN00 or XN01'",
1884       flatbuffers::GetBufferIdentifier(flatbuffer_data));
1885 
1886   auto flatbuffer_graph = fb_xnnpack::GetXNNGraph(flatbuffer_data);
1887   // initialize xnnpack
1888   xnn_status status = xnn_initialize(/*allocator =*/nullptr);
1889   ET_CHECK_OR_RETURN_ERROR(
1890       xnn_status_success == status,
1891       Internal,
1892       "XNN Initialize failed with code: %s",
1893       xnn_status_to_string(status));
1894 
1895   // create xnnpack subgraph
1896   xnn_subgraph_t subgraph_ptr = nullptr;
1897   status = xnn_create_subgraph(
1898       /*external_value_ids=*/flatbuffer_graph->num_externs(),
1899       /*flags=*/0,
1900       &subgraph_ptr);
1901   ET_CHECK_OR_RETURN_ERROR(
1902       xnn_status_success == status,
1903       Internal,
1904       "XNN Subgraph creation failed with code: %s",
1905       xnn_status_to_string(status));
1906 
1907   std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> subgraph(
1908       subgraph_ptr, &xnn_delete_subgraph);
1909 
1910   // mapping from old ids to new created value ids
1911   // The old ids that were serialied were generated AoT, since
1912   // we are re-defining tensor values, the defined IDs could be
1913   // different from the ones generated AoT, as a result, we need
1914   // a new mapping from the old ids to the newly created ones
1915   std::unordered_map<uint32_t, uint32_t> remapped_ids;
1916   // Invalid ids do not need to be remapped
1917   remapped_ids.emplace(XNN_INVALID_VALUE_ID, XNN_INVALID_VALUE_ID);
1918 
1919   // External Ids for inputs and outputs
1920   std::vector<uint32_t> input_ids;
1921   std::vector<uint32_t> output_ids;
1922   Error err = Error::Ok;
1923   for (auto value : *flatbuffer_graph->xvalues()) {
1924     err = defineTensor(
1925         subgraph.get(),
1926         remapped_ids,
1927         value,
1928         flatbuffer_graph,
1929         constant_data,
1930         input_ids,
1931         output_ids,
1932         compile_allocator);
1933 
1934     if (err != Error::Ok) {
1935       return err;
1936     }
1937   }
1938 
1939   for (auto node : *flatbuffer_graph->xnodes()) {
1940     err = getDefineNodeFunc(node->xnode_union_type())(
1941         subgraph.get(), remapped_ids, node, flatbuffer_graph);
1942     if (err != Error::Ok) {
1943       return err;
1944     }
1945   }
1946   uint32_t runtime_flags = 0;
1947 
1948 #if defined(ENABLE_XNNPACK_PROFILING) || defined(ET_EVENT_TRACER_ENABLED)
1949   runtime_flags |= XNN_FLAG_BASIC_PROFILING;
1950 #endif
1951 
1952   xnn_runtime_t runtime_ptr = nullptr;
1953 
1954 #ifdef ENABLE_XNNPACK_SHARED_WORKSPACE
1955   ET_CHECK_OR_RETURN_ERROR(
1956       workspace != nullptr, Internal, "Failed to initialize XNNPACK workspace");
1957   status = xnn_create_runtime_v4(
1958       subgraph.get(),
1959       /*weight_cache=*/nullptr, // TODO - support weight cache
1960       workspace,
1961       ::executorch::extension::threadpool::get_pthreadpool(),
1962       runtime_flags,
1963       &runtime_ptr);
1964 #else
1965   status = xnn_create_runtime_v3(
1966       subgraph.get(),
1967       /*weight_cache=*/nullptr, // TODO - support weight cache
1968       ::executorch::extension::threadpool::get_pthreadpool(),
1969       runtime_flags,
1970       &runtime_ptr);
1971 #endif
1972 
1973   ET_CHECK_OR_RETURN_ERROR(
1974       xnn_status_success == status,
1975       Internal,
1976       "XNN Runtime creation failed with code: %s",
1977       xnn_status_to_string(status));
1978 
1979   err = executor->initialize( // NOLINT: runtime_ptr is non-null
1980       runtime_ptr,
1981       std::move(input_ids),
1982       std::move(output_ids));
1983 
1984   return err;
1985 };
1986 
1987 } // namespace delegate
1988 } // namespace xnnpack
1989 } // namespace backends
1990 } // namespace executorch
1991