xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/xnnpack/xnnpack_delegate.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
17 
18 #include <algorithm>
19 #include <array>
20 #include <cmath>
21 #include <cstdint>
22 #include <cstring>
23 #include <limits>
24 #include <memory>
25 #include <string>
26 #include <unordered_map>
27 #include <unordered_set>
28 #include <utility>
29 #include <vector>
30 
31 #include "xnnpack.h"  // from @XNNPACK
32 #include "tensorflow/lite/builtin_ops.h"
33 #include "tensorflow/lite/c/builtin_op_data.h"
34 #include "tensorflow/lite/c/common.h"
35 #include "tensorflow/lite/core/api/profiler.h"
36 #include "tensorflow/lite/delegates/xnnpack/quantization_util.h"
37 #include "tensorflow/lite/kernels/internal/compatibility.h"
38 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
39 #include "tensorflow/lite/kernels/internal/utils/sparsity_format_converter.h"
40 #include "tensorflow/lite/kernels/kernel_util.h"
41 #include "tensorflow/lite/kernels/padding.h"
42 #include "tensorflow/lite/minimal_logging.h"
43 #include "tensorflow/lite/tools/optimize/reduced_precision_support.h"
44 
45 struct TfLiteXNNPackDelegateWeightsCache;
46 
47 namespace tflite {
48 namespace xnnpack {
49 namespace {
50 
51 template <typename T>
SafeCopyCustomData(const TfLiteNode & node,T * target)52 void SafeCopyCustomData(const TfLiteNode& node, T* target) {
53   const size_t safe_size =
54       std::min(static_cast<size_t>(node.custom_initial_data_size), sizeof(T));
55   std::memcpy(target, node.custom_initial_data, safe_size);
56 }
57 
GetXNNPackDatatype(const TfLiteTensor & tensor)58 xnn_datatype GetXNNPackDatatype(const TfLiteTensor& tensor) {
59   switch (tensor.type) {
60     case kTfLiteFloat32:
61       return xnn_datatype_fp32;
62     case kTfLiteFloat16:
63       return xnn_datatype_fp16;
64     case kTfLiteUInt8:
65       if (tensor.quantization.type == kTfLiteAffineQuantization) {
66         const auto quantization_params =
67             static_cast<const TfLiteAffineQuantization*>(
68                 tensor.quantization.params);
69         if (quantization_params->scale == nullptr ||
70             quantization_params->zero_point == nullptr ||
71             quantization_params->scale->size != 1 ||
72             quantization_params->zero_point->size != 1) {
73           return xnn_datatype_invalid;
74         }
75 
76         const float scale = quantization_params->scale->data[0];
77         if (!std::isnormal(scale) || scale <= 0.0f) {
78           return xnn_datatype_invalid;
79         }
80 
81         const int zero_point = quantization_params->zero_point->data[0];
82         if (zero_point < std::numeric_limits<uint8_t>::min() ||
83             zero_point > std::numeric_limits<uint8_t>::max()) {
84           return xnn_datatype_invalid;
85         }
86 
87         return xnn_datatype_quint8;
88       }
89       break;
90     case kTfLiteInt8:
91       if (tensor.quantization.type == kTfLiteAffineQuantization) {
92         const auto quantization_params =
93             static_cast<const TfLiteAffineQuantization*>(
94                 tensor.quantization.params);
95         if (quantization_params->scale == nullptr ||
96             quantization_params->zero_point == nullptr ||
97             quantization_params->scale->size <= 0 ||
98             quantization_params->zero_point->size != 1) {
99           return xnn_datatype_invalid;
100         }
101 
102         const int zero_point = quantization_params->zero_point->data[0];
103         if (zero_point < std::numeric_limits<int8_t>::min() ||
104             zero_point > std::numeric_limits<int8_t>::max()) {
105           return xnn_datatype_invalid;
106         }
107 
108         for (int i = 0; i < quantization_params->scale->size; i++) {
109           const float scale = quantization_params->scale->data[i];
110           if (!std::isnormal(scale) || scale <= 0.0f) {
111             return xnn_datatype_invalid;
112           }
113         }
114 
115         return quantization_params->scale->size == 1 ? xnn_datatype_qint8
116                                                      : xnn_datatype_qcint8;
117       }
118       break;
119     case kTfLiteInt32:
120       if (tensor.quantization.type == kTfLiteAffineQuantization) {
121         const auto quantization_params =
122             static_cast<const TfLiteAffineQuantization*>(
123                 tensor.quantization.params);
124         if (quantization_params->scale == nullptr ||
125             quantization_params->zero_point == nullptr ||
126             quantization_params->scale->size <= 0 ||
127             quantization_params->zero_point->size != 1) {
128           return xnn_datatype_invalid;
129         }
130 
131         const int zero_point = quantization_params->zero_point->data[0];
132         if (zero_point == 0) {
133           return xnn_datatype_invalid;
134         }
135 
136         for (int i = 0; i < quantization_params->scale->size; i++) {
137           const float scale = quantization_params->scale->data[i];
138           if (!std::isnormal(scale) || scale <= 0.0f) {
139             return xnn_datatype_invalid;
140           }
141         }
142 
143         return quantization_params->scale->size == 1 ? xnn_datatype_qint32
144                                                      : xnn_datatype_qcint32;
145       }
146       break;
147     default:
148       break;
149   }
150   return xnn_datatype_invalid;
151 }
152 
153 // Forward declaration.
154 TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate);
155 
156 class Delegate {
157   friend class Subgraph;
158 
159  public:
Delegate(const TfLiteXNNPackDelegateOptions * options,xnn_workspace_t workspace)160   explicit Delegate(const TfLiteXNNPackDelegateOptions* options,
161                     xnn_workspace_t workspace) {
162 #if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__)
163     if (options != nullptr && options->num_threads > 1) {
164       threadpool_.reset(
165           pthreadpool_create(static_cast<size_t>(options->num_threads)));
166     }
167 #endif
168     TFLITE_LOG_PROD_ONCE(tflite::TFLITE_LOG_INFO,
169                          "Created TensorFlow Lite XNNPACK delegate for CPU.");
170 
171     options_ =
172         options != nullptr ? *options : TfLiteXNNPackDelegateOptionsDefault();
173     workspace_.reset(workspace);
174   }
175 
176   TfLiteIntArray* PrepareOpsToDelegate(TfLiteContext* context);
tflite_delegate()177   TfLiteDelegate* tflite_delegate() { return &delegate_; }
178 
support_signed_8bit_quantization() const179   bool support_signed_8bit_quantization() const {
180     return (options_.flags & TFLITE_XNNPACK_DELEGATE_FLAG_QS8) != 0;
181   }
182 
support_unsigned_8bit_quantization() const183   bool support_unsigned_8bit_quantization() const {
184     return (options_.flags & TFLITE_XNNPACK_DELEGATE_FLAG_QU8) != 0;
185   }
186 
support_any_8bit_quantization() const187   bool support_any_8bit_quantization() const {
188     return (options_.flags & (TFLITE_XNNPACK_DELEGATE_FLAG_QU8 |
189                               TFLITE_XNNPACK_DELEGATE_FLAG_QS8)) != 0;
190   }
191 
force_fp16() const192   bool force_fp16() const {
193 #ifdef XNNPACK_DELEGATE_FORCE_PRECISION_FP16
194     return true;
195 #else
196     return (options_.flags & TFLITE_XNNPACK_DELEGATE_FLAG_FORCE_FP16) != 0;
197 #endif
198   }
199 
threadpool() const200   pthreadpool_t threadpool() const {
201 #if defined(__EMSCRIPTEN__) && !defined(__EMSCRIPTEN_PTHREADS__)
202     return nullptr;
203 #else
204     return threadpool_.get();
205 #endif
206   }
207 
weights_cache() const208   xnn_weights_cache_t weights_cache() const {
209     if (options_.weights_cache == nullptr) {
210       return nullptr;
211     } else {
212       return reinterpret_cast<xnn_weights_cache_t>(options_.weights_cache);
213     }
214   }
215 
workspace() const216   xnn_workspace_t workspace() const { return workspace_.get(); }
217 
218  private:
219   TfLiteDelegate delegate_ = {
220       reinterpret_cast<void*>(this),  // .data_
221       DelegatePrepare,                // .Prepare
222       nullptr,                        // .CopyFromBufferHandle
223       nullptr,                        // .CopyToBufferHandle
224       nullptr,                        // .FreeBufferHandle
225       kTfLiteDelegateFlagsNone,       // .flags
226   };
227 
228   // Unpacked data for quasi-static tensors, i.e. tensors produced by
229   // dequantizing or unpacking static buffers.
230   std::vector<char> static_unpacked_data_;
231   // Mapping from a tensor index for a quasi-static tensor to the offset to
232   // its unpacked data within static_unpacked_data_.
233   std::unordered_map<int, size_t> static_unpacked_data_map_;
234   // Set of indices of nodes which unpack static data, e.g. Dequantize
235   // operators which convert FP16 static weights to FP32. These nodes are simply
236   // ignored in the delegate implementation, because their outputs are
237   // pre-unpacked in DelegatePrepare.
238   std::unordered_set<int> static_unpack_nodes_;
239   // Set of indices of tensors with unpacked static sparse weights.
240   std::unordered_set<int> static_sparse_weights_;
241 #if !defined(__EMSCRIPTEN__) || defined(__EMSCRIPTEN_PTHREADS__)
242   // Thread pool with smart-pointer for lifetime management.
243   std::unique_ptr<pthreadpool, decltype(&pthreadpool_destroy)> threadpool_{
244       nullptr, &pthreadpool_destroy};
245 #endif
246   std::unique_ptr<xnn_workspace, decltype(&xnn_release_workspace)> workspace_{
247       nullptr, &xnn_release_workspace};
248 
249   TfLiteXNNPackDelegateOptions options_;
250 };
251 
252 class Subgraph {
253  public:
Create(TfLiteContext * context,const TfLiteDelegateParams * params,const Delegate & delegate)254   static Subgraph* Create(TfLiteContext* context,
255                           const TfLiteDelegateParams* params,
256                           const Delegate& delegate) {
257     // Convert subgraph inputs and outputs to hash sets for faster lookup.
258     const std::unordered_set<int> inputs(
259         &params->input_tensors->data[0],
260         &params->input_tensors->data[params->input_tensors->size]);
261     std::unordered_set<int> outputs;
262     for (int o = 0; o < params->output_tensors->size; o++) {
263       const int output_tensor_idx = params->output_tensors->data[o];
264       // Exclude quasi-static tensors which may have become subgraph outputs
265       // after partitioning.
266       if (delegate.static_unpacked_data_map_.count(output_tensor_idx) == 0) {
267         outputs.insert(output_tensor_idx);
268       }
269     }
270     std::unordered_set<int> externals(outputs);
271 
272     TfLiteIntArray* execution_plan;
273     if (context->GetExecutionPlan(context, &execution_plan) != kTfLiteOk) {
274       return nullptr;
275     }
276 
277     xnn_subgraph_t subgraph_ptr = nullptr;
278     xnn_status status = xnn_create_subgraph(
279         /*external_value_ids=*/context->tensors_size, /*flags=*/0,
280         &subgraph_ptr);
281     if (status != xnn_status_success) {
282       TF_LITE_KERNEL_LOG(context, "failed to create XNNPACK subgraph");
283       return nullptr;
284     }
285 
286     // Smart pointer to automatically release subgraph on exit.
287     std::unique_ptr<xnn_subgraph, decltype(&xnn_delete_subgraph)> subgraph(
288         subgraph_ptr, &xnn_delete_subgraph);
289 
290     bool has_sparse_weights = false;
291     // Detect which tensors are used as inputs or outputs of any subgraph nodes.
292     // -1 denotes tensor not used in the subgraph. These indexes will be
293     // filtered out and removed later.
294     std::vector<int> tensors(context->tensors_size, -1);
295     for (int i = 0; i < params->nodes_to_replace->size; i++) {
296       const int node_index = params->nodes_to_replace->data[i];
297 
298       TfLiteNode* node = nullptr;
299       TfLiteRegistration* registration = nullptr;
300       if (context->GetNodeAndRegistration(context, node_index, &node,
301                                           &registration) != kTfLiteOk) {
302         return nullptr;
303       }
304 
305       // Detect if any of the node's inputs are sparse weights.
306       if (!has_sparse_weights) {
307         for (int i = 0; i < node->inputs->size; i++) {
308           if (delegate.static_sparse_weights_.count(node->inputs->data[i]) !=
309               0) {
310             has_sparse_weights = true;
311           }
312         }
313       }
314 
315       if (delegate.static_unpack_nodes_.count(node_index) != 0) {
316         // The node unpacks static input and can be skipped because its input
317         // was pre-unpacked in DelegatePrepare.
318         continue;
319       }
320 
321       switch (registration->builtin_code) {
322         case kTfLiteBuiltinMean:
323         case kTfLiteBuiltinPad:
324         case kTfLiteBuiltinReshape:
325         case kTfLiteBuiltinResizeBilinear:
326           // Ignore the second input (axes, static padding, or new shape),
327           // because it is represented as parameters of the XNNPACK operator
328           // rather than extra input.
329           {
330             const int t = node->inputs->data[0];
331             tensors[t] = t;
332           }
333           break;
334         case kTfLiteBuiltinSplit:
335           // Ignore the first input (split_dim), as it is represented as
336           // parameters of the XNNPACK operator rather than extra input.
337           {
338             const int t = node->inputs->data[1];
339             tensors[t] = t;
340             break;
341           }
342         case kTfLiteBuiltinTranspose:
343           // Ignore the second input (perm), as it is represented as
344           // parameters of the XNNPACK operator rather than extra input.
345           {
346             const int t = node->inputs->data[0];
347             tensors[t] = t;
348             break;
349           }
350         default:
351           // All other operators: process all inputs
352           for (int k = 0; k < node->inputs->size; k++) {
353             if (registration->builtin_code == kTfLiteBuiltinTransposeConv &&
354                 k == 0) {
355               // Ignore the output size parameter (see above).
356               continue;
357             }
358             const int t = node->inputs->data[k];
359             if (t >= 0) {
360               tensors[t] = t;
361             }
362           }
363       }
364       for (int k = 0; k < node->outputs->size; k++) {
365         const int t = node->outputs->data[k];
366         if (t >= 0) {
367           tensors[t] = t;
368         }
369       }
370     }
371     // Filter out and remove -1 (unused) indexes.
372     tensors.erase(std::remove_if(tensors.begin(), tensors.end(),
373                                  [](int i) { return i < 0; }),
374                   tensors.end());
375     std::sort(tensors.begin(), tensors.end());
376 
377     // XNNPACK Value IDs for TFLite tensors
378     std::vector<uint32_t> xnnpack_tensors(tensors.back() + 1);
379     for (int t : tensors) {
380       xnn_datatype datatype = xnn_datatype_invalid;
381       switch (context->tensors[t].type) {
382         case kTfLiteFloat32:
383           datatype = xnn_datatype_fp32;
384           break;
385         case kTfLiteInt8: {
386           if (context->tensors[t].quantization.type !=
387               kTfLiteAffineQuantization) {
388             TF_LITE_KERNEL_LOG(context,
389                                "unsupported quantization type %d for INT8 "
390                                "tensor %d in XNNPACK delegate",
391                                context->tensors[t].quantization.type, t);
392             return nullptr;
393           }
394           const auto quantization_params =
395               static_cast<const TfLiteAffineQuantization*>(
396                   context->tensors[t].quantization.params);
397           if (quantization_params->scale == nullptr) {
398             TF_LITE_KERNEL_LOG(context,
399                                "missing scale quantization parameters for INT8 "
400                                "tensor %d in XNNPACK delegate",
401                                t);
402             return nullptr;
403           }
404           if (quantization_params->zero_point == nullptr) {
405             TF_LITE_KERNEL_LOG(context,
406                                "missing zero point quantization parameters for "
407                                "INT8 tensor %d in XNNPACK delegate",
408                                t);
409             return nullptr;
410           }
411           if (quantization_params->scale->size !=
412               quantization_params->zero_point->size) {
413             TF_LITE_KERNEL_LOG(context,
414                                "mismatching number of scale (%d) and zero "
415                                "point (%d) quantization parameters for INT8 "
416                                "tensor %d in XNNPACK delegate",
417                                quantization_params->scale->size,
418                                quantization_params->zero_point->size, t);
419             return nullptr;
420           }
421           if (quantization_params->scale->size == 1) {
422             // Per-tensor quantization parameters
423             datatype = xnn_datatype_qint8;
424           } else if (NumDimensions(&context->tensors[t]) >= 1 &&
425                      quantization_params->scale->size ==
426                          SizeOfDimension(
427                              &context->tensors[t],
428                              quantization_params->quantized_dimension)) {
429             // Per-channel quantization parameters
430             for (int c = 0;
431                  c < SizeOfDimension(&context->tensors[t],
432                                      quantization_params->quantized_dimension);
433                  c++) {
434               if (quantization_params->zero_point->data[c] != 0) {
435                 TF_LITE_KERNEL_LOG(context,
436                                    "unsupported zero-point value %d in channel "
437                                    "%d of INT8 tensor %d in XNNPACK delegate",
438                                    quantization_params->zero_point[c], c, t);
439                 return nullptr;
440               }
441             }
442             datatype = xnn_datatype_qcint8;
443           } else {
444             TF_LITE_KERNEL_LOG(
445                 context,
446                 "mismatching number of quantization parameters %d and outer "
447                 "dimension %d for INT8 tensor %d in XNNPACK delegate",
448                 quantization_params->scale->size,
449                 SizeOfDimension(&context->tensors[t], 0), t);
450             return nullptr;
451           }
452           break;
453         }
454         case kTfLiteUInt8: {
455           if (context->tensors[t].quantization.type !=
456               kTfLiteAffineQuantization) {
457             TF_LITE_KERNEL_LOG(context,
458                                "unsupported quantization type %d for UINT8 "
459                                "tensor %d in XNNPACK delegate",
460                                context->tensors[t].quantization.type, t);
461             return nullptr;
462           }
463           const auto quantization_params =
464               static_cast<const TfLiteAffineQuantization*>(
465                   context->tensors[t].quantization.params);
466           if (quantization_params->scale == nullptr) {
467             TF_LITE_KERNEL_LOG(
468                 context,
469                 "missing scale quantization parameters for UINT8 "
470                 "tensor %d in XNNPACK delegate",
471                 t);
472             return nullptr;
473           }
474           if (quantization_params->zero_point == nullptr) {
475             TF_LITE_KERNEL_LOG(context,
476                                "missing zero point quantization parameters for "
477                                "UINT8 tensor %d in XNNPACK delegate",
478                                t);
479             return nullptr;
480           }
481           if (quantization_params->scale->size != 1) {
482             TF_LITE_KERNEL_LOG(
483                 context,
484                 "unsupported number (%d) of scale quantization parameters for "
485                 "UINT8 tensor %d in XNNPACK delegate",
486                 quantization_params->scale->size, t);
487             return nullptr;
488           }
489           if (quantization_params->zero_point->size != 1) {
490             TF_LITE_KERNEL_LOG(
491                 context,
492                 "unsupported number (%d) of zero point quantization parameters "
493                 "for UINT8 tensor %d in XNNPACK delegate",
494                 quantization_params->zero_point->size, t);
495             return nullptr;
496           }
497           datatype = xnn_datatype_quint8;
498           break;
499         }
500         case kTfLiteInt32: {
501           if (context->tensors[t].quantization.type !=
502               kTfLiteAffineQuantization) {
503             TF_LITE_KERNEL_LOG(context,
504                                "unsupported quantization type %d for INT32 "
505                                "tensor %d in XNNPACK delegate",
506                                context->tensors[t].quantization.type, t);
507             return nullptr;
508           }
509           const auto quantization_params =
510               static_cast<const TfLiteAffineQuantization*>(
511                   context->tensors[t].quantization.params);
512           if (quantization_params->scale == nullptr) {
513             TF_LITE_KERNEL_LOG(context,
514                                "missing scale quantization parameters for "
515                                "INT32 tensor %d in XNNPACK delegate",
516                                t);
517             return nullptr;
518           }
519           if (quantization_params->zero_point == nullptr) {
520             TF_LITE_KERNEL_LOG(context,
521                                "missing zero point quantization parameters for "
522                                "INT32 tensor %d in XNNPACK delegate",
523                                t);
524             return nullptr;
525           }
526           if (quantization_params->scale->size !=
527               quantization_params->zero_point->size) {
528             TF_LITE_KERNEL_LOG(context,
529                                "mismatching number of scale (%d) and zero "
530                                "point (%d) quantization parameters for INT32 "
531                                "tensor %d in XNNPACK delegate",
532                                quantization_params->scale->size,
533                                quantization_params->zero_point->size, t);
534             return nullptr;
535           }
536           if (quantization_params->quantized_dimension != 0) {
537             TF_LITE_KERNEL_LOG(context,
538                                "unsupported quantized dimension %d for INT32 "
539                                "tensor %d in XNNPACK delegate",
540                                quantization_params->quantized_dimension, t);
541             return nullptr;
542           }
543           if (quantization_params->scale->size == 1) {
544             // Per-tensor quantization parameters
545             if (quantization_params->zero_point->data[0] != 0) {
546               TF_LITE_KERNEL_LOG(context,
547                                  "unsupported zero-point value %d for INT32 "
548                                  "tensor %d in XNNPACK delegate",
549                                  quantization_params->zero_point->data[0], t);
550               return nullptr;
551             }
552             datatype = xnn_datatype_qint32;
553           } else if (NumDimensions(&context->tensors[t]) >= 1 &&
554                      quantization_params->scale->size ==
555                          SizeOfDimension(&context->tensors[t], 0)) {
556             // Per-channel quantization parameters
557             for (int c = 0; c < SizeOfDimension(&context->tensors[t], 0); c++) {
558               if (quantization_params->zero_point->data[c] != 0) {
559                 TF_LITE_KERNEL_LOG(context,
560                                    "unsupported zero-point value %d in channel "
561                                    "%d of INT32 tensor %d in XNNPACK delegate",
562                                    quantization_params->zero_point->data[c], c,
563                                    t);
564                 return nullptr;
565               }
566             }
567             datatype = xnn_datatype_qcint32;
568           } else {
569             TF_LITE_KERNEL_LOG(
570                 context,
571                 "mismatching number of quantization parameters %d and outer "
572                 "dimension %d for INT8 tensor %d in XNNPACK delegate",
573                 quantization_params->scale->size,
574                 SizeOfDimension(&context->tensors[t], 0), t);
575             return nullptr;
576           }
577           break;
578         }
579         default:
580           TF_LITE_KERNEL_LOG(
581               context,
582               "unsupported datatype (%s) of tensor %d in XNNPACK delegate",
583               TfLiteTypeGetName(context->tensors[t].type), t);
584           return nullptr;
585       }
586 
587       uint32_t flags = 0;
588       const void* data = nullptr;
589       if (context->tensors[t].allocation_type == kTfLiteMmapRo) {
590         data = context->tensors[t].data.raw_const;
591       } else {
592         // Check for quasi-static data.
593         const auto it = delegate.static_unpacked_data_map_.find(t);
594         if (it != delegate.static_unpacked_data_map_.end()) {
595           data = delegate.static_unpacked_data_.data() + it->second;
596         }
597       }
598       if (inputs.count(t) != 0) {
599         flags |= XNN_VALUE_FLAG_EXTERNAL_INPUT;
600         if (data == nullptr) {
601           externals.insert(t);
602         }
603       }
604       if (outputs.count(t) != 0) {
605         flags |= XNN_VALUE_FLAG_EXTERNAL_OUTPUT;
606       }
607 
608       std::vector<size_t> dims(
609           &context->tensors[t].dims->data[0],
610           &context->tensors[t].dims->data[NumDimensions(&context->tensors[t])]);
611 
612       xnn_status status = xnn_status_success;
613       switch (datatype) {
614         case xnn_datatype_qint8:
615         case xnn_datatype_quint8:
616         case xnn_datatype_qint32:
617           status = xnn_define_quantized_tensor_value(
618               subgraph.get(), datatype,
619               static_cast<const TfLiteAffineQuantization*>(
620                   context->tensors[t].quantization.params)
621                   ->zero_point->data[0],
622               static_cast<const TfLiteAffineQuantization*>(
623                   context->tensors[t].quantization.params)
624                   ->scale->data[0],
625               dims.size(), dims.data(), data, static_cast<uint32_t>(t), flags,
626               &xnnpack_tensors[t]);
627           break;
628         case xnn_datatype_qcint8:
629         case xnn_datatype_qcint32:
630           status = xnn_define_channelwise_quantized_tensor_value(
631               subgraph.get(), datatype,
632               static_cast<const TfLiteAffineQuantization*>(
633                   context->tensors[t].quantization.params)
634                   ->scale->data,
635               dims.size(),
636               static_cast<const TfLiteAffineQuantization*>(
637                   context->tensors[t].quantization.params)
638                   ->quantized_dimension,
639               dims.data(), data, static_cast<uint32_t>(t), flags,
640               &xnnpack_tensors[t]);
641           break;
642         default:
643           status = xnn_define_tensor_value(
644               subgraph.get(), datatype, dims.size(), dims.data(), data,
645               static_cast<uint32_t>(t), flags, &xnnpack_tensors[t]);
646           break;
647       }
648       if (status != xnn_status_success) {
649         TF_LITE_KERNEL_LOG(context,
650                            "failed to create XNNPACK Value for tensor %d", t);
651         return nullptr;
652       }
653     }
654 
655     // Create a set of quasi-static tensors for VisitNode function
656     std::unordered_set<int> quasi_static_tensors;
657     for (const std::pair<const int, size_t>& entry :
658          delegate.static_unpacked_data_map_) {
659       quasi_static_tensors.insert(entry.first);
660     }
661 
662     // Create XNNPACK nodes for TFLite delegate nodes
663     for (int i = 0; i < params->nodes_to_replace->size; i++) {
664       const int node_index = params->nodes_to_replace->data[i];
665       if (delegate.static_unpack_nodes_.count(node_index)) {
666         // The node unpacks static input and can be skipped because its input
667         // was pre-unpacked in DelegatePrepare.
668         continue;
669       }
670 
671       TfLiteNode* node = nullptr;
672       TfLiteRegistration* registration = nullptr;
673       if (context->GetNodeAndRegistration(context, node_index, &node,
674                                           &registration) != kTfLiteOk) {
675         return nullptr;
676       }
677 
678       if (VisitNode(subgraph.get(), delegate, context, registration, node,
679                     node_index, quasi_static_tensors,
680                     xnnpack_tensors) != kTfLiteOk) {
681         return nullptr;
682       }
683     }
684 
685     xnn_runtime_t runtime_ptr = nullptr;
686     uint32_t flags = XNN_FLAG_YIELD_WORKERS;
687     if (has_sparse_weights) {
688       flags |= XNN_FLAG_HINT_SPARSE_INFERENCE;
689     }
690     if (delegate.force_fp16()) {
691       flags |= XNN_FLAG_FORCE_FP16_INFERENCE;
692     } else {
693       const char* precision_metadata_ptr = nullptr;
694       size_t precision_metadata_size = 0;
695       if (context->GetModelMetadata(
696               context, optimize::kTfLiteReducedPrecisionKey,
697               &precision_metadata_ptr, &precision_metadata_size) == kTfLiteOk) {
698         const std::string precision_metadata(precision_metadata_ptr,
699                                              precision_metadata_size);
700         optimize::ReducedPrecisionSupport precision_mask =
701             optimize::ReducedPrecisionSupport::None;
702         if (optimize::SetMaskFromReducedPrecisionMetadata(precision_metadata,
703                                                           &precision_mask)) {
704           if (optimize::SupportsFP16Inference(precision_mask) &&
705               optimize::SupportsFP16Accumulation(precision_mask)) {
706             flags |= XNN_FLAG_HINT_FP16_INFERENCE;
707           }
708         }
709       }
710     }
711     if (context->profiler) {
712       flags |= XNN_FLAG_BASIC_PROFILING;
713     }
714     status = xnn_create_runtime_v4(subgraph.get(), delegate.weights_cache(),
715                                    delegate.workspace(), delegate.threadpool(),
716                                    flags, &runtime_ptr);
717     if (status != xnn_status_success) {
718       TF_LITE_KERNEL_LOG(context, "failed to create XNNPACK runtime");
719       return nullptr;
720     }
721 
722     return new Subgraph(delegate, runtime_ptr, externals);
723   }
724 
Prepare(TfLiteContext * context)725   TfLiteStatus Prepare(TfLiteContext* context) { return kTfLiteOk; }
726 
Invoke(TfLiteContext * context)727   TfLiteStatus Invoke(TfLiteContext* context) {
728     bool any_pointers_changed = false;
729     for (std::pair<int, void*> io_info : externals_) {
730       const TfLiteTensor& tensor = context->tensors[io_info.first];
731       void* data_pointer = &dummy_data_;
732       if (tensor.data.raw != nullptr) {
733         data_pointer = tensor.data.raw;
734       } else {
735         if (tensor.bytes != 0) {
736           TF_LITE_KERNEL_LOG(
737               context, "unexpected null data pointer in external tensor %d",
738               io_info.first);
739           return kTfLiteError;
740         }
741       }
742       if (data_pointer != io_info.second) {
743         any_pointers_changed = true;
744         externals_[io_info.first] = data_pointer;
745       }
746     }
747 
748     if (any_pointers_changed) {
749       std::vector<xnn_external_value> external_values;
750       for (std::pair<int, void*> io_info : externals_) {
751         xnn_external_value value = {0};
752         value.id = static_cast<uint32_t>(io_info.first);
753         value.data = io_info.second;
754         external_values.push_back(value);
755       }
756 
757       const xnn_status status = xnn_setup_runtime(
758           runtime_.get(), external_values.size(), external_values.data());
759       if (status != xnn_status_success) {
760         TF_LITE_KERNEL_LOG(context, "failed to setup XNNPACK runtime");
761         return kTfLiteError;
762       }
763     }
764 
765     xnn_status status = xnn_invoke_runtime(runtime_.get());
766     if (status != xnn_status_success) {
767       TF_LITE_KERNEL_LOG(context, "failed to invoke XNNPACK runtime");
768       return kTfLiteError;
769     }
770 
771     if (context->profiler) {
772       if (AddEventsToProfiler(reinterpret_cast<Profiler*>(context->profiler),
773                               runtime_.get()) != kTfLiteOk) {
774         TF_LITE_KERNEL_LOG(context,
775                            "failed to get XNNPACK profile information.");
776       }
777     }
778 
779     return kTfLiteOk;
780   }
781 
782   // Fetch the profile information from XNNPACK and add the events to TfLite's
783   // profiler.
AddEventsToProfiler(Profiler * profiler,const xnn_runtime_t runtime)784   static TfLiteStatus AddEventsToProfiler(Profiler* profiler,
785                                           const xnn_runtime_t runtime) {
786     size_t required_size = 0;
787 
788     // xnn_get_runtime_profiling_info is called twice. The first time it sets
789     // required_size to the required size of the buffer to store the result and
790     // returns xnn_status_out_of_memory. The second time it writes the result to
791     // the buffer provided that the buffer is large enough and returns
792     // xnn_status_success.
793     xnn_status status = xnn_get_runtime_profiling_info(
794         runtime, xnn_profile_info_operator_name, /*param_value_size*/ 0,
795         /*param_value*/ nullptr, &required_size);
796     std::vector<char> operator_names;
797     if (status == xnn_status_out_of_memory) {
798       operator_names.resize(required_size);
799       status = xnn_get_runtime_profiling_info(
800           runtime, xnn_profile_info_operator_name, operator_names.size(),
801           operator_names.data(), &required_size);
802     }
803     if (status != xnn_status_success) {
804       return kTfLiteError;
805     }
806     size_t num_operators;
807     status = xnn_get_runtime_profiling_info(
808         runtime, xnn_profile_info_num_operators, sizeof(num_operators),
809         &num_operators, &required_size);
810     if (status != xnn_status_success) {
811       return kTfLiteError;
812     }
813     status = xnn_get_runtime_profiling_info(
814         runtime, xnn_profile_info_operator_timing, /*param_value_size*/ 0,
815         /*param_value*/ nullptr, &required_size);
816     std::vector<uint64_t> operator_timings;
817     if (status == xnn_status_out_of_memory) {
818       operator_timings.resize(required_size / sizeof(uint64_t));
819       status = xnn_get_runtime_profiling_info(
820           runtime, xnn_profile_info_operator_timing,
821           operator_timings.size() * sizeof(uint64_t), operator_timings.data(),
822           &required_size);
823     }
824     if (status != xnn_status_success) {
825       return kTfLiteError;
826     }
827     const char* operator_name = nullptr;
828     size_t name_len = 0;
829     for (size_t node_index = 0; node_index < num_operators; ++node_index) {
830       operator_name = &operator_names[name_len];
831       name_len += strlen(operator_name) + 1;
832       profiler->AddEvent(operator_name,
833                          Profiler::EventType::DELEGATE_OPERATOR_INVOKE_EVENT,
834                          operator_timings[node_index], node_index);
835     }
836     return kTfLiteOk;
837   }
838 
CalculatePadding(TfLiteContext * context,TfLitePadding padding,uint32_t * flags,int node_index)839   static TfLiteStatus CalculatePadding(TfLiteContext* context,
840                                        TfLitePadding padding, uint32_t* flags,
841                                        int node_index) {
842     switch (padding) {
843       case kTfLitePaddingSame: {
844         *flags = XNN_FLAG_TENSORFLOW_SAME_PADDING;
845         return kTfLiteOk;
846       }
847       case kTfLitePaddingValid:
848         *flags = 0;
849         return kTfLiteOk;
850       default:
851         TF_LITE_MAYBE_KERNEL_LOG(context,
852                                  "invalid padding mode (%d) in node #%d",
853                                  static_cast<int>(padding), node_index);
854         return kTfLiteError;
855     }
856   }
857 
CalculateTransposeConvPaddings(TfLiteContext * context,TfLitePadding padding,int input_height,int input_width,int kernel_height,int kernel_width,int dilation_height,int dilation_width,int stride_height,int stride_width,int node_index,int output_height,int output_width,int * padding_top,int * padding_bottom,int * padding_left,int * padding_right,int * adjustment_height,int * adjustment_width)858   static TfLiteStatus CalculateTransposeConvPaddings(
859       TfLiteContext* context, TfLitePadding padding, int input_height,
860       int input_width, int kernel_height, int kernel_width, int dilation_height,
861       int dilation_width, int stride_height, int stride_width, int node_index,
862       int output_height, int output_width, int* padding_top,
863       int* padding_bottom, int* padding_left, int* padding_right,
864       int* adjustment_height, int* adjustment_width) {
865     const int effective_kernel_height =
866         (kernel_height - 1) * dilation_height + 1;
867     const int effective_kernel_width = (kernel_width - 1) * dilation_width + 1;
868     switch (padding) {
869       case kTfLitePaddingValid: {
870         if (effective_kernel_height > output_height ||
871             effective_kernel_width > output_width) {
872           TF_LITE_MAYBE_KERNEL_LOG(
873               context,
874               "output smaller than effective kernel dimensions unsupported "
875               "with VALID padding in TRANSPOSE_CONV node #%d: "
876               "effective kernel size %dx%d (HxW), output %dx%d",
877               node_index, effective_kernel_height, effective_kernel_width,
878               output_height, output_width);
879           return kTfLiteError;
880         }
881 
882         *padding_top = *padding_bottom = *padding_left = *padding_right = 0;
883         *adjustment_height = (output_height - kernel_height) % stride_height;
884         *adjustment_width = (output_width - kernel_width) % stride_width;
885         break;
886       }
887       case kTfLitePaddingSame: {
888         int expected_input_height = 0;
889         int expected_input_width = 0;
890         TfLitePaddingValues paddings = ComputePaddingHeightWidth(
891             stride_height, stride_width, dilation_height, dilation_width,
892             output_height, output_width, kernel_height, kernel_width, padding,
893             &expected_input_height, &expected_input_width);
894         if (expected_input_height != input_height ||
895             expected_input_width != input_width) {
896           TF_LITE_MAYBE_KERNEL_LOG(
897               context,
898               "inconsistent combination of parameters for TRANSPOSE_CONV op "
899               "in node #%d: computed input size %dx%d (HxW), actual %dx%d",
900               node_index, expected_input_height, expected_input_width,
901               input_height, input_width);
902           return kTfLiteError;
903         }
904 
905         // Note: In the derivation of the adjustments below, it was assumed that
906         //       `effective_kernel_...` >= `stride_...` so that `ComputePadding`
907         //       in TFLite doesn't encounter a negative value clamped to zero.
908         if (kernel_height < stride_height || kernel_width < stride_width) {
909           TF_LITE_MAYBE_KERNEL_LOG(
910               context,
911               "strides larger than effective kernel dimensions unsupported in "
912               "TRANSPOSE_CONV node #%d: kernel size %dx%d (HxW), strides %dx%d",
913               node_index, effective_kernel_height, effective_kernel_width,
914               stride_height, stride_width);
915           return kTfLiteError;
916         }
917 
918         *padding_top = paddings.height;
919         *padding_bottom = paddings.height + paddings.height_offset;
920         *adjustment_height = 0;
921         *padding_left = paddings.width;
922         *padding_right = paddings.width + paddings.width_offset;
923         *adjustment_width = 0;
924         break;
925       }
926       default:
927         TF_LITE_MAYBE_KERNEL_LOG(context,
928                                  "invalid padding mode (%d) in node #%d",
929                                  static_cast<int>(padding), node_index);
930         return kTfLiteError;
931     }
932 
933     return kTfLiteOk;
934   }
935 
ConvertActivationToOutputRange(TfLiteContext * context,int node_index,TfLiteFusedActivation activation,float * output_min,float * output_max)936   static TfLiteStatus ConvertActivationToOutputRange(
937       TfLiteContext* context, int node_index, TfLiteFusedActivation activation,
938       float* output_min, float* output_max) {
939     switch (activation) {
940       case kTfLiteActNone:
941         *output_min = -std::numeric_limits<float>::infinity();
942         *output_max = +std::numeric_limits<float>::infinity();
943         return kTfLiteOk;
944       case kTfLiteActRelu:
945         *output_min = 0.0f;
946         *output_max = +std::numeric_limits<float>::infinity();
947         return kTfLiteOk;
948       case kTfLiteActReluN1To1:
949         *output_min = -1.0f;
950         *output_max = +1.0f;
951         return kTfLiteOk;
952       case kTfLiteActRelu6:
953         *output_min = 0.0f;
954         *output_max = 6.0f;
955         return kTfLiteOk;
956       case kTfLiteActTanh:
957         TF_LITE_MAYBE_KERNEL_LOG(
958             context, "unsupported fused activation (Tanh) in node #%d",
959             node_index);
960         return kTfLiteError;
961       case kTfLiteActSignBit:
962         TF_LITE_MAYBE_KERNEL_LOG(
963             context, "unsupported fused activation (Sign) in node #%d",
964             node_index);
965         return kTfLiteError;
966       case kTfLiteActSigmoid:
967         TF_LITE_MAYBE_KERNEL_LOG(
968             context, "unsupported fused activation (Sigmoid) in node #%d",
969             node_index);
970         return kTfLiteError;
971       default:
972         TF_LITE_MAYBE_KERNEL_LOG(context,
973                                  "invalid fused activation (%d) in node #%d",
974                                  static_cast<int>(activation), node_index);
975         return kTfLiteError;
976     }
977   }
978 
CheckConvolutionParams(TfLiteContext * context,const TfLiteConvParams * params,int node_index)979   static TfLiteStatus CheckConvolutionParams(TfLiteContext* context,
980                                              const TfLiteConvParams* params,
981                                              int node_index) {
982     if (params->stride_width <= 0) {
983       TF_LITE_MAYBE_KERNEL_LOG(context, "invalid stride width %d in node #%d",
984                                params->stride_width, node_index);
985       return kTfLiteError;
986     }
987     if (params->stride_height <= 0) {
988       TF_LITE_MAYBE_KERNEL_LOG(context, "invalid stride height %d in node #%d",
989                                params->stride_height, node_index);
990       return kTfLiteError;
991     }
992 
993     if (params->dilation_width_factor <= 0) {
994       TF_LITE_MAYBE_KERNEL_LOG(context,
995                                "invalid dilation width factor %d in node #%d",
996                                params->dilation_width_factor, node_index);
997       return kTfLiteError;
998     }
999     if (params->dilation_height_factor <= 0) {
1000       TF_LITE_MAYBE_KERNEL_LOG(context,
1001                                "invalid dilation height factor %d in node #%d",
1002                                params->dilation_height_factor, node_index);
1003       return kTfLiteError;
1004     }
1005 
1006     return kTfLiteOk;
1007   }
1008 
CheckDepthwiseConvolutionParams(TfLiteContext * context,const TfLiteDepthwiseConvParams * params,int output_channels,int node_index)1009   static TfLiteStatus CheckDepthwiseConvolutionParams(
1010       TfLiteContext* context, const TfLiteDepthwiseConvParams* params,
1011       int output_channels, int node_index) {
1012     if (params->stride_width <= 0) {
1013       TF_LITE_MAYBE_KERNEL_LOG(context, "invalid stride width %d in node #%d",
1014                                params->stride_width, node_index);
1015       return kTfLiteError;
1016     }
1017     if (params->stride_height <= 0) {
1018       TF_LITE_MAYBE_KERNEL_LOG(context, "invalid stride height %d in node #%d",
1019                                params->stride_height, node_index);
1020       return kTfLiteError;
1021     }
1022 
1023     if (params->depth_multiplier <= 0) {
1024       TF_LITE_MAYBE_KERNEL_LOG(context,
1025                                "invalid depth multiplier %d in node #%d",
1026                                params->depth_multiplier, node_index);
1027       return kTfLiteError;
1028     }
1029     if (output_channels % params->depth_multiplier != 0) {
1030       TF_LITE_MAYBE_KERNEL_LOG(context,
1031                                "depth multiplier %d is incompatible with "
1032                                "number of output channels %d in node #%d",
1033                                params->depth_multiplier, output_channels,
1034                                node_index);
1035       return kTfLiteError;
1036     }
1037 
1038     if (params->dilation_width_factor <= 0) {
1039       TF_LITE_MAYBE_KERNEL_LOG(context,
1040                                "invalid dilation width factor %d in node #%d",
1041                                params->dilation_width_factor, node_index);
1042       return kTfLiteError;
1043     }
1044     if (params->dilation_height_factor <= 0) {
1045       TF_LITE_MAYBE_KERNEL_LOG(context,
1046                                "invalid dilation height factor %d in node #%d",
1047                                params->dilation_height_factor, node_index);
1048       return kTfLiteError;
1049     }
1050 
1051     return kTfLiteOk;
1052   }
1053 
CheckMediaPipeTransposedConvolutionParams(TfLiteContext * context,const TfLiteTransposeConvParams * params,int node_index)1054   static TfLiteStatus CheckMediaPipeTransposedConvolutionParams(
1055       TfLiteContext* context, const TfLiteTransposeConvParams* params,
1056       int node_index) {
1057     if (params->stride_width <= 0) {
1058       TF_LITE_MAYBE_KERNEL_LOG(context, "invalid stride width %d in node #%d",
1059                                params->stride_width, node_index);
1060       return kTfLiteError;
1061     }
1062     if (params->stride_height <= 0) {
1063       TF_LITE_MAYBE_KERNEL_LOG(context, "invalid stride height %d in node #%d",
1064                                params->stride_height, node_index);
1065       return kTfLiteError;
1066     }
1067 
1068     return kTfLiteOk;
1069   }
1070 
CheckMediaPipePoolParams(TfLiteContext * context,const TfLitePoolParams * params,int node_index)1071   static TfLiteStatus CheckMediaPipePoolParams(TfLiteContext* context,
1072                                                const TfLitePoolParams* params,
1073                                                int node_index) {
1074     if (params->stride_width <= 0) {
1075       TF_LITE_MAYBE_KERNEL_LOG(context, "invalid stride width %d in node #%d",
1076                                params->stride_width, node_index);
1077       return kTfLiteError;
1078     }
1079     if (params->stride_height <= 0) {
1080       TF_LITE_MAYBE_KERNEL_LOG(context, "invalid stride height %d in node #%d",
1081                                params->stride_height, node_index);
1082       return kTfLiteError;
1083     }
1084     if (params->filter_width <= 0) {
1085       TF_LITE_MAYBE_KERNEL_LOG(context, "invalid filter width %d in node #%d",
1086                                params->filter_width, node_index);
1087       return kTfLiteError;
1088     }
1089     if (params->filter_height <= 0) {
1090       TF_LITE_MAYBE_KERNEL_LOG(context, "invalid filter height %d in node #%d",
1091                                params->filter_height, node_index);
1092       return kTfLiteError;
1093     }
1094     if (params->filter_width != params->stride_width) {
1095       TF_LITE_MAYBE_KERNEL_LOG(
1096           context, "filter width %d does not match stride width %d in node #%d",
1097           params->filter_width, params->stride_width, node_index);
1098       return kTfLiteError;
1099     }
1100     if (params->filter_height != params->stride_height) {
1101       TF_LITE_MAYBE_KERNEL_LOG(
1102           context,
1103           "filter height %d does not match stride height %d in node #%d",
1104           params->filter_height, params->stride_height, node_index);
1105       return kTfLiteError;
1106     }
1107     switch (params->activation) {
1108       case kTfLiteActNone:
1109         break;
1110       case kTfLiteActRelu:
1111         TF_LITE_MAYBE_KERNEL_LOG(
1112             context, "unsupported fused activation (Relu) in node #%d",
1113             node_index);
1114         return kTfLiteOk;
1115       case kTfLiteActReluN1To1:
1116         TF_LITE_MAYBE_KERNEL_LOG(
1117             context, "unsupported fused activation (ReluMinus1To1) in node #%d",
1118             node_index);
1119         return kTfLiteOk;
1120       case kTfLiteActRelu6:
1121         TF_LITE_MAYBE_KERNEL_LOG(
1122             context, "unsupported fused activation (Relu6) in node #%d",
1123             node_index);
1124         return kTfLiteOk;
1125       case kTfLiteActTanh:
1126         TF_LITE_MAYBE_KERNEL_LOG(
1127             context, "unsupported fused activation (Tanh) in node #%d",
1128             node_index);
1129         return kTfLiteError;
1130       case kTfLiteActSignBit:
1131         TF_LITE_MAYBE_KERNEL_LOG(
1132             context, "unsupported fused activation (Sign) in node #%d",
1133             node_index);
1134         return kTfLiteError;
1135       case kTfLiteActSigmoid:
1136         TF_LITE_MAYBE_KERNEL_LOG(
1137             context, "unsupported fused activation (Sigmoid) in node #%d",
1138             node_index);
1139         return kTfLiteError;
1140       default:
1141         TF_LITE_MAYBE_KERNEL_LOG(
1142             context, "invalid fused activation (%d) in node #%d",
1143             static_cast<int>(params->activation), node_index);
1144         return kTfLiteError;
1145     }
1146 
1147     return kTfLiteOk;
1148   }
1149 
CheckFullyConnectedParams(TfLiteContext * context,const TfLiteFullyConnectedParams * params,int node_index)1150   static TfLiteStatus CheckFullyConnectedParams(
1151       TfLiteContext* context, const TfLiteFullyConnectedParams* params,
1152       int node_index) {
1153     if (params->weights_format != kTfLiteFullyConnectedWeightsFormatDefault) {
1154       TF_LITE_MAYBE_KERNEL_LOG(
1155           context, "unsupported non-default weights format in node #%d",
1156           node_index);
1157       return kTfLiteError;
1158     }
1159 
1160     return kTfLiteOk;
1161   }
1162 
CheckPoolingParams(TfLiteContext * context,const TfLitePoolParams * params,int node_index)1163   static TfLiteStatus CheckPoolingParams(TfLiteContext* context,
1164                                          const TfLitePoolParams* params,
1165                                          int node_index) {
1166     if (params->stride_width <= 0) {
1167       TF_LITE_MAYBE_KERNEL_LOG(context, "invalid stride width %d in node #%d",
1168                                params->stride_width, node_index);
1169       return kTfLiteError;
1170     }
1171     if (params->stride_height <= 0) {
1172       TF_LITE_MAYBE_KERNEL_LOG(context, "invalid stride height %d in node #%d",
1173                                params->stride_height, node_index);
1174       return kTfLiteError;
1175     }
1176 
1177     if (params->filter_width <= 0) {
1178       TF_LITE_MAYBE_KERNEL_LOG(context, "invalid filter width %d in node #%d",
1179                                params->filter_width, node_index);
1180       return kTfLiteError;
1181     }
1182     if (params->filter_height <= 0) {
1183       TF_LITE_MAYBE_KERNEL_LOG(context, "invalid filter height %d in node #%d",
1184                                params->filter_height, node_index);
1185       return kTfLiteError;
1186     }
1187 
1188     if (params->stride_width > params->filter_width) {
1189       TF_LITE_MAYBE_KERNEL_LOG(
1190           context,
1191           "unsupported width stride %d exceeding filter width %d in node #%d",
1192           params->stride_width, params->filter_width, node_index);
1193       return kTfLiteError;
1194     }
1195 
1196     if (params->stride_height > params->filter_height) {
1197       TF_LITE_MAYBE_KERNEL_LOG(
1198           context,
1199           "unsupported height stride %d exceeding filter height %d in node #%d",
1200           params->stride_height, params->filter_height, node_index);
1201       return kTfLiteError;
1202     }
1203 
1204     if (params->filter_width == 1 && params->filter_height == 1 &&
1205         std::max(params->stride_width, params->stride_height) > 1) {
1206       TF_LITE_MAYBE_KERNEL_LOG(context,
1207                                "unsupported pooling with 1x1 filter "
1208                                "and %dx%d stride in node #%d",
1209                                params->stride_width, params->stride_height,
1210                                node_index);
1211       return kTfLiteError;
1212     }
1213 
1214     return kTfLiteOk;
1215   }
1216 
CheckNumInputs(TfLiteContext * context,TfLiteNode * node,int expected_num_inputs,int node_index)1217   static TfLiteStatus CheckNumInputs(TfLiteContext* context, TfLiteNode* node,
1218                                      int expected_num_inputs, int node_index) {
1219     if (node->inputs->size != expected_num_inputs) {
1220       TF_LITE_MAYBE_KERNEL_LOG(
1221           context, "unexpected number of inputs (%d != %d) in node #%d",
1222           node->inputs->size, expected_num_inputs, node_index);
1223       return kTfLiteError;
1224     }
1225     return kTfLiteOk;
1226   }
1227 
CheckNumInputs(TfLiteContext * context,TfLiteNode * node,int min_num_inputs,int max_num_inputs,int node_index)1228   static TfLiteStatus CheckNumInputs(TfLiteContext* context, TfLiteNode* node,
1229                                      int min_num_inputs, int max_num_inputs,
1230                                      int node_index) {
1231     if (node->inputs->size < min_num_inputs ||
1232         node->inputs->size > max_num_inputs) {
1233       TF_LITE_MAYBE_KERNEL_LOG(context,
1234                                "unexpected number of inputs (%d) in node #%d",
1235                                node->inputs->size, node_index);
1236       return kTfLiteError;
1237     }
1238     return kTfLiteOk;
1239   }
1240 
CheckNumOutputs(TfLiteContext * context,TfLiteNode * node,int expected_num_outputs,int node_index)1241   static TfLiteStatus CheckNumOutputs(TfLiteContext* context, TfLiteNode* node,
1242                                       int expected_num_outputs,
1243                                       int node_index) {
1244     if (node->outputs->size != expected_num_outputs) {
1245       TF_LITE_MAYBE_KERNEL_LOG(
1246           context, "unexpected number of outputs (%d != %d) in node #%d",
1247           node->outputs->size, expected_num_outputs, node_index);
1248       return kTfLiteError;
1249     }
1250     return kTfLiteOk;
1251   }
1252 
CheckNumOutputs(TfLiteContext * context,TfLiteNode * node,int min_num_outputs,int max_num_outputs,int node_index)1253   static TfLiteStatus CheckNumOutputs(TfLiteContext* context, TfLiteNode* node,
1254                                       int min_num_outputs, int max_num_outputs,
1255                                       int node_index) {
1256     if (node->outputs->size < min_num_outputs ||
1257         node->outputs->size > max_num_outputs) {
1258       TF_LITE_MAYBE_KERNEL_LOG(context,
1259                                "unexpected number of outputs (%d) in node #%d",
1260                                node->outputs->size, node_index);
1261       return kTfLiteError;
1262     }
1263     return kTfLiteOk;
1264   }
1265 
CheckNumInputsAndOutputs(TfLiteContext * context,TfLiteNode * node,int min_num_inputs,int max_num_inputs,int expected_num_outputs,int node_index)1266   static TfLiteStatus CheckNumInputsAndOutputs(
1267       TfLiteContext* context, TfLiteNode* node, int min_num_inputs,
1268       int max_num_inputs, int expected_num_outputs, int node_index) {
1269     TF_LITE_ENSURE_STATUS(CheckNumInputs(context, node, min_num_inputs,
1270                                          max_num_inputs, node_index));
1271     TF_LITE_ENSURE_STATUS(
1272         CheckNumOutputs(context, node, expected_num_outputs, node_index));
1273     return kTfLiteOk;
1274   }
1275 
CheckNumInputsAndOutputs(TfLiteContext * context,TfLiteNode * node,int expected_num_inputs,int expected_num_outputs,int node_index)1276   static TfLiteStatus CheckNumInputsAndOutputs(TfLiteContext* context,
1277                                                TfLiteNode* node,
1278                                                int expected_num_inputs,
1279                                                int expected_num_outputs,
1280                                                int node_index) {
1281     TF_LITE_ENSURE_STATUS(
1282         CheckNumInputs(context, node, expected_num_inputs, node_index));
1283     TF_LITE_ENSURE_STATUS(
1284         CheckNumOutputs(context, node, expected_num_outputs, node_index));
1285     return kTfLiteOk;
1286   }
1287 
CheckTensorType(TfLiteContext * context,const TfLiteTensor & tensor,TfLiteType expected_type,int tensor_index,int node_index)1288   static TfLiteStatus CheckTensorType(TfLiteContext* context,
1289                                       const TfLiteTensor& tensor,
1290                                       TfLiteType expected_type,
1291                                       int tensor_index, int node_index) {
1292     if (tensor.type != expected_type) {
1293       TF_LITE_MAYBE_KERNEL_LOG(
1294           context, "unsupported type %s in tensor #%d in node #%d",
1295           TfLiteTypeGetName(tensor.type), tensor_index, node_index);
1296       return kTfLiteError;
1297     }
1298     return kTfLiteOk;
1299   }
1300 
CheckTensorFloat32Type(TfLiteContext * context,const TfLiteTensor & tensor,int tensor_index,int node_index)1301   static TfLiteStatus CheckTensorFloat32Type(TfLiteContext* context,
1302                                              const TfLiteTensor& tensor,
1303                                              int tensor_index, int node_index) {
1304     return CheckTensorType(context, tensor, kTfLiteFloat32, tensor_index,
1305                            node_index);
1306   }
1307 
CheckTensorFloat32OrQInt8Type(const Delegate & delegate,TfLiteContext * context,const TfLiteTensor & tensor,int tensor_index,int node_index)1308   static TfLiteStatus CheckTensorFloat32OrQInt8Type(const Delegate& delegate,
1309                                                     TfLiteContext* context,
1310                                                     const TfLiteTensor& tensor,
1311                                                     int tensor_index,
1312                                                     int node_index) {
1313     switch (tensor.type) {
1314       case kTfLiteFloat32:
1315         return kTfLiteOk;
1316       case kTfLiteInt8:
1317         if (delegate.support_signed_8bit_quantization()) {
1318           const auto* quantization_params =
1319               static_cast<const TfLiteAffineQuantization*>(
1320                   tensor.quantization.params);
1321           if (tensor.quantization.type != kTfLiteAffineQuantization ||
1322               quantization_params->quantized_dimension != 0 ||
1323               quantization_params->scale == nullptr ||
1324               quantization_params->scale->size != 1) {
1325             TF_LITE_MAYBE_KERNEL_LOG(
1326                 context,
1327                 "unsupported quantization type %d in tensor #%d in node #%d",
1328                 tensor.quantization.type, tensor_index, node_index);
1329             return kTfLiteError;
1330           }
1331           return kTfLiteOk;
1332         }
1333         break;
1334       default:
1335         break;
1336     }
1337 
1338     TF_LITE_MAYBE_KERNEL_LOG(
1339         context, "unsupported type %s in tensor #%d in node #%d",
1340         TfLiteTypeGetName(tensor.type), tensor_index, node_index);
1341     return kTfLiteError;
1342   }
1343 
CheckTensorQInt8OrQUInt8Type(const Delegate & delegate,TfLiteContext * context,const TfLiteTensor & tensor,int tensor_index,int node_index)1344   static TfLiteStatus CheckTensorQInt8OrQUInt8Type(const Delegate& delegate,
1345                                                    TfLiteContext* context,
1346                                                    const TfLiteTensor& tensor,
1347                                                    int tensor_index,
1348                                                    int node_index) {
1349     switch (tensor.type) {
1350       case kTfLiteInt8:
1351         if (delegate.support_signed_8bit_quantization()) {
1352           const auto* quantization_params =
1353               static_cast<const TfLiteAffineQuantization*>(
1354                   tensor.quantization.params);
1355           if (tensor.quantization.type != kTfLiteAffineQuantization ||
1356               quantization_params->quantized_dimension != 0 ||
1357               quantization_params->scale == nullptr ||
1358               quantization_params->scale->size != 1) {
1359             TF_LITE_MAYBE_KERNEL_LOG(
1360                 context,
1361                 "unsupported quantization type %d in tensor #%d in node #%d",
1362                 tensor.quantization.type, tensor_index, node_index);
1363             return kTfLiteError;
1364           }
1365           return kTfLiteOk;
1366         }
1367         break;
1368       case kTfLiteUInt8:
1369         if (delegate.support_unsigned_8bit_quantization()) {
1370           const auto* quantization_params =
1371               static_cast<const TfLiteAffineQuantization*>(
1372                   tensor.quantization.params);
1373           if (tensor.quantization.type != kTfLiteAffineQuantization ||
1374               quantization_params->quantized_dimension != 0 ||
1375               quantization_params->scale == nullptr ||
1376               quantization_params->zero_point == nullptr ||
1377               quantization_params->scale->size != 1 ||
1378               quantization_params->zero_point->size != 1) {
1379             TF_LITE_MAYBE_KERNEL_LOG(
1380                 context,
1381                 "unsupported quantization type %d in tensor #%d in node #%d",
1382                 tensor.quantization.type, tensor_index, node_index);
1383             return kTfLiteError;
1384           }
1385           return kTfLiteOk;
1386         }
1387         break;
1388       default:
1389         break;
1390     }
1391 
1392     TF_LITE_MAYBE_KERNEL_LOG(
1393         context, "unsupported type %s in tensor #%d in node #%d",
1394         TfLiteTypeGetName(tensor.type), tensor_index, node_index);
1395     return kTfLiteError;
1396   }
1397 
CheckTensorFloat32OrQUInt8Type(const Delegate & delegate,TfLiteContext * context,const TfLiteTensor & tensor,int tensor_index,int node_index)1398   static TfLiteStatus CheckTensorFloat32OrQUInt8Type(const Delegate& delegate,
1399                                                      TfLiteContext* context,
1400                                                      const TfLiteTensor& tensor,
1401                                                      int tensor_index,
1402                                                      int node_index) {
1403     switch (tensor.type) {
1404       case kTfLiteFloat32:
1405         return kTfLiteOk;
1406       case kTfLiteInt8:
1407         if (delegate.support_signed_8bit_quantization()) {
1408           const auto* quantization_params =
1409               static_cast<const TfLiteAffineQuantization*>(
1410                   tensor.quantization.params);
1411           if (tensor.quantization.type != kTfLiteAffineQuantization ||
1412               quantization_params->quantized_dimension != 0 ||
1413               quantization_params->scale == nullptr ||
1414               quantization_params->scale->size != 1) {
1415             TF_LITE_MAYBE_KERNEL_LOG(
1416                 context,
1417                 "unsupported quantization type %d in tensor #%d in node #%d",
1418                 tensor.quantization.type, tensor_index, node_index);
1419             return kTfLiteError;
1420           }
1421           return kTfLiteOk;
1422         }
1423         break;
1424       case kTfLiteUInt8:
1425         if (delegate.support_unsigned_8bit_quantization()) {
1426           const auto* quantization_params =
1427               static_cast<const TfLiteAffineQuantization*>(
1428                   tensor.quantization.params);
1429           if (tensor.quantization.type != kTfLiteAffineQuantization ||
1430               quantization_params->quantized_dimension != 0 ||
1431               quantization_params->scale == nullptr ||
1432               quantization_params->zero_point == nullptr ||
1433               quantization_params->scale->size != 1 ||
1434               quantization_params->zero_point->size != 1) {
1435             TF_LITE_MAYBE_KERNEL_LOG(
1436                 context,
1437                 "unsupported quantization type %d in tensor #%d in node #%d",
1438                 tensor.quantization.type, tensor_index, node_index);
1439             return kTfLiteError;
1440           }
1441           return kTfLiteOk;
1442         }
1443         break;
1444       default:
1445         break;
1446     }
1447 
1448     TF_LITE_MAYBE_KERNEL_LOG(
1449         context, "unsupported type %s in tensor #%d in node #%d",
1450         TfLiteTypeGetName(tensor.type), tensor_index, node_index);
1451     return kTfLiteError;
1452   }
1453 
CheckTensorFloat32OrQCInt8Type(const Delegate & delegate,TfLiteContext * context,const TfLiteTensor & tensor,int expected_quantized_dimension,int tensor_index,int node_index)1454   static TfLiteStatus CheckTensorFloat32OrQCInt8Type(
1455       const Delegate& delegate, TfLiteContext* context,
1456       const TfLiteTensor& tensor, int expected_quantized_dimension,
1457       int tensor_index, int node_index) {
1458     switch (tensor.type) {
1459       case kTfLiteFloat32:
1460         return kTfLiteOk;
1461       case kTfLiteInt8:
1462         if (delegate.support_signed_8bit_quantization()) {
1463           if (tensor.quantization.type != kTfLiteAffineQuantization) {
1464             TF_LITE_MAYBE_KERNEL_LOG(
1465                 context,
1466                 "unsupported quantization type %d in tensor #%d in node #%d",
1467                 tensor.quantization.type, tensor_index, node_index);
1468             return kTfLiteError;
1469           }
1470           const TfLiteAffineQuantization* quantization_params =
1471               static_cast<const TfLiteAffineQuantization*>(
1472                   tensor.quantization.params);
1473           if (quantization_params->scale == nullptr) {
1474             TF_LITE_MAYBE_KERNEL_LOG(context,
1475                                      "missing scale quantization parameters in "
1476                                      "tensor #%d in node #%d",
1477                                      tensor_index, node_index);
1478             return kTfLiteError;
1479           }
1480           if (quantization_params->scale->size > 1 &&
1481               quantization_params->quantized_dimension !=
1482                   expected_quantized_dimension) {
1483             TF_LITE_MAYBE_KERNEL_LOG(
1484                 context,
1485                 "unsupported quantized dimension %d in tensor #%d in node #%d",
1486                 quantization_params->quantized_dimension, tensor_index,
1487                 node_index);
1488             return kTfLiteError;
1489           }
1490           return kTfLiteOk;
1491         }
1492         break;
1493       case kTfLiteUInt8:
1494         if (delegate.support_unsigned_8bit_quantization()) {
1495           const auto* quantization_params =
1496               static_cast<const TfLiteAffineQuantization*>(
1497                   tensor.quantization.params);
1498           if (tensor.quantization.type != kTfLiteAffineQuantization ||
1499               quantization_params->quantized_dimension != 0 ||
1500               quantization_params->scale == nullptr ||
1501               quantization_params->zero_point == nullptr ||
1502               quantization_params->scale->size != 1 ||
1503               quantization_params->zero_point->size != 1) {
1504             TF_LITE_MAYBE_KERNEL_LOG(
1505                 context,
1506                 "unsupported quantization type %d in tensor #%d in node #%d",
1507                 tensor.quantization.type, tensor_index, node_index);
1508             return kTfLiteError;
1509           }
1510           return kTfLiteOk;
1511         }
1512         break;
1513       default:
1514         break;
1515     }
1516 
1517     TF_LITE_MAYBE_KERNEL_LOG(
1518         context, "unsupported type %s in tensor #%d in node #%d",
1519         TfLiteTypeGetName(tensor.type), tensor_index, node_index);
1520     return kTfLiteError;
1521   }
1522 
CheckTensorFloat32OrQInt32Type(const Delegate & delegate,TfLiteContext * context,const TfLiteTensor & tensor,int tensor_index,int node_index)1523   static TfLiteStatus CheckTensorFloat32OrQInt32Type(const Delegate& delegate,
1524                                                      TfLiteContext* context,
1525                                                      const TfLiteTensor& tensor,
1526                                                      int tensor_index,
1527                                                      int node_index) {
1528     switch (tensor.type) {
1529       case kTfLiteFloat32:
1530         return kTfLiteOk;
1531       case kTfLiteInt32:
1532         if (delegate.support_any_8bit_quantization()) {
1533           if (tensor.quantization.type != kTfLiteAffineQuantization ||
1534               static_cast<const TfLiteAffineQuantization*>(
1535                   tensor.quantization.params)
1536                       ->quantized_dimension != 0 ||
1537               static_cast<const TfLiteAffineQuantization*>(
1538                   tensor.quantization.params)
1539                       ->scale == nullptr ||
1540               static_cast<const TfLiteAffineQuantization*>(
1541                   tensor.quantization.params)
1542                       ->scale->size != 1) {
1543             TF_LITE_MAYBE_KERNEL_LOG(
1544                 context,
1545                 "unsupported quantization type %d in tensor #%d in node #%d",
1546                 tensor.quantization.type, tensor_index, node_index);
1547             return kTfLiteError;
1548           }
1549           return kTfLiteOk;
1550         }
1551         break;
1552       default:
1553         break;
1554     }
1555 
1556     TF_LITE_MAYBE_KERNEL_LOG(
1557         context, "unsupported type %s in tensor #%d in node #%d",
1558         TfLiteTypeGetName(tensor.type), tensor_index, node_index);
1559     return kTfLiteError;
1560   }
1561 
CheckTensorFloat32OrQCInt32Type(const Delegate & delegate,TfLiteContext * context,const TfLiteTensor & tensor,int tensor_index,int node_index)1562   static TfLiteStatus CheckTensorFloat32OrQCInt32Type(
1563       const Delegate& delegate, TfLiteContext* context,
1564       const TfLiteTensor& tensor, int tensor_index, int node_index) {
1565     switch (tensor.type) {
1566       case kTfLiteFloat32:
1567         return kTfLiteOk;
1568       case kTfLiteInt32:
1569         if (delegate.support_signed_8bit_quantization()) {
1570           if (tensor.quantization.type != kTfLiteAffineQuantization ||
1571               static_cast<const TfLiteAffineQuantization*>(
1572                   tensor.quantization.params)
1573                       ->quantized_dimension != 0) {
1574             TF_LITE_MAYBE_KERNEL_LOG(
1575                 context,
1576                 "unsupported quantization type %d in tensor #%d in node #%d",
1577                 tensor.quantization.type, tensor_index, node_index);
1578             return kTfLiteError;
1579           }
1580           return kTfLiteOk;
1581         }
1582         break;
1583       default:
1584         break;
1585     }
1586     TF_LITE_MAYBE_KERNEL_LOG(
1587         context, "unsupported type %s in tensor #%d in node #%d",
1588         TfLiteTypeGetName(tensor.type), tensor_index, node_index);
1589     return kTfLiteError;
1590   }
1591 
CheckTensorShape(TfLiteContext * context,const TfLiteTensor & tensor,int min_num_dims,int max_num_dims,int tensor_index)1592   static TfLiteStatus CheckTensorShape(TfLiteContext* context,
1593                                        const TfLiteTensor& tensor,
1594                                        int min_num_dims, int max_num_dims,
1595                                        int tensor_index) {
1596     if (min_num_dims == max_num_dims) {
1597       if (NumDimensions(&tensor) != min_num_dims) {
1598         TF_LITE_MAYBE_KERNEL_LOG(
1599             context,
1600             "unsupported number of shape dimensions (%d) in tensor #%d: "
1601             "%d dimensions expected",
1602             NumDimensions(&tensor), tensor_index, min_num_dims);
1603         return kTfLiteError;
1604       }
1605     } else {
1606       if (NumDimensions(&tensor) < min_num_dims) {
1607         TF_LITE_MAYBE_KERNEL_LOG(
1608             context,
1609             "unsupported number of shape dimensions (%d) in tensor #%d: "
1610             "at least %d dimensions expected",
1611             NumDimensions(&tensor), tensor_index, min_num_dims);
1612         return kTfLiteError;
1613       }
1614       if (NumDimensions(&tensor) > max_num_dims) {
1615         TF_LITE_MAYBE_KERNEL_LOG(
1616             context,
1617             "unsupported number of shape dimensions (%d) in tensor #%d: "
1618             "at most %d dimensions expected",
1619             NumDimensions(&tensor), tensor_index, max_num_dims);
1620         return kTfLiteError;
1621       }
1622     }
1623     for (int i = 0; i < NumDimensions(&tensor); i++) {
1624       if (SizeOfDimension(&tensor, i) <= 0) {
1625         TF_LITE_MAYBE_KERNEL_LOG(context,
1626                                  "invalid num of elements (%d) in "
1627                                  "dimension #%d in tensor #%d",
1628                                  SizeOfDimension(&tensor, i), i, tensor_index);
1629         return kTfLiteError;
1630       }
1631     }
1632     return kTfLiteOk;
1633   }
1634 
CheckTensorShape(TfLiteContext * context,const TfLiteTensor & tensor,int expected_num_dims,int tensor_index)1635   static TfLiteStatus CheckTensorShape(TfLiteContext* context,
1636                                        const TfLiteTensor& tensor,
1637                                        int expected_num_dims,
1638                                        int tensor_index) {
1639     return CheckTensorShape(context, tensor, expected_num_dims,
1640                             expected_num_dims, tensor_index);
1641   }
1642 
CheckSlopeTensorShape(TfLiteContext * context,const TfLiteTensor & tensor,int tensor_index,int node_index)1643   static TfLiteStatus CheckSlopeTensorShape(TfLiteContext* context,
1644                                             const TfLiteTensor& tensor,
1645                                             int tensor_index, int node_index) {
1646     if (NumDimensions(&tensor) < 1) {
1647       TF_LITE_MAYBE_KERNEL_LOG(context,
1648                                "unexpected number of shape dimensions (%d) in "
1649                                "tensor #%d in node #%d: "
1650                                "expected at least a 1D tensor",
1651                                NumDimensions(&tensor), tensor_index,
1652                                node_index);
1653       return kTfLiteError;
1654     }
1655     // Validate that all non-channel dimensions (if any) are exactly 1.
1656     for (int i = 0; i < NumDimensions(&tensor) - 1; i++) {
1657       if (SizeOfDimension(&tensor, i) != 1) {
1658         TF_LITE_MAYBE_KERNEL_LOG(
1659             context,
1660             "unexpected value %d of shape dimension #%d in "
1661             "tensor #%d in node #%d: "
1662             "expected 1 for non-channel dimensions",
1663             tensor.dims[i], i, tensor_index, node_index);
1664         return kTfLiteError;
1665       }
1666     }
1667     return kTfLiteOk;
1668   }
1669 
CheckPaddingsTensorShape(TfLiteContext * context,const TfLiteTensor & tensor,int expected_rows,int tensor_index,int node_index)1670   static TfLiteStatus CheckPaddingsTensorShape(TfLiteContext* context,
1671                                                const TfLiteTensor& tensor,
1672                                                int expected_rows,
1673                                                int tensor_index,
1674                                                int node_index) {
1675     if (NumDimensions(&tensor) != 2) {
1676       TF_LITE_MAYBE_KERNEL_LOG(context,
1677                                "unexpected number of shape dimensions (%d) in "
1678                                "padding tensor #%d in node #%d: "
1679                                "expected a 2D tensor",
1680                                NumDimensions(&tensor), tensor_index,
1681                                node_index);
1682       return kTfLiteError;
1683     }
1684     if (SizeOfDimension(&tensor, 0) != expected_rows) {
1685       TF_LITE_MAYBE_KERNEL_LOG(context,
1686                                "unexpected number of rows (%d) in "
1687                                "padding tensor #%d in node #%d: "
1688                                "%d rows expected",
1689                                NumDimensions(&tensor), tensor_index, node_index,
1690                                expected_rows);
1691       return kTfLiteError;
1692     }
1693     if (SizeOfDimension(&tensor, 1) != 2) {
1694       TF_LITE_MAYBE_KERNEL_LOG(context,
1695                                "unexpected number of columns (%d) in "
1696                                "padding tensor #%d in node #%d: "
1697                                "2 columns expected",
1698                                NumDimensions(&tensor), tensor_index,
1699                                node_index);
1700       return kTfLiteError;
1701     }
1702     return kTfLiteOk;
1703   }
1704 
CheckAxesTensorShape(TfLiteContext * context,const TfLiteTensor & tensor,int tensor_index,int node_index)1705   static TfLiteStatus CheckAxesTensorShape(TfLiteContext* context,
1706                                            const TfLiteTensor& tensor,
1707                                            int tensor_index, int node_index) {
1708     const int num_tensor_dims = NumDimensions(&tensor);
1709     if (num_tensor_dims > 1) {
1710       TF_LITE_MAYBE_KERNEL_LOG(context,
1711                                "unexpected number of shape dimensions (%d) in "
1712                                "axes tensor #%d in node #%d: "
1713                                "expected a 1D tensor",
1714                                num_tensor_dims, tensor_index, node_index);
1715       return kTfLiteError;
1716     }
1717     return kTfLiteOk;
1718   }
1719 
CheckShapeTensorShape(TfLiteContext * context,const TfLiteTensor & tensor,int tensor_index,int node_index)1720   static TfLiteStatus CheckShapeTensorShape(TfLiteContext* context,
1721                                             const TfLiteTensor& tensor,
1722                                             int tensor_index, int node_index) {
1723     if (NumDimensions(&tensor) != 1) {
1724       TF_LITE_MAYBE_KERNEL_LOG(context,
1725                                "unexpected number of shape dimensions (%d) in "
1726                                "shape tensor #%d in node #%d: "
1727                                "expected a 1D tensor",
1728                                NumDimensions(&tensor), tensor_index,
1729                                node_index);
1730       return kTfLiteError;
1731     }
1732     return kTfLiteOk;
1733   }
1734 
CheckTensorNonDynamicAllocation(TfLiteContext * context,const TfLiteTensor & tensor,int tensor_index,int node_index)1735   static TfLiteStatus CheckTensorNonDynamicAllocation(
1736       TfLiteContext* context, const TfLiteTensor& tensor, int tensor_index,
1737       int node_index) {
1738     // TODO(b/149120844): remove checks once dynamic tensors are supported
1739     if (tensor.allocation_type == kTfLiteDynamic) {
1740       TF_LITE_MAYBE_KERNEL_LOG(
1741           context,
1742           "invalid allocation type in tensor #%d in node #%d: "
1743           "expected non-dynamic tensor",
1744           tensor_index, node_index);
1745       return kTfLiteError;
1746     }
1747     return kTfLiteOk;
1748   }
1749 
CheckTensorStaticAllocation(TfLiteContext * context,const TfLiteTensor & tensor,int tensor_index,int node_index)1750   static TfLiteStatus CheckTensorStaticAllocation(TfLiteContext* context,
1751                                                   const TfLiteTensor& tensor,
1752                                                   int tensor_index,
1753                                                   int node_index) {
1754     if (tensor.allocation_type != kTfLiteMmapRo ||
1755         tensor.data.raw_const == nullptr) {
1756       TF_LITE_MAYBE_KERNEL_LOG(
1757           context,
1758           "invalid allocation type in tensor #%d in node #%d: "
1759           "expected static read-only tensor",
1760           tensor_index, node_index);
1761       return kTfLiteError;
1762     }
1763     return kTfLiteOk;
1764   }
1765 
CheckTensorsDimensionMatch(TfLiteContext * context,const TfLiteTensor & input_tensor,const TfLiteTensor & output_tensor,int dimension_index,int node_index,const char * op_name)1766   static TfLiteStatus CheckTensorsDimensionMatch(
1767       TfLiteContext* context, const TfLiteTensor& input_tensor,
1768       const TfLiteTensor& output_tensor, int dimension_index, int node_index,
1769       const char* op_name) {
1770     if (SizeOfDimension(&input_tensor, dimension_index) !=
1771         SizeOfDimension(&output_tensor, dimension_index)) {
1772       TF_LITE_MAYBE_KERNEL_LOG(
1773           context,
1774           "mismatch in shape dimension %d (%d != %d) in input and output "
1775           "tensors of %s operator #%d",
1776           dimension_index, SizeOfDimension(&input_tensor, dimension_index),
1777           SizeOfDimension(&output_tensor, dimension_index), op_name,
1778           node_index);
1779       return kTfLiteError;
1780     }
1781     return kTfLiteOk;
1782   }
1783 
GetTensorScaleOrDefault(const TfLiteTensor & tensor,float default_scale)1784   static float GetTensorScaleOrDefault(const TfLiteTensor& tensor,
1785                                        float default_scale) {
1786     switch (tensor.type) {
1787       case kTfLiteInt8:
1788       case kTfLiteUInt8: {
1789         if (tensor.quantization.type != kTfLiteAffineQuantization) {
1790           return default_scale;
1791         }
1792 
1793         const auto* quantization_params =
1794             static_cast<const TfLiteAffineQuantization*>(
1795                 tensor.quantization.params);
1796         if (quantization_params->quantized_dimension != 0 ||
1797             quantization_params->scale == nullptr ||
1798             quantization_params->scale->size != 1) {
1799           return default_scale;
1800         }
1801 
1802         return quantization_params->scale->data[0];
1803       }
1804       default:
1805         break;
1806     }
1807     return default_scale;
1808   }
1809 
CheckTensorsInputOutputScale(TfLiteContext * context,const TfLiteTensor & input_tensor,const TfLiteTensor & output_tensor,float scale_min,float scale_max,int node_index,const char * op_name)1810   static TfLiteStatus CheckTensorsInputOutputScale(
1811       TfLiteContext* context, const TfLiteTensor& input_tensor,
1812       const TfLiteTensor& output_tensor, float scale_min, float scale_max,
1813       int node_index, const char* op_name) {
1814     if (input_tensor.type != output_tensor.type) {
1815       // No validation needed
1816       return kTfLiteOk;
1817     }
1818 
1819     if (input_tensor.type == kTfLiteInt8 || input_tensor.type == kTfLiteUInt8) {
1820       const float input_scale = static_cast<const TfLiteAffineQuantization*>(
1821                                     input_tensor.quantization.params)
1822                                     ->scale->data[0];
1823       const float output_scale = static_cast<const TfLiteAffineQuantization*>(
1824                                      output_tensor.quantization.params)
1825                                      ->scale->data[0];
1826 
1827       const float input_output_scale = input_scale / output_scale;
1828       if (input_output_scale < scale_min || input_output_scale >= scale_max) {
1829         TF_LITE_MAYBE_KERNEL_LOG(
1830             context, "unsupported input-to-output scale in node #%d",
1831             node_index);
1832         return kTfLiteError;
1833       }
1834     }
1835     return kTfLiteOk;
1836   }
1837 
CheckTensorsInputProductOutputScale(TfLiteContext * context,const TfLiteTensor & input1_tensor,const TfLiteTensor & input2_tensor,const TfLiteTensor & output_tensor,float scale_min,float scale_max,int node_index,const char * op_name)1838   static TfLiteStatus CheckTensorsInputProductOutputScale(
1839       TfLiteContext* context, const TfLiteTensor& input1_tensor,
1840       const TfLiteTensor& input2_tensor, const TfLiteTensor& output_tensor,
1841       float scale_min, float scale_max, int node_index, const char* op_name) {
1842     if (input1_tensor.type != input2_tensor.type ||
1843         input1_tensor.type != output_tensor.type) {
1844       // No validation needed
1845       return kTfLiteOk;
1846     }
1847 
1848     if (input1_tensor.type == kTfLiteInt8 ||
1849         input1_tensor.type == kTfLiteUInt8) {
1850       const float input1_scale = static_cast<const TfLiteAffineQuantization*>(
1851                                      input1_tensor.quantization.params)
1852                                      ->scale->data[0];
1853       const float input2_scale = static_cast<const TfLiteAffineQuantization*>(
1854                                      input2_tensor.quantization.params)
1855                                      ->scale->data[0];
1856       const float output_scale = static_cast<const TfLiteAffineQuantization*>(
1857                                      output_tensor.quantization.params)
1858                                      ->scale->data[0];
1859 
1860       const float product_scale = input1_scale * input2_scale;
1861       const float product_output_scale = product_scale / output_scale;
1862       if (product_output_scale < scale_min ||
1863           product_output_scale >= scale_max) {
1864         TF_LITE_MAYBE_KERNEL_LOG(
1865             context, "unsupported input-product-to-output scale in node #%d",
1866             node_index);
1867         return kTfLiteError;
1868       }
1869     }
1870     return kTfLiteOk;
1871   }
1872 
VisitNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * context,TfLiteRegistration * registration,TfLiteNode * node,int node_index,const std::unordered_set<int> & quasi_static_tensors,const std::vector<uint32_t> & xnnpack_tensors)1873   static TfLiteStatus VisitNode(
1874       xnn_subgraph_t subgraph, const Delegate& delegate, TfLiteContext* context,
1875       TfLiteRegistration* registration, TfLiteNode* node, int node_index,
1876       const std::unordered_set<int>& quasi_static_tensors,
1877       const std::vector<uint32_t>& xnnpack_tensors) {
1878     // TFLite context used for logging purposes. When we create a new node
1879     // (subgraph is non-null), logging context is the same as context, and error
1880     // messages are passed to TFLite. When we detect supported operations
1881     // (subgraph is null), logging context is null, and error messages are
1882     // supressed.
1883     TfLiteContext* logging_context = subgraph == nullptr ? nullptr : context;
1884     switch (registration->builtin_code) {
1885       case kTfLiteBuiltinAbs:
1886         return VisitAbsNode(subgraph, delegate, logging_context, node_index,
1887                             node, context->tensors, xnnpack_tensors);
1888       case kTfLiteBuiltinAdd: {
1889         const TfLiteAddParams* add_params =
1890             static_cast<const TfLiteAddParams*>(node->builtin_data);
1891 
1892         return VisitAddNode(subgraph, delegate, logging_context, node_index,
1893                             node, context->tensors, add_params,
1894                             xnnpack_tensors);
1895       }
1896       case kTfLiteBuiltinAveragePool2d: {
1897         const TfLitePoolParams* pool_params =
1898             static_cast<const TfLitePoolParams*>(node->builtin_data);
1899 
1900         return VisitAveragePool2DNode(subgraph, delegate, logging_context,
1901                                       node_index, node, context->tensors,
1902                                       pool_params, xnnpack_tensors);
1903       }
1904       case kTfLiteBuiltinCeil:
1905         return VisitCeilNode(subgraph, delegate, logging_context, node_index,
1906                              node, context->tensors, xnnpack_tensors);
1907       case kTfLiteBuiltinConcatenation: {
1908         const TfLiteConcatenationParams* concat_params =
1909             static_cast<const TfLiteConcatenationParams*>(node->builtin_data);
1910         return VisitConcatenationNode(subgraph, delegate, logging_context,
1911                                       node_index, node, context->tensors,
1912                                       concat_params, xnnpack_tensors);
1913       }
1914       case kTfLiteBuiltinConv2d: {
1915         const TfLiteConvParams* conv_params =
1916             static_cast<const TfLiteConvParams*>(node->builtin_data);
1917 
1918         return VisitConv2DNode(subgraph, delegate, logging_context, node_index,
1919                                node, context->tensors, conv_params,
1920                                quasi_static_tensors, xnnpack_tensors);
1921       }
1922       case kTfLiteBuiltinDepthwiseConv2d: {
1923         const TfLiteDepthwiseConvParams* dwconv_params =
1924             static_cast<const TfLiteDepthwiseConvParams*>(node->builtin_data);
1925 
1926         return VisitDepthwiseConv2DNode(subgraph, delegate, logging_context,
1927                                         node_index, node, context->tensors,
1928                                         dwconv_params, quasi_static_tensors,
1929                                         xnnpack_tensors);
1930       }
1931       case kTfLiteBuiltinDepthToSpace: {
1932         const TfLiteDepthToSpaceParams* depth_to_space_params =
1933             static_cast<const TfLiteDepthToSpaceParams*>(node->builtin_data);
1934 
1935         return VisitDepthToSpaceNode(subgraph, delegate, logging_context,
1936                                      node_index, node, context->tensors,
1937                                      depth_to_space_params, xnnpack_tensors);
1938       }
1939       case kTfLiteBuiltinDequantize:
1940         return VisitDequantizeNode(subgraph, delegate, logging_context,
1941                                    node_index, node, context->tensors,
1942                                    xnnpack_tensors);
1943       case kTfLiteBuiltinDiv: {
1944         const TfLiteDivParams* div_params =
1945             static_cast<const TfLiteDivParams*>(node->builtin_data);
1946 
1947         return VisitDivNode(subgraph, delegate, logging_context, node_index,
1948                             node, context->tensors, div_params,
1949                             xnnpack_tensors);
1950       }
1951       case kTfLiteBuiltinElu:
1952         return VisitEluNode(subgraph, delegate, logging_context, node_index,
1953                             node, context->tensors, xnnpack_tensors);
1954       case kTfLiteBuiltinFullyConnected: {
1955         // FullyConnected with sparse weight has version 8, which cannot be
1956         // delegated to XNNPack.
1957         if (registration->version == 8) {
1958           TF_LITE_MAYBE_KERNEL_LOG(logging_context,
1959                                    "Unsupported version %d of FullyConnected.",
1960                                    registration->version);
1961           return kTfLiteError;
1962         }
1963 
1964         const TfLiteFullyConnectedParams* fc_params =
1965             static_cast<const TfLiteFullyConnectedParams*>(node->builtin_data);
1966 
1967         return VisitFullyConnectedNode(
1968             subgraph, delegate, logging_context, node_index, node,
1969             context->tensors, fc_params, quasi_static_tensors, xnnpack_tensors);
1970       }
1971       case kTfLiteBuiltinFloor:
1972         return VisitFloorNode(subgraph, delegate, logging_context, node_index,
1973                               node, context->tensors, xnnpack_tensors);
1974       case kTfLiteBuiltinHardSwish:
1975         return VisitHardSwishNode(subgraph, delegate, logging_context,
1976                                   node_index, node, context->tensors,
1977                                   xnnpack_tensors);
1978       case kTfLiteBuiltinLeakyRelu: {
1979         const TfLiteLeakyReluParams* leaky_relu_params =
1980             static_cast<const TfLiteLeakyReluParams*>(node->builtin_data);
1981 
1982         return VisitLeakyReluNode(subgraph, delegate, logging_context,
1983                                   node_index, node, context->tensors,
1984                                   leaky_relu_params, xnnpack_tensors);
1985       }
1986       case kTfLiteBuiltinLogistic:
1987         return VisitLogisticNode(subgraph, delegate, logging_context,
1988                                  node_index, node, context->tensors,
1989                                  xnnpack_tensors);
1990       case kTfLiteBuiltinMaxPool2d: {
1991         const TfLitePoolParams* pool_params =
1992             static_cast<const TfLitePoolParams*>(node->builtin_data);
1993 
1994         return VisitMaxPool2DNode(subgraph, delegate, logging_context,
1995                                   node_index, node, context->tensors,
1996                                   pool_params, xnnpack_tensors);
1997       }
1998       case kTfLiteBuiltinMaximum:
1999         return VisitMaximumNode(subgraph, delegate, logging_context, node_index,
2000                                 node, context->tensors, xnnpack_tensors);
2001       case kTfLiteBuiltinMean: {
2002         const TfLiteReducerParams* reducer_params =
2003             static_cast<const TfLiteReducerParams*>(node->builtin_data);
2004 
2005         return VisitMeanNode(subgraph, delegate, logging_context, node_index,
2006                              node, context->tensors, reducer_params,
2007                              xnnpack_tensors);
2008       }
2009       case kTfLiteBuiltinMinimum:
2010         return VisitMinimumNode(subgraph, delegate, logging_context, node_index,
2011                                 node, context->tensors, xnnpack_tensors);
2012       case kTfLiteBuiltinMul: {
2013         const TfLiteMulParams* mul_params =
2014             static_cast<const TfLiteMulParams*>(node->builtin_data);
2015 
2016         return VisitMulNode(subgraph, delegate, logging_context, node_index,
2017                             node, context->tensors, mul_params,
2018                             xnnpack_tensors);
2019       }
2020       case kTfLiteBuiltinNeg:
2021         return VisitNegNode(subgraph, delegate, logging_context, node_index,
2022                             node, context->tensors, xnnpack_tensors);
2023       case kTfLiteBuiltinPad:
2024         return VisitPadNode(subgraph, delegate, logging_context, node_index,
2025                             node, context->tensors, xnnpack_tensors);
2026       case kTfLiteBuiltinPrelu:
2027         return VisitPreluNode(subgraph, delegate, logging_context, node_index,
2028                               node, context->tensors, quasi_static_tensors,
2029                               xnnpack_tensors);
2030       case kTfLiteBuiltinQuantize:
2031         return VisitQuantizeNode(subgraph, delegate, logging_context,
2032                                  node_index, node, context->tensors,
2033                                  xnnpack_tensors);
2034       case kTfLiteBuiltinRelu:
2035         return VisitReluNode(subgraph, delegate, logging_context, node_index,
2036                              node, context->tensors, 0.0f,
2037                              std::numeric_limits<float>::infinity(),
2038                              xnnpack_tensors);
2039       case kTfLiteBuiltinReluN1To1:
2040         return VisitReluNode(subgraph, delegate, logging_context, node_index,
2041                              node, context->tensors, -1.0f, 1.0f,
2042                              xnnpack_tensors);
2043       case kTfLiteBuiltinRelu6:
2044         return VisitReluNode(subgraph, delegate, logging_context, node_index,
2045                              node, context->tensors, 0.0f, 6.0f,
2046                              xnnpack_tensors);
2047       case kTfLiteBuiltinReshape: {
2048         const TfLiteReshapeParams* reshape_params =
2049             static_cast<const TfLiteReshapeParams*>(node->builtin_data);
2050 
2051         return VisitReshapeNode(subgraph, delegate, logging_context, node_index,
2052                                 node, context->tensors, reshape_params,
2053                                 xnnpack_tensors);
2054       }
2055       case kTfLiteBuiltinResizeBilinear: {
2056         const TfLiteResizeBilinearParams* resize_params =
2057             static_cast<const TfLiteResizeBilinearParams*>(node->builtin_data);
2058 
2059         return VisitResizeBilinearNode(subgraph, delegate, logging_context,
2060                                        node_index, node, context->tensors,
2061                                        resize_params, xnnpack_tensors);
2062       }
2063       case kTfLiteBuiltinRound:
2064         return VisitRoundNode(subgraph, delegate, logging_context, node_index,
2065                               node, context->tensors, xnnpack_tensors);
2066       case kTfLiteBuiltinSoftmax: {
2067         const TfLiteSoftmaxParams* softmax_params =
2068             static_cast<const TfLiteSoftmaxParams*>(node->builtin_data);
2069 
2070         return VisitSoftmaxNode(subgraph, delegate, logging_context, node_index,
2071                                 node, context->tensors, softmax_params,
2072                                 xnnpack_tensors);
2073       }
2074       case kTfLiteBuiltinSplit: {
2075         const TfLiteSplitParams* split_params =
2076             static_cast<const TfLiteSplitParams*>(node->builtin_data);
2077         return VisitSplitNode(subgraph, delegate, logging_context, node_index,
2078                               node, context->tensors, split_params,
2079                               xnnpack_tensors);
2080       }
2081       case kTfLiteBuiltinSqrt:
2082         return VisitSqrtNode(subgraph, delegate, logging_context, node_index,
2083                              node, context->tensors, xnnpack_tensors);
2084       case kTfLiteBuiltinSquare:
2085         return VisitSquareNode(subgraph, delegate, logging_context, node_index,
2086                                node, context->tensors, xnnpack_tensors);
2087       case kTfLiteBuiltinSquaredDifference:
2088         return VisitSquaredDifferenceNode(subgraph, delegate, logging_context,
2089                                           node_index, node, context->tensors,
2090                                           xnnpack_tensors);
2091       case kTfLiteBuiltinSub: {
2092         const TfLiteSubParams* sub_params =
2093             static_cast<const TfLiteSubParams*>(node->builtin_data);
2094 
2095         return VisitSubNode(subgraph, delegate, logging_context, node_index,
2096                             node, context->tensors, sub_params,
2097                             xnnpack_tensors);
2098       }
2099       case kTfLiteBuiltinTranspose: {
2100         return VisitTransposeNode(subgraph, delegate, logging_context,
2101                                   node_index, node, context->tensors,
2102                                   xnnpack_tensors);
2103       }
2104       case kTfLiteBuiltinTransposeConv: {
2105         const TfLiteTransposeConvParams* deconv_params =
2106             static_cast<const TfLiteTransposeConvParams*>(node->builtin_data);
2107 
2108         return VisitTransposeConvNode(subgraph, delegate, logging_context,
2109                                       node_index, node, context->tensors,
2110                                       deconv_params, quasi_static_tensors,
2111                                       xnnpack_tensors);
2112       }
2113       case kTfLiteBuiltinCustom: {
2114         if (strcmp(registration->custom_name, "Convolution2DTransposeBias") ==
2115             0) {
2116           TfLiteTransposeConvParams deconv_params = {kTfLitePaddingUnknown};
2117           SafeCopyCustomData(*node, &deconv_params);
2118 
2119           return VisitMediaPipeDeconvolutionNode(
2120               subgraph, delegate, context, node_index, node, context->tensors,
2121               &deconv_params, quasi_static_tensors, xnnpack_tensors);
2122         } else if (strcmp(registration->custom_name,
2123                           "MaxPoolingWithArgmax2D") == 0) {
2124           TfLitePoolParams pool_params = {kTfLitePaddingUnknown};
2125           SafeCopyCustomData(*node, &pool_params);
2126 
2127           return VisitMediaPipeMaxPoolingNode(
2128               subgraph, delegate, context, node_index, node, context->tensors,
2129               &pool_params, xnnpack_tensors);
2130         } else if (strcmp(registration->custom_name, "MaxUnpooling2D") == 0) {
2131           TfLitePoolParams pool_params = {kTfLitePaddingUnknown};
2132           SafeCopyCustomData(*node, &pool_params);
2133 
2134           return VisitMediaPipeUnpoolingNode(subgraph, delegate, context,
2135                                              node_index, node, context->tensors,
2136                                              &pool_params, xnnpack_tensors);
2137         }
2138         return kTfLiteError;
2139       }
2140       default:
2141         return kTfLiteError;
2142     }
2143   }
2144 
VisitAbsNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const std::vector<uint32_t> & xnnpack_tensors)2145   static TfLiteStatus VisitAbsNode(
2146       xnn_subgraph_t subgraph, const Delegate& delegate,
2147       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
2148       const TfLiteTensor* tensors,
2149       const std::vector<uint32_t>& xnnpack_tensors) {
2150     TF_LITE_ENSURE_STATUS(
2151         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
2152 
2153     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
2154     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
2155         logging_context, input_tensor, node->inputs->data[0], node_index));
2156     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2157         logging_context, input_tensor, node->inputs->data[0], node_index));
2158 
2159     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
2160     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
2161         logging_context, output_tensor, node->outputs->data[0], node_index));
2162     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2163         logging_context, output_tensor, node->outputs->data[0], node_index));
2164 
2165     if (subgraph != nullptr) {
2166       const xnn_status status = xnn_define_abs(
2167           subgraph, /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
2168           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
2169       if (status != xnn_status_success) {
2170         TF_LITE_KERNEL_LOG(logging_context, "failed to delegate ABS node #%d",
2171                            node_index);
2172         return kTfLiteError;
2173       }
2174     }
2175 
2176     return kTfLiteOk;
2177   }
2178 
VisitAddNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLiteAddParams * add_params,const std::vector<uint32_t> & xnnpack_tensors)2179   static TfLiteStatus VisitAddNode(
2180       xnn_subgraph_t subgraph, const Delegate& delegate,
2181       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
2182       const TfLiteTensor* tensors, const TfLiteAddParams* add_params,
2183       const std::vector<uint32_t>& xnnpack_tensors) {
2184     TF_LITE_ENSURE_STATUS(
2185         CheckNumInputsAndOutputs(logging_context, node, 2, 1, node_index));
2186 
2187     const TfLiteTensor& input1_tensor = tensors[node->inputs->data[0]];
2188     TF_LITE_ENSURE_STATUS(
2189         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input1_tensor,
2190                                        node->inputs->data[0], node_index));
2191     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2192         logging_context, input1_tensor, node->inputs->data[0], node_index));
2193 
2194     const TfLiteTensor& input2_tensor = tensors[node->inputs->data[1]];
2195     TF_LITE_ENSURE_STATUS(
2196         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input2_tensor,
2197                                        node->inputs->data[1], node_index));
2198     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2199         logging_context, input2_tensor, node->inputs->data[1], node_index));
2200 
2201     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
2202     TF_LITE_ENSURE_STATUS(
2203         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor,
2204                                        node->outputs->data[0], node_index));
2205     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2206         logging_context, output_tensor, node->outputs->data[0], node_index));
2207 
2208     const float scale_min = 1.0f / 1024.0f;
2209     const float scale_max = 256.0f;
2210     TF_LITE_ENSURE_STATUS(CheckTensorsInputOutputScale(
2211         logging_context, input1_tensor, output_tensor, scale_min, scale_max,
2212         node_index, "ADD"));
2213     TF_LITE_ENSURE_STATUS(CheckTensorsInputOutputScale(
2214         logging_context, input2_tensor, output_tensor, scale_min, scale_max,
2215         node_index, "ADD"));
2216 
2217     float output_min = -std::numeric_limits<float>::infinity();
2218     float output_max = +std::numeric_limits<float>::infinity();
2219     if (add_params != nullptr) {
2220       TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange(
2221           logging_context, node_index, add_params->activation, &output_min,
2222           &output_max));
2223     }
2224 
2225     if (subgraph != nullptr) {
2226       const xnn_status status = xnn_define_add2(
2227           subgraph, output_min, output_max,
2228           /*input1_id=*/xnnpack_tensors[node->inputs->data[0]],
2229           /*input2_id=*/xnnpack_tensors[node->inputs->data[1]],
2230           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
2231       if (status != xnn_status_success) {
2232         TF_LITE_KERNEL_LOG(logging_context, "failed to delegate ADD node #%d",
2233                            node_index);
2234         return kTfLiteError;
2235       }
2236     }
2237 
2238     return kTfLiteOk;
2239   }
2240 
VisitAveragePool2DNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLitePoolParams * pool_params,const std::vector<uint32_t> & xnnpack_tensors)2241   static TfLiteStatus VisitAveragePool2DNode(
2242       xnn_subgraph_t subgraph, const Delegate& delegate,
2243       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
2244       const TfLiteTensor* tensors, const TfLitePoolParams* pool_params,
2245       const std::vector<uint32_t>& xnnpack_tensors) {
2246     TF_LITE_ENSURE_STATUS(
2247         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
2248 
2249     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
2250     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
2251         logging_context, input_tensor, node->inputs->data[0], node_index));
2252     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2253         logging_context, input_tensor, node->inputs->data[0], node_index));
2254 
2255     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
2256     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
2257         logging_context, output_tensor, node->outputs->data[0], node_index));
2258     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2259         logging_context, output_tensor, node->outputs->data[0], node_index));
2260 
2261     TF_LITE_ENSURE_STATUS(
2262         CheckPoolingParams(logging_context, pool_params, node_index));
2263 
2264     uint32_t flags = 0;
2265     TF_LITE_ENSURE_STATUS(CalculatePadding(
2266         logging_context, pool_params->padding, &flags, node_index));
2267 
2268     float output_min = -std::numeric_limits<float>::infinity();
2269     float output_max = +std::numeric_limits<float>::infinity();
2270     TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange(
2271         logging_context, node_index, pool_params->activation, &output_min,
2272         &output_max));
2273 
2274     if (subgraph != nullptr) {
2275       xnn_status status = xnn_status_success;
2276       if (pool_params->filter_height == 1 && pool_params->filter_width == 1) {
2277         status = xnn_define_clamp(
2278             subgraph, output_min, output_max,
2279             /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
2280             /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
2281       } else {
2282         status = xnn_define_average_pooling_2d(
2283             subgraph,
2284             /*input_padding_top=*/0,
2285             /*input_padding_right=*/0,
2286             /*input_padding_bottom=*/0,
2287             /*input_padding_left=*/0,
2288             static_cast<uint32_t>(pool_params->filter_height),
2289             static_cast<uint32_t>(pool_params->filter_width),
2290             static_cast<uint32_t>(pool_params->stride_height),
2291             static_cast<uint32_t>(pool_params->stride_width), output_min,
2292             output_max,
2293             /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
2294             /*output_id=*/xnnpack_tensors[node->outputs->data[0]], flags);
2295       }
2296       if (status != xnn_status_success) {
2297         TF_LITE_KERNEL_LOG(logging_context,
2298                            "failed to delegate AVERAGE_POOL_2D node #%d",
2299                            node_index);
2300         return kTfLiteError;
2301       }
2302     }
2303 
2304     return kTfLiteOk;
2305   }
2306 
VisitCeilNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const std::vector<uint32_t> & xnnpack_tensors)2307   static TfLiteStatus VisitCeilNode(
2308       xnn_subgraph_t subgraph, const Delegate& delegate,
2309       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
2310       const TfLiteTensor* tensors,
2311       const std::vector<uint32_t>& xnnpack_tensors) {
2312     TF_LITE_ENSURE_STATUS(
2313         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
2314 
2315     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
2316     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
2317         logging_context, input_tensor, node->inputs->data[0], node_index));
2318     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2319         logging_context, input_tensor, node->inputs->data[0], node_index));
2320 
2321     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
2322     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
2323         logging_context, output_tensor, node->outputs->data[0], node_index));
2324     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2325         logging_context, output_tensor, node->outputs->data[0], node_index));
2326 
2327     if (subgraph != nullptr) {
2328       const xnn_status status = xnn_define_ceiling(
2329           subgraph, /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
2330           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
2331       if (status != xnn_status_success) {
2332         TF_LITE_KERNEL_LOG(logging_context, "failed to delegate CEIL node #%d",
2333                            node_index);
2334         return kTfLiteError;
2335       }
2336     }
2337 
2338     return kTfLiteOk;
2339   }
2340 
VisitSplitNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLiteSplitParams * split_params,const std::vector<uint32_t> & xnnpack_tensors)2341   static TfLiteStatus VisitSplitNode(
2342       xnn_subgraph_t subgraph, const Delegate& delegate,
2343       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
2344       const TfLiteTensor* tensors, const TfLiteSplitParams* split_params,
2345       const std::vector<uint32_t>& xnnpack_tensors) {
2346     const int num_outputs = NumOutputs(node);
2347     TF_LITE_ENSURE_EQ(logging_context, split_params->num_splits, num_outputs);
2348     TF_LITE_ENSURE_STATUS(CheckNumInputs(logging_context, node, 2, node_index));
2349     TF_LITE_ENSURE_STATUS(
2350         CheckNumOutputs(logging_context, node, 2, 4, node_index));
2351 
2352     const int split_dim_idx = node->inputs->data[0];
2353     const TfLiteTensor& split_dim_tensor = tensors[split_dim_idx];
2354     TF_LITE_ENSURE_STATUS(CheckTensorType(logging_context, split_dim_tensor,
2355                                           kTfLiteInt32, split_dim_idx,
2356                                           node_index));
2357     TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
2358         logging_context, split_dim_tensor, split_dim_idx, node_index));
2359 
2360     const int input_idx = node->inputs->data[1];
2361     const TfLiteTensor& input_tensor = tensors[input_idx];
2362     TF_LITE_ENSURE_STATUS(CheckTensorFloat32OrQUInt8Type(
2363         delegate, logging_context, input_tensor, input_idx, node_index));
2364     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2365         logging_context, input_tensor, input_idx, node_index));
2366 
2367     int32_t split_dim = GetTensorData<int32_t>(&split_dim_tensor)[0];
2368     if (split_dim < 0) split_dim += NumDimensions(&input_tensor);
2369     TF_LITE_ENSURE(logging_context, split_dim >= 0);
2370     TF_LITE_ENSURE(logging_context, split_dim < NumDimensions(&input_tensor));
2371 
2372     const int input_split_dim_size = SizeOfDimension(&input_tensor, split_dim);
2373     if (input_split_dim_size % num_outputs != 0) {
2374       TF_LITE_MAYBE_KERNEL_LOG(
2375           logging_context,
2376           "Cannot evenly split dimension %d, which is %d, into %d", split_dim,
2377           input_split_dim_size, num_outputs);
2378       return kTfLiteError;
2379     }
2380 
2381     const int32_t expected_output_split_dim_size =
2382         input_split_dim_size / num_outputs;
2383 
2384     for (int i = 0; i < NumOutputs(node); i++) {
2385       const int output_idx = node->outputs->data[i];
2386       const TfLiteTensor& output_tensor = tensors[output_idx];
2387 
2388       TF_LITE_ENSURE_STATUS(CheckTensorFloat32OrQUInt8Type(
2389           delegate, logging_context, output_tensor, output_idx, node_index));
2390       TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2391           logging_context, output_tensor, output_idx, node_index));
2392       TF_LITE_ENSURE_EQ(logging_context, NumDimensions(&input_tensor),
2393                         NumDimensions(&output_tensor));
2394 
2395       for (int d = 0; d < NumDimensions(&input_tensor); d++) {
2396         if (d == split_dim) {
2397           if (SizeOfDimension(&output_tensor, split_dim) !=
2398               expected_output_split_dim_size) {
2399             TF_LITE_MAYBE_KERNEL_LOG(
2400                 logging_context,
2401                 "mismatch in split dimension %d (%d != %d) "
2402                 "in output %d and input"
2403                 "tensors of SPLIT operator #%d",
2404                 split_dim, SizeOfDimension(&output_tensor, split_dim),
2405                 expected_output_split_dim_size, d, node_index);
2406             return kTfLiteError;
2407           }
2408         } else {
2409           TF_LITE_ENSURE_STATUS(CheckTensorsDimensionMatch(
2410               logging_context, input_tensor, output_tensor, d, node_index,
2411               "SPLIT"));
2412         }
2413       }
2414     }
2415 
2416     if (subgraph != nullptr) {
2417       xnn_status status = xnn_status_invalid_parameter;
2418       if (num_outputs == 2) {
2419         status = xnn_define_even_split2(
2420             subgraph, split_dim,
2421             /*input_id=*/xnnpack_tensors[input_idx],
2422             /*output1_id=*/xnnpack_tensors[node->outputs->data[0]],
2423             /*output2_id=*/xnnpack_tensors[node->outputs->data[1]],
2424             /*flags=*/0);
2425       } else if (num_outputs == 3) {
2426         status = xnn_define_even_split3(
2427             subgraph, split_dim,
2428             /*input_id=*/xnnpack_tensors[input_idx],
2429             /*output1_id=*/xnnpack_tensors[node->outputs->data[0]],
2430             /*output2_id=*/xnnpack_tensors[node->outputs->data[1]],
2431             /*output3_id=*/xnnpack_tensors[node->outputs->data[2]],
2432             /*flags=*/0);
2433       } else if (num_outputs == 4) {
2434         status = xnn_define_even_split4(
2435             subgraph, split_dim,
2436             /*input_id=*/xnnpack_tensors[input_idx],
2437             /*output1_id=*/xnnpack_tensors[node->outputs->data[0]],
2438             /*output2_id=*/xnnpack_tensors[node->outputs->data[1]],
2439             /*output3_id=*/xnnpack_tensors[node->outputs->data[2]],
2440             /*output4_id=*/xnnpack_tensors[node->outputs->data[3]],
2441             /*flags=*/0);
2442       }
2443 
2444       if (status != xnn_status_success) {
2445         TF_LITE_KERNEL_LOG(logging_context, "failed to delegate SPLIT node #%d",
2446                            node_index);
2447         return kTfLiteError;
2448       }
2449     }
2450 
2451     return kTfLiteOk;
2452   }
2453 
VisitConcatenationNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLiteConcatenationParams * concat_params,const std::vector<uint32_t> & xnnpack_tensors)2454   static TfLiteStatus VisitConcatenationNode(
2455       xnn_subgraph_t subgraph, const Delegate& delegate,
2456       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
2457       const TfLiteTensor* tensors,
2458       const TfLiteConcatenationParams* concat_params,
2459       const std::vector<uint32_t>& xnnpack_tensors) {
2460     TF_LITE_ENSURE_STATUS(
2461         CheckNumInputsAndOutputs(logging_context, node, 2, 4, 1, node_index));
2462     const int num_inputs = NumInputs(node);
2463 
2464     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
2465     TF_LITE_ENSURE_STATUS(
2466         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor,
2467                                        node->outputs->data[0], node_index));
2468     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2469         logging_context, output_tensor, node->outputs->data[0], node_index));
2470 
2471     // Check dimensions
2472     int axis = concat_params->axis;
2473     if (axis < 0) axis += NumDimensions(&output_tensor);
2474     int sum_axis = 0;
2475 
2476     for (int i = 0; i < num_inputs; i++) {
2477       const TfLiteTensor& input_tensor = tensors[node->inputs->data[i]];
2478       TF_LITE_ENSURE_STATUS(CheckTensorFloat32OrQUInt8Type(
2479           delegate, logging_context, input_tensor, node->inputs->data[i],
2480           node_index));
2481       TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2482           logging_context, input_tensor, node->inputs->data[i], node_index));
2483 
2484       TF_LITE_ENSURE_EQ(logging_context, NumDimensions(&input_tensor),
2485                         NumDimensions(&output_tensor));
2486 
2487       for (int d = 0; d < NumDimensions(&output_tensor); d++) {
2488         // All dimensions must match except the 'axis'.
2489         if (d == axis) {
2490           continue;
2491         }
2492         const TfLiteTensor& input_tensor = tensors[node->inputs->data[i]];
2493         TF_LITE_ENSURE_STATUS(CheckTensorsDimensionMatch(
2494             logging_context, input_tensor, output_tensor, d, node_index,
2495             "CONCATENATE"));
2496       }
2497       sum_axis += SizeOfDimension(&input_tensor, axis);
2498     }
2499 
2500     if (SizeOfDimension(&output_tensor, axis) != sum_axis) {
2501       TF_LITE_MAYBE_KERNEL_LOG(
2502           logging_context,
2503           "mismatch in axis dimension %d (%d != %d) in output and input"
2504           "tensors of CONCATENATE operator #%d",
2505           axis, SizeOfDimension(&output_tensor, axis), sum_axis, node_index);
2506       return kTfLiteError;
2507     }
2508 
2509     if (subgraph != nullptr) {
2510       xnn_status status = xnn_status_invalid_parameter;
2511       if (num_inputs == 2) {
2512         status = xnn_define_concatenate2(
2513             subgraph, axis,
2514             /*input1_id=*/xnnpack_tensors[node->inputs->data[0]],
2515             /*input2_id=*/xnnpack_tensors[node->inputs->data[1]],
2516             /*output_id=*/xnnpack_tensors[node->outputs->data[0]],
2517             /*flags=*/0);
2518       } else if (num_inputs == 3) {
2519         status = xnn_define_concatenate3(
2520             subgraph, axis,
2521             /*input1_id=*/xnnpack_tensors[node->inputs->data[0]],
2522             /*input2_id=*/xnnpack_tensors[node->inputs->data[1]],
2523             /*input3_id=*/xnnpack_tensors[node->inputs->data[2]],
2524             /*output_id=*/xnnpack_tensors[node->outputs->data[0]],
2525             /*flags=*/0);
2526       } else if (num_inputs == 4) {
2527         status = xnn_define_concatenate4(
2528             subgraph, axis,
2529             /*input1_id=*/xnnpack_tensors[node->inputs->data[0]],
2530             /*input2_id=*/xnnpack_tensors[node->inputs->data[1]],
2531             /*input3_id=*/xnnpack_tensors[node->inputs->data[2]],
2532             /*input4_id=*/xnnpack_tensors[node->inputs->data[3]],
2533             /*output_id=*/xnnpack_tensors[node->outputs->data[0]],
2534             /*flags=*/0);
2535       }
2536       if (status != xnn_status_success) {
2537         TF_LITE_KERNEL_LOG(logging_context,
2538                            "failed to delegate CONCATENATION node #%d",
2539                            node_index);
2540         return kTfLiteError;
2541       }
2542     }
2543     return kTfLiteOk;
2544   }
2545 
VisitConv2DNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLiteConvParams * conv_params,const std::unordered_set<int> & quasi_static_tensors,const std::vector<uint32_t> & xnnpack_tensors)2546   static TfLiteStatus VisitConv2DNode(
2547       xnn_subgraph_t subgraph, const Delegate& delegate,
2548       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
2549       const TfLiteTensor* tensors, const TfLiteConvParams* conv_params,
2550       const std::unordered_set<int>& quasi_static_tensors,
2551       const std::vector<uint32_t>& xnnpack_tensors) {
2552     TF_LITE_ENSURE_STATUS(
2553         CheckConvolutionParams(logging_context, conv_params, node_index));
2554 
2555     TF_LITE_ENSURE_STATUS(
2556         CheckNumInputsAndOutputs(logging_context, node, 3, 1, node_index));
2557 
2558     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
2559     TF_LITE_ENSURE_STATUS(
2560         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor,
2561                                        node->inputs->data[0], node_index));
2562     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_tensor, 4,
2563                                            node->inputs->data[0]));
2564     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2565         logging_context, input_tensor, node->inputs->data[0], node_index));
2566 
2567     const TfLiteTensor& filter_tensor = tensors[node->inputs->data[1]];
2568     TF_LITE_ENSURE_STATUS(CheckTensorFloat32OrQCInt8Type(
2569         delegate, logging_context, filter_tensor,
2570         /*expected_quantized_dimension=*/0, node->inputs->data[1], node_index));
2571     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, filter_tensor, 4,
2572                                            node->inputs->data[1]));
2573     if (quasi_static_tensors.count(node->inputs->data[1]) == 0) {
2574       TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
2575           logging_context, filter_tensor, node->inputs->data[1], node_index));
2576     }
2577 
2578     const int bias_tensor_id = node->inputs->data[2];
2579     if (bias_tensor_id < 0) {
2580       TF_LITE_MAYBE_KERNEL_LOG(logging_context,
2581                                "unsupported CONV_2D node #%d without bias",
2582                                node_index);
2583       return kTfLiteError;
2584     }
2585     const TfLiteTensor& bias_tensor = tensors[bias_tensor_id];
2586     TF_LITE_ENSURE_STATUS(
2587         CheckTensorFloat32OrQCInt32Type(delegate, logging_context, bias_tensor,
2588                                         node->inputs->data[2], node_index));
2589     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, bias_tensor, 1,
2590                                            node->inputs->data[2]));
2591     if (quasi_static_tensors.count(node->inputs->data[2]) == 0) {
2592       TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
2593           logging_context, bias_tensor, node->inputs->data[2], node_index));
2594     }
2595 
2596     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
2597     TF_LITE_ENSURE_STATUS(
2598         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor,
2599                                        node->outputs->data[0], node_index));
2600     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_tensor, 4,
2601                                            node->outputs->data[0]));
2602     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2603         logging_context, output_tensor, node->outputs->data[0], node_index));
2604 
2605     if (input_tensor.type != output_tensor.type ||
2606         input_tensor.type != filter_tensor.type) {
2607       TF_LITE_MAYBE_KERNEL_LOG(
2608           logging_context, "unsupported mixed types in CONV_2D operator #%d",
2609           node_index);
2610       return kTfLiteError;
2611     }
2612 
2613     const int output_channels = SizeOfDimension(&filter_tensor, 0);
2614     const int kernel_height = SizeOfDimension(&filter_tensor, 1);
2615     const int kernel_width = SizeOfDimension(&filter_tensor, 2);
2616     const int input_channels = SizeOfDimension(&filter_tensor, 3);
2617     const int groups = SizeOfDimension(&input_tensor, 3) / input_channels;
2618 
2619     uint32_t flags;
2620     TF_LITE_ENSURE_STATUS(CalculatePadding(
2621         logging_context, conv_params->padding, &flags, node_index));
2622 
2623     float output_min = -std::numeric_limits<float>::infinity();
2624     float output_max = +std::numeric_limits<float>::infinity();
2625     TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange(
2626         logging_context, node_index, conv_params->activation, &output_min,
2627         &output_max));
2628 
2629     if (subgraph != nullptr) {
2630       const xnn_status status = xnn_define_convolution_2d(
2631           subgraph,
2632           /*input_padding_top=*/0,
2633           /*input_padding_right=*/0,
2634           /*input_padding_bottom=*/0,
2635           /*input_padding_left=*/0, static_cast<uint32_t>(kernel_height),
2636           static_cast<uint32_t>(kernel_width),
2637           static_cast<uint32_t>(conv_params->stride_height),
2638           static_cast<uint32_t>(conv_params->stride_width),
2639           static_cast<uint32_t>(conv_params->dilation_height_factor),
2640           static_cast<uint32_t>(conv_params->dilation_width_factor), groups,
2641           static_cast<size_t>(input_channels),
2642           static_cast<size_t>(output_channels) / groups, output_min, output_max,
2643           /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
2644           /*filter_id=*/xnnpack_tensors[node->inputs->data[1]],
2645           /*bias_id=*/xnnpack_tensors[node->inputs->data[2]],
2646           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], flags);
2647       if (status != xnn_status_success) {
2648         TF_LITE_KERNEL_LOG(logging_context,
2649                            "failed to delegate CONV_2D node #%d", node_index);
2650         return kTfLiteError;
2651       }
2652     }
2653 
2654     return kTfLiteOk;
2655   }
2656 
VisitDepthwiseConv2DNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLiteDepthwiseConvParams * dwconv_params,const std::unordered_set<int> & quasi_static_tensors,const std::vector<uint32_t> & xnnpack_tensors)2657   static TfLiteStatus VisitDepthwiseConv2DNode(
2658       xnn_subgraph_t subgraph, const Delegate& delegate,
2659       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
2660       const TfLiteTensor* tensors,
2661       const TfLiteDepthwiseConvParams* dwconv_params,
2662       const std::unordered_set<int>& quasi_static_tensors,
2663       const std::vector<uint32_t>& xnnpack_tensors) {
2664     TF_LITE_ENSURE_STATUS(
2665         CheckNumInputsAndOutputs(logging_context, node, 3, 1, node_index));
2666 
2667     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
2668     TF_LITE_ENSURE_STATUS(
2669         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor,
2670                                        node->inputs->data[0], node_index));
2671     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_tensor, 4,
2672                                            node->inputs->data[0]));
2673     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2674         logging_context, input_tensor, node->inputs->data[0], node_index));
2675 
2676     const TfLiteTensor& filter_tensor = tensors[node->inputs->data[1]];
2677     TF_LITE_ENSURE_STATUS(CheckTensorFloat32OrQCInt8Type(
2678         delegate, logging_context, filter_tensor,
2679         /*expected_quantized_dimension=*/3, node->inputs->data[1], node_index));
2680     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, filter_tensor, 4,
2681                                            node->inputs->data[1]));
2682     if (quasi_static_tensors.count(node->inputs->data[1]) == 0) {
2683       TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
2684           logging_context, filter_tensor, node->inputs->data[1], node_index));
2685     }
2686 
2687     const int bias_tensor_id = node->inputs->data[2];
2688     if (bias_tensor_id < 0) {
2689       TF_LITE_MAYBE_KERNEL_LOG(
2690           logging_context,
2691           "unsupported DEPTHWISE_CONV_2D node #%d without bias", node_index);
2692       return kTfLiteError;
2693     }
2694     const TfLiteTensor& bias_tensor = tensors[bias_tensor_id];
2695     TF_LITE_ENSURE_STATUS(
2696         CheckTensorFloat32OrQCInt32Type(delegate, logging_context, bias_tensor,
2697                                         node->inputs->data[2], node_index));
2698     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, bias_tensor, 1,
2699                                            node->inputs->data[2]));
2700     if (quasi_static_tensors.count(node->inputs->data[2]) == 0) {
2701       TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
2702           logging_context, bias_tensor, node->inputs->data[2], node_index));
2703     }
2704 
2705     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
2706     TF_LITE_ENSURE_STATUS(
2707         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor,
2708                                        node->outputs->data[0], node_index));
2709     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_tensor, 4,
2710                                            node->outputs->data[0]));
2711     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2712         logging_context, output_tensor, node->outputs->data[0], node_index));
2713 
2714     if (input_tensor.type != output_tensor.type ||
2715         input_tensor.type != filter_tensor.type) {
2716       TF_LITE_MAYBE_KERNEL_LOG(
2717           logging_context,
2718           "unsupported mixed types in DEPTHWISE_CONV_2D operator #%d",
2719           node_index);
2720       return kTfLiteError;
2721     }
2722 
2723     const int kernel_height = SizeOfDimension(&filter_tensor, 1);
2724     const int kernel_width = SizeOfDimension(&filter_tensor, 2);
2725     const int output_channels = SizeOfDimension(&filter_tensor, 3);
2726 
2727     TF_LITE_ENSURE_STATUS(CheckDepthwiseConvolutionParams(
2728         logging_context, dwconv_params, output_channels, node_index));
2729 
2730     uint32_t flags = 0;
2731     TF_LITE_ENSURE_STATUS(CalculatePadding(
2732         logging_context, dwconv_params->padding, &flags, node_index));
2733 
2734     float output_min = -std::numeric_limits<float>::infinity();
2735     float output_max = +std::numeric_limits<float>::infinity();
2736     TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange(
2737         logging_context, node_index, dwconv_params->activation, &output_min,
2738         &output_max));
2739 
2740     if (subgraph != nullptr) {
2741       const xnn_status status = xnn_define_depthwise_convolution_2d(
2742           subgraph,
2743           /*input_padding_top=*/0,
2744           /*input_padding_right=*/0,
2745           /*input_padding_bottom=*/0,
2746           /*input_padding_left=*/0, static_cast<uint32_t>(kernel_height),
2747           static_cast<uint32_t>(kernel_width),
2748           static_cast<uint32_t>(dwconv_params->stride_height),
2749           static_cast<uint32_t>(dwconv_params->stride_width),
2750           static_cast<uint32_t>(dwconv_params->dilation_height_factor),
2751           static_cast<uint32_t>(dwconv_params->dilation_width_factor),
2752           static_cast<uint32_t>(dwconv_params->depth_multiplier),
2753           /*input_channels=*/
2754           static_cast<uint32_t>(output_channels /
2755                                 dwconv_params->depth_multiplier),
2756           output_min, output_max,
2757           /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
2758           /*filter_id=*/xnnpack_tensors[node->inputs->data[1]],
2759           /*bias_id=*/xnnpack_tensors[node->inputs->data[2]],
2760           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], flags);
2761       if (status != xnn_status_success) {
2762         TF_LITE_KERNEL_LOG(logging_context,
2763                            "failed to delegate DEPTHWISE_CONV_2D node #%d",
2764                            node_index);
2765         return kTfLiteError;
2766       }
2767     }
2768 
2769     return kTfLiteOk;
2770   }
2771 
VisitDepthToSpaceNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLiteDepthToSpaceParams * depth_to_space_params,const std::vector<uint32_t> & xnnpack_tensors)2772   static TfLiteStatus VisitDepthToSpaceNode(
2773       xnn_subgraph_t subgraph, const Delegate& delegate,
2774       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
2775       const TfLiteTensor* tensors,
2776       const TfLiteDepthToSpaceParams* depth_to_space_params,
2777       const std::vector<uint32_t>& xnnpack_tensors) {
2778     TF_LITE_ENSURE_STATUS(
2779         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
2780 
2781     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
2782     TF_LITE_ENSURE_STATUS(
2783         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor,
2784                                        node->inputs->data[0], node_index));
2785     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2786         logging_context, input_tensor, node->inputs->data[0], node_index));
2787 
2788     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
2789     TF_LITE_ENSURE_STATUS(
2790         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor,
2791                                        node->outputs->data[0], node_index));
2792     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2793         logging_context, output_tensor, node->outputs->data[0], node_index));
2794 
2795     if (depth_to_space_params->block_size <= 1) {
2796       TF_LITE_MAYBE_KERNEL_LOG(
2797           logging_context, "invalid block size (%d) in DEPTH_TO_SPACE node #%d",
2798           depth_to_space_params->block_size, node_index);
2799       return kTfLiteError;
2800     }
2801 
2802     if (subgraph != nullptr) {
2803       const xnn_status status = xnn_define_depth_to_space(
2804           subgraph,
2805           /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
2806           /*output_id=*/xnnpack_tensors[node->outputs->data[0]],
2807           /*block_size=*/
2808           static_cast<uint32_t>(depth_to_space_params->block_size),
2809           /*flags=*/0);
2810       if (status != xnn_status_success) {
2811         TF_LITE_KERNEL_LOG(logging_context,
2812                            "failed to delegate DEPTH_TO_SPACE node #%d",
2813                            node_index);
2814         return kTfLiteError;
2815       }
2816     }
2817 
2818     return kTfLiteOk;
2819   }
2820 
VisitDequantizeNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const std::vector<uint32_t> & xnnpack_tensors)2821   static TfLiteStatus VisitDequantizeNode(
2822       xnn_subgraph_t subgraph, const Delegate& delegate,
2823       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
2824       const TfLiteTensor* tensors,
2825       const std::vector<uint32_t>& xnnpack_tensors) {
2826     TF_LITE_ENSURE_STATUS(
2827         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
2828 
2829     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
2830     TF_LITE_ENSURE_STATUS(
2831         CheckTensorQInt8OrQUInt8Type(delegate, logging_context, input_tensor,
2832                                      node->inputs->data[0], node_index));
2833     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2834         logging_context, input_tensor, node->inputs->data[0], node_index));
2835 
2836     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
2837     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
2838         logging_context, output_tensor, node->outputs->data[0], node_index));
2839     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2840         logging_context, output_tensor, node->outputs->data[0], node_index));
2841 
2842     if (subgraph != nullptr) {
2843       const xnn_status status = xnn_define_convert(
2844           subgraph, /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
2845           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
2846       if (status != xnn_status_success) {
2847         TF_LITE_KERNEL_LOG(logging_context,
2848                            "failed to delegate DEQUANTIZE node #%d",
2849                            node_index);
2850         return kTfLiteError;
2851       }
2852     }
2853 
2854     return kTfLiteOk;
2855   }
2856 
VisitDivNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLiteDivParams * div_params,const std::vector<uint32_t> & xnnpack_tensors)2857   static TfLiteStatus VisitDivNode(
2858       xnn_subgraph_t subgraph, const Delegate& delegate,
2859       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
2860       const TfLiteTensor* tensors, const TfLiteDivParams* div_params,
2861       const std::vector<uint32_t>& xnnpack_tensors) {
2862     TF_LITE_ENSURE_STATUS(
2863         CheckNumInputsAndOutputs(logging_context, node, 2, 1, node_index));
2864 
2865     const TfLiteTensor& input1_tensor = tensors[node->inputs->data[0]];
2866     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
2867         logging_context, input1_tensor, node->inputs->data[0], node_index));
2868     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2869         logging_context, input1_tensor, node->inputs->data[0], node_index));
2870 
2871     const TfLiteTensor& input2_tensor = tensors[node->inputs->data[1]];
2872     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
2873         logging_context, input2_tensor, node->inputs->data[1], node_index));
2874     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2875         logging_context, input2_tensor, node->inputs->data[1], node_index));
2876 
2877     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
2878     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
2879         logging_context, output_tensor, node->outputs->data[0], node_index));
2880     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2881         logging_context, output_tensor, node->outputs->data[0], node_index));
2882 
2883     float output_min = -std::numeric_limits<float>::infinity();
2884     float output_max = +std::numeric_limits<float>::infinity();
2885     if (div_params != nullptr) {
2886       TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange(
2887           logging_context, node_index, div_params->activation, &output_min,
2888           &output_max));
2889     }
2890 
2891     if (subgraph != nullptr) {
2892       const xnn_status status = xnn_define_divide(
2893           subgraph, output_min, output_max,
2894           /*input1_id=*/xnnpack_tensors[node->inputs->data[0]],
2895           /*input2_id=*/xnnpack_tensors[node->inputs->data[1]],
2896           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
2897       if (status != xnn_status_success) {
2898         TF_LITE_KERNEL_LOG(logging_context, "failed to delegate DIV node #%d",
2899                            node_index);
2900         return kTfLiteError;
2901       }
2902     }
2903 
2904     return kTfLiteOk;
2905   }
2906 
VisitEluNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const std::vector<uint32_t> & xnnpack_tensors)2907   static TfLiteStatus VisitEluNode(
2908       xnn_subgraph_t subgraph, const Delegate& delegate,
2909       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
2910       const TfLiteTensor* tensors,
2911       const std::vector<uint32_t>& xnnpack_tensors) {
2912     TF_LITE_ENSURE_STATUS(
2913         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
2914 
2915     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
2916     TF_LITE_ENSURE_STATUS(
2917         CheckTensorFloat32OrQInt8Type(delegate, logging_context, input_tensor,
2918                                       node->inputs->data[0], node_index));
2919     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2920         logging_context, input_tensor, node->inputs->data[0], node_index));
2921 
2922     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
2923     TF_LITE_ENSURE_STATUS(
2924         CheckTensorFloat32OrQInt8Type(delegate, logging_context, output_tensor,
2925                                       node->outputs->data[0], node_index));
2926     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2927         logging_context, output_tensor, node->outputs->data[0], node_index));
2928 
2929     if (subgraph != nullptr) {
2930       const xnn_status status =
2931           xnn_define_elu(subgraph, /*alpha=*/1.0f,
2932                          /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
2933                          /*output_id=*/xnnpack_tensors[node->outputs->data[0]],
2934                          /*flags=*/0);
2935       if (status != xnn_status_success) {
2936         TF_LITE_KERNEL_LOG(logging_context, "failed to delegate ELU node #%d",
2937                            node_index);
2938         return kTfLiteError;
2939       }
2940     }
2941 
2942     return kTfLiteOk;
2943   }
2944 
VisitFullyConnectedNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLiteFullyConnectedParams * fc_params,const std::unordered_set<int> & quasi_static_tensors,const std::vector<uint32_t> & xnnpack_tensors)2945   static TfLiteStatus VisitFullyConnectedNode(
2946       xnn_subgraph_t subgraph, const Delegate& delegate,
2947       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
2948       const TfLiteTensor* tensors, const TfLiteFullyConnectedParams* fc_params,
2949       const std::unordered_set<int>& quasi_static_tensors,
2950       const std::vector<uint32_t>& xnnpack_tensors) {
2951     TF_LITE_ENSURE_STATUS(
2952         CheckFullyConnectedParams(logging_context, fc_params, node_index));
2953 
2954     TF_LITE_ENSURE_STATUS(
2955         CheckNumInputsAndOutputs(logging_context, node, 2, 3, 1, node_index));
2956 
2957     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
2958     TF_LITE_ENSURE_STATUS(
2959         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor,
2960                                        node->inputs->data[0], node_index));
2961     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2962         logging_context, input_tensor, node->inputs->data[0], node_index));
2963 
2964     const TfLiteTensor& filter_tensor = tensors[node->inputs->data[1]];
2965     TF_LITE_ENSURE_STATUS(
2966         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, filter_tensor,
2967                                        node->inputs->data[1], node_index));
2968     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, filter_tensor, 2,
2969                                            node->inputs->data[1]));
2970     if (quasi_static_tensors.count(node->inputs->data[1]) == 0) {
2971       TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
2972           logging_context, filter_tensor, node->inputs->data[1], node_index));
2973     }
2974 
2975     int bias_tensor_id = -1;
2976     if (node->inputs->size >= 3) {
2977       bias_tensor_id = node->inputs->data[2];
2978       if (bias_tensor_id >= 0) {
2979         const TfLiteTensor& bias_tensor = tensors[bias_tensor_id];
2980         TF_LITE_ENSURE_STATUS(CheckTensorFloat32OrQInt32Type(
2981             delegate, logging_context, bias_tensor, node->inputs->data[2],
2982             node_index));
2983         TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, bias_tensor, 1,
2984                                                node->inputs->data[2]));
2985         if (quasi_static_tensors.count(node->inputs->data[2]) == 0) {
2986           TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
2987               logging_context, bias_tensor, node->inputs->data[2], node_index));
2988         }
2989       }
2990     }
2991 
2992     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
2993     TF_LITE_ENSURE_STATUS(
2994         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor,
2995                                        node->outputs->data[0], node_index));
2996     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
2997         logging_context, output_tensor, node->outputs->data[0], node_index));
2998 
2999     const int32_t output_channels = SizeOfDimension(&filter_tensor, 0);
3000     const int32_t input_channels = SizeOfDimension(&filter_tensor, 1);
3001 
3002     if (input_tensor.type != output_tensor.type ||
3003         input_tensor.type != filter_tensor.type) {
3004       TF_LITE_MAYBE_KERNEL_LOG(
3005           logging_context,
3006           "unsupported mixed types in FULLY_CONNECTED operator #%d",
3007           node_index);
3008       return kTfLiteError;
3009     }
3010 
3011     if (NumDimensions(&input_tensor) == 0) {
3012       TF_LITE_MAYBE_KERNEL_LOG(
3013           logging_context,
3014           "unexpected number of shape dimensions %d in tensor #%d",
3015           NumDimensions(&input_tensor), node->inputs->data[0]);
3016       return kTfLiteError;
3017     }
3018 
3019     int32_t num_input_elements = 1;
3020     for (int i = 0; i < NumDimensions(&input_tensor); i++) {
3021       if (SizeOfDimension(&input_tensor, i) <= 0) {
3022         TF_LITE_MAYBE_KERNEL_LOG(
3023             logging_context, "invalid dimension #%d (%d) in tensor #%d", i,
3024             SizeOfDimension(&input_tensor, i), node->inputs->data[0]);
3025         return kTfLiteError;
3026       }
3027       num_input_elements *= SizeOfDimension(&input_tensor, i);
3028     }
3029 
3030     if (fc_params->keep_num_dims) {
3031       TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_tensor,
3032                                              NumDimensions(&input_tensor),
3033                                              node->outputs->data[0]));
3034 
3035       for (int i = 0; i < NumDimensions(&input_tensor) - 1; i++) {
3036         if (SizeOfDimension(&input_tensor, i) !=
3037             SizeOfDimension(&output_tensor, i)) {
3038           TF_LITE_MAYBE_KERNEL_LOG(
3039               logging_context,
3040               "mismatch in shape dimension %d (%d != %d) in input and output "
3041               "tensors of FULLY_CONNECTED operator #%d",
3042               i, SizeOfDimension(&input_tensor, i),
3043               SizeOfDimension(&output_tensor, i), node_index);
3044           return kTfLiteError;
3045         }
3046       }
3047     } else {
3048       if (num_input_elements % input_channels != 0) {
3049         TF_LITE_MAYBE_KERNEL_LOG(
3050             logging_context,
3051             "number of elements in input tensor #%d in FULLY_CONNECTED "
3052             "operator is not divisible by input channels (%d)",
3053             node->inputs->data[0], input_channels);
3054         return kTfLiteError;
3055       }
3056 
3057       TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_tensor, 2,
3058                                              node->outputs->data[0]));
3059 
3060       if (SizeOfDimension(&output_tensor, 0) !=
3061           num_input_elements / input_channels) {
3062         TF_LITE_MAYBE_KERNEL_LOG(
3063             logging_context,
3064             "batch size %d in output tensor #%d in FULLY_CONNECTED operator "
3065             "does not match batch size %d in reshaped input tensor #%d",
3066             SizeOfDimension(&output_tensor, 0), node->outputs->data[0],
3067             num_input_elements / input_channels, node->inputs->data[0]);
3068         return kTfLiteError;
3069       }
3070     }
3071 
3072     if (SizeOfDimension(&output_tensor, NumDimensions(&output_tensor) - 1) !=
3073         output_channels) {
3074       TF_LITE_MAYBE_KERNEL_LOG(
3075           logging_context,
3076           "number of channels %d in output tensor #%d does not match output "
3077           "channels %d in filter tensor #%d",
3078           SizeOfDimension(&output_tensor, NumDimensions(&output_tensor) - 1),
3079           node->outputs->data[0], output_channels, node->inputs->data[1]);
3080       return kTfLiteError;
3081     }
3082 
3083     float output_min = -std::numeric_limits<float>::infinity();
3084     float output_max = +std::numeric_limits<float>::infinity();
3085     TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange(
3086         logging_context, node_index, fc_params->activation, &output_min,
3087         &output_max));
3088 
3089     if (subgraph != nullptr) {
3090       const xnn_status status = xnn_define_fully_connected(
3091           subgraph, output_min, output_max,
3092           /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
3093           /*filter_id=*/xnnpack_tensors[node->inputs->data[1]],
3094           /*bias_id=*/bias_tensor_id >= 0 ? xnnpack_tensors[bias_tensor_id]
3095                                           : XNN_INVALID_VALUE_ID,
3096           /*output_id=*/xnnpack_tensors[node->outputs->data[0]],
3097           /*flags=*/fc_params->keep_num_dims ? 0
3098                                              : XNN_FLAG_TENSORFLOW_RESHAPE_2D);
3099       if (status != xnn_status_success) {
3100         TF_LITE_KERNEL_LOG(logging_context,
3101                            "failed to delegate FULLY_CONNECTED node #%d",
3102                            node_index);
3103         return kTfLiteError;
3104       }
3105     }
3106 
3107     return kTfLiteOk;
3108   }
3109 
VisitFloorNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const std::vector<uint32_t> & xnnpack_tensors)3110   static TfLiteStatus VisitFloorNode(
3111       xnn_subgraph_t subgraph, const Delegate& delegate,
3112       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
3113       const TfLiteTensor* tensors,
3114       const std::vector<uint32_t>& xnnpack_tensors) {
3115     TF_LITE_ENSURE_STATUS(
3116         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
3117 
3118     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
3119     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
3120         logging_context, input_tensor, node->inputs->data[0], node_index));
3121     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
3122         logging_context, input_tensor, node->inputs->data[0], node_index));
3123 
3124     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
3125     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
3126         logging_context, output_tensor, node->outputs->data[0], node_index));
3127     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
3128         logging_context, output_tensor, node->outputs->data[0], node_index));
3129 
3130     if (subgraph != nullptr) {
3131       const xnn_status status = xnn_define_floor(
3132           subgraph, /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
3133           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
3134       if (status != xnn_status_success) {
3135         TF_LITE_KERNEL_LOG(logging_context, "failed to delegate FLOOR node #%d",
3136                            node_index);
3137         return kTfLiteError;
3138       }
3139     }
3140 
3141     return kTfLiteOk;
3142   }
3143 
VisitHardSwishNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const std::vector<uint32_t> & xnnpack_tensors)3144   static TfLiteStatus VisitHardSwishNode(
3145       xnn_subgraph_t subgraph, const Delegate& delegate,
3146       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
3147       const TfLiteTensor* tensors,
3148       const std::vector<uint32_t>& xnnpack_tensors) {
3149     TF_LITE_ENSURE_STATUS(
3150         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
3151 
3152     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
3153     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
3154         logging_context, input_tensor, node->inputs->data[0], node_index));
3155     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
3156         logging_context, input_tensor, node->inputs->data[0], node_index));
3157 
3158     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
3159     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
3160         logging_context, output_tensor, node->outputs->data[0], node_index));
3161     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
3162         logging_context, output_tensor, node->outputs->data[0], node_index));
3163 
3164     if (subgraph != nullptr) {
3165       const xnn_status status = xnn_define_hardswish(
3166           subgraph, /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
3167           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
3168       if (status != xnn_status_success) {
3169         TF_LITE_KERNEL_LOG(logging_context,
3170                            "failed to delegate HARD_SWISH node #%d",
3171                            node_index);
3172         return kTfLiteError;
3173       }
3174     }
3175 
3176     return kTfLiteOk;
3177   }
3178 
VisitLeakyReluNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLiteLeakyReluParams * leaky_relu_params,const std::vector<uint32_t> & xnnpack_tensors)3179   static TfLiteStatus VisitLeakyReluNode(
3180       xnn_subgraph_t subgraph, const Delegate& delegate,
3181       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
3182       const TfLiteTensor* tensors,
3183       const TfLiteLeakyReluParams* leaky_relu_params,
3184       const std::vector<uint32_t>& xnnpack_tensors) {
3185     TF_LITE_ENSURE_STATUS(
3186         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
3187 
3188     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
3189     TF_LITE_ENSURE_STATUS(
3190         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor,
3191                                        node->inputs->data[0], node_index));
3192     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
3193         logging_context, input_tensor, node->inputs->data[0], node_index));
3194 
3195     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
3196     TF_LITE_ENSURE_STATUS(
3197         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor,
3198                                        node->outputs->data[0], node_index));
3199     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
3200         logging_context, output_tensor, node->outputs->data[0], node_index));
3201 
3202     if (!std::isnormal(leaky_relu_params->alpha) ||
3203         leaky_relu_params->alpha == 0.0f) {
3204       TF_LITE_MAYBE_KERNEL_LOG(logging_context,
3205                                "unsupported alpha %g in LEAKY_RELU node #%d",
3206                                leaky_relu_params->alpha, node_index);
3207       return kTfLiteError;
3208     }
3209 
3210     const float input_scale =
3211         GetTensorScaleOrDefault(input_tensor, std::nanf(""));
3212     const float output_scale =
3213         GetTensorScaleOrDefault(output_tensor, std::nanf(""));
3214     if (std::isnormal(input_scale) && std::isnormal(output_scale)) {
3215       const float positive_scale = input_scale / output_scale;
3216       if (positive_scale < 1.0f / 256.0f || positive_scale > 128.0f) {
3217         TF_LITE_MAYBE_KERNEL_LOG(logging_context,
3218                                  "unsupported positive input-to-output scale "
3219                                  "%g in LEAKY_RELU node #%d",
3220                                  positive_scale, node_index);
3221         return kTfLiteError;
3222       }
3223 
3224       const float negative_scale = positive_scale * leaky_relu_params->alpha;
3225       if (negative_scale < -127.99609375f || negative_scale > 128.0f ||
3226           std::fabs(negative_scale) < 1.0f / 256.0f) {
3227         TF_LITE_MAYBE_KERNEL_LOG(logging_context,
3228                                  "unsupported negative input-to-output scale "
3229                                  "%g in LEAKY_RELU node #%d",
3230                                  negative_scale, node_index);
3231         return kTfLiteError;
3232       }
3233     }
3234 
3235     if (subgraph != nullptr) {
3236       const xnn_status status = xnn_define_leaky_relu(
3237           subgraph, leaky_relu_params->alpha,
3238           /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
3239           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
3240       if (status != xnn_status_success) {
3241         TF_LITE_KERNEL_LOG(logging_context,
3242                            "failed to delegate LEAKY_RELU node #%d",
3243                            node_index);
3244         return kTfLiteError;
3245       }
3246     }
3247 
3248     return kTfLiteOk;
3249   }
3250 
VisitLogisticNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const std::vector<uint32_t> & xnnpack_tensors)3251   static TfLiteStatus VisitLogisticNode(
3252       xnn_subgraph_t subgraph, const Delegate& delegate,
3253       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
3254       const TfLiteTensor* tensors,
3255       const std::vector<uint32_t>& xnnpack_tensors) {
3256     TF_LITE_ENSURE_STATUS(
3257         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
3258 
3259     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
3260     TF_LITE_ENSURE_STATUS(
3261         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor,
3262                                        node->inputs->data[0], node_index));
3263     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
3264         logging_context, input_tensor, node->inputs->data[0], node_index));
3265 
3266     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
3267     TF_LITE_ENSURE_STATUS(
3268         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor,
3269                                        node->outputs->data[0], node_index));
3270     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
3271         logging_context, output_tensor, node->outputs->data[0], node_index));
3272 
3273     if (subgraph != nullptr) {
3274       const xnn_status status = xnn_define_sigmoid(
3275           subgraph, /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
3276           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
3277       if (status != xnn_status_success) {
3278         TF_LITE_KERNEL_LOG(logging_context,
3279                            "failed to delegate LOGISTIC node #%d", node_index);
3280         return kTfLiteError;
3281       }
3282     }
3283 
3284     return kTfLiteOk;
3285   }
3286 
VisitMaxPool2DNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLitePoolParams * pool_params,const std::vector<uint32_t> & xnnpack_tensors)3287   static TfLiteStatus VisitMaxPool2DNode(
3288       xnn_subgraph_t subgraph, const Delegate& delegate,
3289       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
3290       const TfLiteTensor* tensors, const TfLitePoolParams* pool_params,
3291       const std::vector<uint32_t>& xnnpack_tensors) {
3292     TF_LITE_ENSURE_STATUS(
3293         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
3294 
3295     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
3296     TF_LITE_ENSURE_STATUS(
3297         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor,
3298                                        node->inputs->data[0], node_index));
3299     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
3300         logging_context, input_tensor, node->inputs->data[0], node_index));
3301 
3302     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
3303     TF_LITE_ENSURE_STATUS(
3304         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor,
3305                                        node->outputs->data[0], node_index));
3306     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
3307         logging_context, output_tensor, node->outputs->data[0], node_index));
3308 
3309     TF_LITE_ENSURE_STATUS(
3310         CheckPoolingParams(logging_context, pool_params, node_index));
3311 
3312     uint32_t flags = 0;
3313     TF_LITE_ENSURE_STATUS(CalculatePadding(
3314         logging_context, pool_params->padding, &flags, node_index));
3315 
3316     float output_min = -std::numeric_limits<float>::infinity();
3317     float output_max = +std::numeric_limits<float>::infinity();
3318     TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange(
3319         logging_context, node_index, pool_params->activation, &output_min,
3320         &output_max));
3321 
3322     if (subgraph != nullptr) {
3323       xnn_status status = xnn_status_success;
3324       if (pool_params->filter_height == 1 && pool_params->filter_width == 1) {
3325         status = xnn_define_clamp(
3326             subgraph, output_min, output_max,
3327             /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
3328             /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
3329       } else {
3330         status = xnn_define_max_pooling_2d(
3331             subgraph,
3332             /*input_padding_top=*/0,
3333             /*input_padding_right=*/0,
3334             /*input_padding_bottom=*/0,
3335             /*input_padding_left=*/0,
3336             static_cast<uint32_t>(pool_params->filter_height),
3337             static_cast<uint32_t>(pool_params->filter_width),
3338             static_cast<uint32_t>(pool_params->stride_height),
3339             static_cast<uint32_t>(pool_params->stride_width),
3340             /*dilation_height=*/1, /*dilation_width=*/1, output_min, output_max,
3341             /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
3342             /*output_id=*/xnnpack_tensors[node->outputs->data[0]], flags);
3343       }
3344       if (status != xnn_status_success) {
3345         TF_LITE_KERNEL_LOG(logging_context,
3346                            "failed to delegate MAX_POOL_2D node #%d",
3347                            node_index);
3348         return kTfLiteError;
3349       }
3350     }
3351 
3352     return kTfLiteOk;
3353   }
3354 
VisitMaximumNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const std::vector<uint32_t> & xnnpack_tensors)3355   static TfLiteStatus VisitMaximumNode(
3356       xnn_subgraph_t subgraph, const Delegate& delegate,
3357       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
3358       const TfLiteTensor* tensors,
3359       const std::vector<uint32_t>& xnnpack_tensors) {
3360     TF_LITE_ENSURE_STATUS(
3361         CheckNumInputsAndOutputs(logging_context, node, 2, 1, node_index));
3362 
3363     const TfLiteTensor& input1_tensor = tensors[node->inputs->data[0]];
3364     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
3365         logging_context, input1_tensor, node->inputs->data[0], node_index));
3366     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
3367         logging_context, input1_tensor, node->inputs->data[0], node_index));
3368 
3369     const TfLiteTensor& input2_tensor = tensors[node->inputs->data[1]];
3370     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
3371         logging_context, input2_tensor, node->inputs->data[1], node_index));
3372     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
3373         logging_context, input2_tensor, node->inputs->data[1], node_index));
3374 
3375     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
3376     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
3377         logging_context, output_tensor, node->outputs->data[0], node_index));
3378     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
3379         logging_context, output_tensor, node->outputs->data[0], node_index));
3380 
3381     if (subgraph != nullptr) {
3382       const xnn_status status = xnn_define_maximum2(
3383           subgraph, /*input1_id=*/xnnpack_tensors[node->inputs->data[0]],
3384           /*input2_id=*/xnnpack_tensors[node->inputs->data[1]],
3385           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
3386       if (status != xnn_status_success) {
3387         TF_LITE_KERNEL_LOG(logging_context,
3388                            "failed to delegate MAXIMUM node #%d", node_index);
3389         return kTfLiteError;
3390       }
3391     }
3392 
3393     return kTfLiteOk;
3394   }
3395 
VisitMeanNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLiteReducerParams * reducer_params,const std::vector<uint32_t> & xnnpack_tensors)3396   static TfLiteStatus VisitMeanNode(
3397       xnn_subgraph_t subgraph, const Delegate& delegate,
3398       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
3399       const TfLiteTensor* tensors, const TfLiteReducerParams* reducer_params,
3400       const std::vector<uint32_t>& xnnpack_tensors) {
3401     TF_LITE_ENSURE_STATUS(
3402         CheckNumInputsAndOutputs(logging_context, node, 2, 1, node_index));
3403 
3404     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
3405     TF_LITE_ENSURE_STATUS(
3406         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor,
3407                                        node->inputs->data[0], node_index));
3408     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_tensor, 4,
3409                                            node->inputs->data[0]));
3410     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
3411         logging_context, input_tensor, node->inputs->data[0], node_index));
3412 
3413     const TfLiteTensor& axes_tensor = tensors[node->inputs->data[1]];
3414     TF_LITE_ENSURE_STATUS(CheckTensorType(logging_context, axes_tensor,
3415                                           kTfLiteInt32, node->inputs->data[1],
3416                                           node_index));
3417     TF_LITE_ENSURE_STATUS(CheckAxesTensorShape(
3418         logging_context, axes_tensor, node->inputs->data[1], node_index));
3419     TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
3420         logging_context, axes_tensor, node->inputs->data[1], node_index));
3421 
3422     const int32_t* axes_data =
3423         reinterpret_cast<const int32_t*>(axes_tensor.data.data);
3424     const int num_reduction_axes = NumElements(&axes_tensor);
3425     switch (num_reduction_axes) {
3426       case 1:
3427         if (axes_data[0] != 2) {
3428           TF_LITE_MAYBE_KERNEL_LOG(
3429               logging_context,
3430               "unsupported MEAN reduction along non-spatial "
3431               "axis %d in node %d",
3432               axes_data[0], node_index);
3433           return kTfLiteError;
3434         }
3435         break;
3436       case 2:
3437         if (std::min(axes_data[0], axes_data[1]) != 1 ||
3438             std::max(axes_data[0], axes_data[1]) != 2) {
3439           TF_LITE_MAYBE_KERNEL_LOG(
3440               logging_context,
3441               "unsupported MEAN reduction along non-spatial "
3442               "axes %d and %d in node %d",
3443               std::min(axes_data[0], axes_data[1]),
3444               std::max(axes_data[0], axes_data[1]), node_index);
3445           return kTfLiteError;
3446         }
3447         break;
3448       default:
3449         TF_LITE_MAYBE_KERNEL_LOG(
3450             logging_context,
3451             "unsupported MEAN reduction along %d axes in node %d",
3452             SizeOfDimension(&axes_tensor, 0), node_index);
3453         return kTfLiteError;
3454     }
3455 
3456     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
3457     TF_LITE_ENSURE_STATUS(
3458         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor,
3459                                        node->outputs->data[0], node_index));
3460     int expected_output_dims = 4;
3461     if (!reducer_params->keep_dims) {
3462       expected_output_dims -= num_reduction_axes;
3463     }
3464     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_tensor,
3465                                            expected_output_dims,
3466                                            node->outputs->data[0]));
3467     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
3468         logging_context, output_tensor, node->outputs->data[0], node_index));
3469 
3470     if (subgraph != nullptr) {
3471       xnn_status status = xnn_status_success;
3472       switch (num_reduction_axes) {
3473         case 1:
3474           status = xnn_define_global_average_pooling_1d(
3475               subgraph,
3476               /*output_min=*/-std::numeric_limits<float>::infinity(),
3477               /*output_max=*/+std::numeric_limits<float>::infinity(),
3478               /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
3479               /*output_id=*/xnnpack_tensors[node->outputs->data[0]],
3480               /*flags=*/0);
3481           break;
3482         case 2:
3483           status = xnn_define_global_average_pooling_2d(
3484               subgraph,
3485               /*output_min=*/-std::numeric_limits<float>::infinity(),
3486               /*output_max=*/+std::numeric_limits<float>::infinity(),
3487               /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
3488               /*output_id=*/xnnpack_tensors[node->outputs->data[0]],
3489               /*flags=*/0);
3490           break;
3491         default:
3492           break;
3493       }
3494       if (status != xnn_status_success) {
3495         TF_LITE_KERNEL_LOG(logging_context, "failed to delegate MEAN node #%d",
3496                            node_index);
3497         return kTfLiteError;
3498       }
3499     }
3500 
3501     return kTfLiteOk;
3502   }
3503 
VisitMediaPipeDeconvolutionNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLiteTransposeConvParams * deconv_params,const std::unordered_set<int> & quasi_static_tensors,const std::vector<uint32_t> & xnnpack_tensors)3504   static TfLiteStatus VisitMediaPipeDeconvolutionNode(
3505       xnn_subgraph_t subgraph, const Delegate& delegate,
3506       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
3507       const TfLiteTensor* tensors,
3508       const TfLiteTransposeConvParams* deconv_params,
3509       const std::unordered_set<int>& quasi_static_tensors,
3510       const std::vector<uint32_t>& xnnpack_tensors) {
3511     TF_LITE_ENSURE_STATUS(
3512         CheckNumInputsAndOutputs(logging_context, node, 3, 1, node_index));
3513 
3514     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
3515     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
3516         logging_context, input_tensor, node->inputs->data[0], node_index));
3517     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_tensor, 4,
3518                                            node->inputs->data[0]));
3519     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
3520         logging_context, input_tensor, node->inputs->data[0], node_index));
3521 
3522     const TfLiteTensor& filter_tensor = tensors[node->inputs->data[1]];
3523     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
3524         logging_context, filter_tensor, node->inputs->data[1], node_index));
3525     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, filter_tensor, 4,
3526                                            node->inputs->data[1]));
3527     if (quasi_static_tensors.count(node->inputs->data[1]) == 0) {
3528       TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
3529           logging_context, filter_tensor, node->inputs->data[1], node_index));
3530     }
3531 
3532     const TfLiteTensor& bias_tensor = tensors[node->inputs->data[2]];
3533     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
3534         logging_context, bias_tensor, node->inputs->data[2], node_index));
3535     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, bias_tensor, 1,
3536                                            node->inputs->data[2]));
3537     if (quasi_static_tensors.count(node->inputs->data[2]) == 0) {
3538       TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
3539           logging_context, bias_tensor, node->inputs->data[2], node_index));
3540     }
3541 
3542     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
3543     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
3544         logging_context, output_tensor, node->outputs->data[0], node_index));
3545     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_tensor, 4,
3546                                            node->outputs->data[0]));
3547     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
3548         logging_context, output_tensor, node->outputs->data[0], node_index));
3549 
3550     const int* input_tensor_dims = input_tensor.dims->data;
3551     const int input_height = input_tensor_dims[1];
3552     const int input_width = input_tensor_dims[2];
3553 
3554     const int* output_tensor_dims = output_tensor.dims->data;
3555     const int output_height = output_tensor_dims[1];
3556     const int output_width = output_tensor_dims[2];
3557 
3558     const int output_channels = SizeOfDimension(&filter_tensor, 0);
3559     const int kernel_height = SizeOfDimension(&filter_tensor, 1);
3560     const int kernel_width = SizeOfDimension(&filter_tensor, 2);
3561     const int input_channels = SizeOfDimension(&filter_tensor, 3);
3562 
3563     TF_LITE_ENSURE_STATUS(CheckMediaPipeTransposedConvolutionParams(
3564         logging_context, deconv_params, node_index));
3565 
3566     int padding_top = 0;
3567     int padding_bottom = 0;
3568     int padding_left = 0;
3569     int padding_right = 0;
3570     int adjustment_height = 0;
3571     int adjustment_width = 0;
3572     TF_LITE_ENSURE_STATUS(CalculateTransposeConvPaddings(
3573         logging_context, deconv_params->padding, input_height, input_width,
3574         kernel_height, kernel_width, /*dilation_height=*/1,
3575         /*dilation_width=*/1, deconv_params->stride_height,
3576         deconv_params->stride_width, node_index, output_height, output_width,
3577         &padding_top, &padding_bottom, &padding_left, &padding_right,
3578         &adjustment_height, &adjustment_width));
3579 
3580     if (subgraph != nullptr) {
3581       const xnn_status status = xnn_define_deconvolution_2d(
3582           subgraph,
3583           /*padding_top=*/padding_top,
3584           /*padding_right=*/padding_right,
3585           /*padding_bottom=*/padding_bottom,
3586           /*padding_left=*/padding_left,
3587           /*adjustment_height=*/adjustment_height,
3588           /*adjustment_width=*/adjustment_width,
3589           static_cast<uint32_t>(kernel_height),
3590           static_cast<uint32_t>(kernel_width),
3591           static_cast<uint32_t>(deconv_params->stride_height),
3592           static_cast<uint32_t>(deconv_params->stride_width),
3593           /*dilation_height=*/1,
3594           /*dilation_width=*/1,
3595           /*groups=*/1,
3596           /*group_input_channels=*/input_channels,
3597           /*group_output_channels=*/output_channels,
3598           /*output_min=*/-std::numeric_limits<float>::infinity(),
3599           /*output_max=*/+std::numeric_limits<float>::infinity(),
3600           /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
3601           /*filter_id=*/xnnpack_tensors[node->inputs->data[1]],
3602           /*bias_id=*/xnnpack_tensors[node->inputs->data[2]],
3603           /*output_id=*/xnnpack_tensors[node->outputs->data[0]],
3604           /*flags=*/0);
3605       if (status != xnn_status_success) {
3606         TF_LITE_KERNEL_LOG(
3607             logging_context,
3608             "failed to delegate Convolution2DTransposeBias node #%d",
3609             node_index);
3610         return kTfLiteError;
3611       }
3612     }
3613 
3614     return kTfLiteOk;
3615   }
3616 
VisitMediaPipeMaxPoolingNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLitePoolParams * pool_params,const std::vector<uint32_t> & xnnpack_tensors)3617   static TfLiteStatus VisitMediaPipeMaxPoolingNode(
3618       xnn_subgraph_t subgraph, const Delegate& delegate,
3619       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
3620       const TfLiteTensor* tensors, const TfLitePoolParams* pool_params,
3621       const std::vector<uint32_t>& xnnpack_tensors) {
3622     TF_LITE_ENSURE_STATUS(
3623         CheckNumInputsAndOutputs(logging_context, node, 1, 2, node_index));
3624 
3625     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
3626     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
3627         logging_context, input_tensor, node->inputs->data[0], node_index));
3628     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_tensor, 4,
3629                                            node->inputs->data[0]));
3630     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
3631         logging_context, input_tensor, node->inputs->data[0], node_index));
3632 
3633     const TfLiteTensor& output_value_tensor = tensors[node->outputs->data[0]];
3634     TF_LITE_ENSURE_STATUS(
3635         CheckTensorFloat32Type(logging_context, output_value_tensor,
3636                                node->outputs->data[0], node_index));
3637     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_value_tensor,
3638                                            4, node->outputs->data[0]));
3639     TF_LITE_ENSURE_STATUS(
3640         CheckTensorNonDynamicAllocation(logging_context, output_value_tensor,
3641                                         node->outputs->data[0], node_index));
3642 
3643     const TfLiteTensor& output_index_tensor = tensors[node->outputs->data[1]];
3644     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_index_tensor,
3645                                            4, node->outputs->data[1]));
3646     TF_LITE_ENSURE_STATUS(
3647         CheckTensorNonDynamicAllocation(logging_context, output_index_tensor,
3648                                         node->outputs->data[1], node_index));
3649 
3650     TF_LITE_ENSURE_STATUS(
3651         CheckMediaPipePoolParams(logging_context, pool_params, node_index));
3652 
3653     uint32_t flags = 0;
3654     TF_LITE_ENSURE_STATUS(CalculatePadding(
3655         logging_context, pool_params->padding, &flags, node_index));
3656 
3657     if (subgraph != nullptr) {
3658       const xnn_status status = xnn_define_argmax_pooling_2d(
3659           subgraph,
3660           /*input_padding_top=*/0,
3661           /*input_padding_right=*/0,
3662           /*input_padding_bottom=*/0,
3663           /*input_padding_left=*/0,
3664           static_cast<uint32_t>(pool_params->filter_height),
3665           static_cast<uint32_t>(pool_params->filter_width),
3666           /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
3667           /*output_value_id=*/xnnpack_tensors[node->outputs->data[0]],
3668           /*output_index_id=*/xnnpack_tensors[node->outputs->data[1]], flags);
3669       if (status != xnn_status_success) {
3670         TF_LITE_KERNEL_LOG(
3671             logging_context,
3672             "failed to delegate CUSTOM(MaxPoolingWithArgmax2D) node #%d",
3673             node_index);
3674         return kTfLiteError;
3675       }
3676     }
3677 
3678     return kTfLiteOk;
3679   }
3680 
VisitMediaPipeUnpoolingNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLitePoolParams * pool_params,const std::vector<uint32_t> & xnnpack_tensors)3681   static TfLiteStatus VisitMediaPipeUnpoolingNode(
3682       xnn_subgraph_t subgraph, const Delegate& delegate,
3683       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
3684       const TfLiteTensor* tensors, const TfLitePoolParams* pool_params,
3685       const std::vector<uint32_t>& xnnpack_tensors) {
3686     TF_LITE_ENSURE_STATUS(
3687         CheckNumInputsAndOutputs(logging_context, node, 2, 1, node_index));
3688 
3689     const TfLiteTensor& input_value_tensor = tensors[node->inputs->data[0]];
3690     TF_LITE_ENSURE_STATUS(
3691         CheckTensorFloat32Type(logging_context, input_value_tensor,
3692                                node->inputs->data[0], node_index));
3693     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_value_tensor,
3694                                            4, node->inputs->data[0]));
3695     TF_LITE_ENSURE_STATUS(
3696         CheckTensorNonDynamicAllocation(logging_context, input_value_tensor,
3697                                         node->inputs->data[0], node_index));
3698 
3699     const TfLiteTensor& input_index_tensor = tensors[node->inputs->data[1]];
3700     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_index_tensor,
3701                                            4, node->inputs->data[1]));
3702     TF_LITE_ENSURE_STATUS(
3703         CheckTensorNonDynamicAllocation(logging_context, input_index_tensor,
3704                                         node->inputs->data[1], node_index));
3705 
3706     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
3707     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
3708         logging_context, output_tensor, node->outputs->data[0], node_index));
3709     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_tensor, 4,
3710                                            node->outputs->data[0]));
3711     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
3712         logging_context, output_tensor, node->outputs->data[0], node_index));
3713 
3714     TF_LITE_ENSURE_STATUS(
3715         CheckMediaPipePoolParams(logging_context, pool_params, node_index));
3716 
3717     uint32_t flags = 0;
3718     TF_LITE_ENSURE_STATUS(CalculatePadding(
3719         logging_context, pool_params->padding, &flags, node_index));
3720     if (flags != 0) {
3721       TF_LITE_MAYBE_KERNEL_LOG(
3722           logging_context, "invalid padding mode (%d) in node #%d",
3723           static_cast<int>(pool_params->padding), node_index);
3724     }
3725 
3726     if (subgraph != nullptr) {
3727       const xnn_status status = xnn_define_unpooling_2d(
3728           subgraph,
3729           /*padding_top=*/0,
3730           /*padding_right=*/0,
3731           /*padding_bottom=*/0,
3732           /*padding_left=*/0, static_cast<uint32_t>(pool_params->filter_height),
3733           static_cast<uint32_t>(pool_params->filter_width),
3734           /*input_value_id=*/xnnpack_tensors[node->inputs->data[0]],
3735           /*input_index_id=*/xnnpack_tensors[node->inputs->data[1]],
3736           /*output_id=*/xnnpack_tensors[node->outputs->data[0]],
3737           /*flags=*/0);
3738       if (status != xnn_status_success) {
3739         TF_LITE_KERNEL_LOG(logging_context,
3740                            "failed to delegate CUSTOM(MaxUnpooling2D) node #%d",
3741                            node_index);
3742         return kTfLiteError;
3743       }
3744     }
3745 
3746     return kTfLiteOk;
3747   }
3748 
VisitMinimumNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const std::vector<uint32_t> & xnnpack_tensors)3749   static TfLiteStatus VisitMinimumNode(
3750       xnn_subgraph_t subgraph, const Delegate& delegate,
3751       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
3752       const TfLiteTensor* tensors,
3753       const std::vector<uint32_t>& xnnpack_tensors) {
3754     TF_LITE_ENSURE_STATUS(
3755         CheckNumInputsAndOutputs(logging_context, node, 2, 1, node_index));
3756 
3757     const TfLiteTensor& input1_tensor = tensors[node->inputs->data[0]];
3758     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
3759         logging_context, input1_tensor, node->inputs->data[0], node_index));
3760     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
3761         logging_context, input1_tensor, node->inputs->data[0], node_index));
3762 
3763     const TfLiteTensor& input2_tensor = tensors[node->inputs->data[1]];
3764     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
3765         logging_context, input2_tensor, node->inputs->data[1], node_index));
3766     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
3767         logging_context, input2_tensor, node->inputs->data[1], node_index));
3768 
3769     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
3770     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
3771         logging_context, output_tensor, node->outputs->data[0], node_index));
3772     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
3773         logging_context, output_tensor, node->outputs->data[0], node_index));
3774 
3775     if (subgraph != nullptr) {
3776       const xnn_status status = xnn_define_minimum2(
3777           subgraph, /*input1_id=*/xnnpack_tensors[node->inputs->data[0]],
3778           /*input2_id=*/xnnpack_tensors[node->inputs->data[1]],
3779           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
3780       if (status != xnn_status_success) {
3781         TF_LITE_KERNEL_LOG(logging_context,
3782                            "failed to delegate MINIMUM node #%d", node_index);
3783         return kTfLiteError;
3784       }
3785     }
3786 
3787     return kTfLiteOk;
3788   }
3789 
VisitMulNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLiteMulParams * mul_params,const std::vector<uint32_t> & xnnpack_tensors)3790   static TfLiteStatus VisitMulNode(
3791       xnn_subgraph_t subgraph, const Delegate& delegate,
3792       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
3793       const TfLiteTensor* tensors, const TfLiteMulParams* mul_params,
3794       const std::vector<uint32_t>& xnnpack_tensors) {
3795     TF_LITE_ENSURE_STATUS(
3796         CheckNumInputsAndOutputs(logging_context, node, 2, 1, node_index));
3797 
3798     const TfLiteTensor& input1_tensor = tensors[node->inputs->data[0]];
3799     TF_LITE_ENSURE_STATUS(
3800         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input1_tensor,
3801                                        node->inputs->data[0], node_index));
3802     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
3803         logging_context, input1_tensor, node->inputs->data[0], node_index));
3804 
3805     const TfLiteTensor& input2_tensor = tensors[node->inputs->data[1]];
3806     TF_LITE_ENSURE_STATUS(
3807         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input2_tensor,
3808                                        node->inputs->data[1], node_index));
3809     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
3810         logging_context, input2_tensor, node->inputs->data[1], node_index));
3811 
3812     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
3813     TF_LITE_ENSURE_STATUS(
3814         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor,
3815                                        node->outputs->data[0], node_index));
3816     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
3817         logging_context, output_tensor, node->outputs->data[0], node_index));
3818 
3819     const float scale_min = 1.0f / 65536.0f;
3820     const float scale_max = 256.0f;
3821     TF_LITE_ENSURE_STATUS(CheckTensorsInputProductOutputScale(
3822         logging_context, input1_tensor, input2_tensor, output_tensor, scale_min,
3823         scale_max, node_index, "MUL"));
3824 
3825     float output_min = -std::numeric_limits<float>::infinity();
3826     float output_max = +std::numeric_limits<float>::infinity();
3827     if (mul_params != nullptr) {
3828       TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange(
3829           logging_context, node_index, mul_params->activation, &output_min,
3830           &output_max));
3831     }
3832 
3833     if (subgraph != nullptr) {
3834       const xnn_status status = xnn_define_multiply2(
3835           subgraph, output_min, output_max,
3836           /*input1_id=*/xnnpack_tensors[node->inputs->data[0]],
3837           /*input2_id=*/xnnpack_tensors[node->inputs->data[1]],
3838           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
3839       if (status != xnn_status_success) {
3840         TF_LITE_KERNEL_LOG(logging_context, "failed to delegate MUL node #%d",
3841                            node_index);
3842         return kTfLiteError;
3843       }
3844     }
3845 
3846     return kTfLiteOk;
3847   }
3848 
VisitNegNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const std::vector<uint32_t> & xnnpack_tensors)3849   static TfLiteStatus VisitNegNode(
3850       xnn_subgraph_t subgraph, const Delegate& delegate,
3851       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
3852       const TfLiteTensor* tensors,
3853       const std::vector<uint32_t>& xnnpack_tensors) {
3854     TF_LITE_ENSURE_STATUS(
3855         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
3856 
3857     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
3858     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
3859         logging_context, input_tensor, node->inputs->data[0], node_index));
3860     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
3861         logging_context, input_tensor, node->inputs->data[0], node_index));
3862 
3863     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
3864     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
3865         logging_context, output_tensor, node->outputs->data[0], node_index));
3866     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
3867         logging_context, output_tensor, node->outputs->data[0], node_index));
3868 
3869     if (subgraph != nullptr) {
3870       const xnn_status status = xnn_define_negate(
3871           subgraph, /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
3872           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
3873       if (status != xnn_status_success) {
3874         TF_LITE_KERNEL_LOG(logging_context, "failed to delegate NEG node #%d",
3875                            node_index);
3876         return kTfLiteError;
3877       }
3878     }
3879 
3880     return kTfLiteOk;
3881   }
3882 
VisitPadNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const std::vector<uint32_t> & xnnpack_tensors)3883   static TfLiteStatus VisitPadNode(
3884       xnn_subgraph_t subgraph, const Delegate& delegate,
3885       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
3886       const TfLiteTensor* tensors,
3887       const std::vector<uint32_t>& xnnpack_tensors) {
3888     TF_LITE_ENSURE_STATUS(
3889         CheckNumInputsAndOutputs(logging_context, node, 2, 1, node_index));
3890 
3891     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
3892     TF_LITE_ENSURE_STATUS(
3893         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor,
3894                                        node->inputs->data[0], node_index));
3895     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_tensor, 1,
3896                                            XNN_MAX_TENSOR_DIMS,
3897                                            node->inputs->data[0]));
3898     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
3899         logging_context, input_tensor, node->inputs->data[0], node_index));
3900 
3901     const TfLiteTensor& paddings_tensor = tensors[node->inputs->data[1]];
3902     TF_LITE_ENSURE_STATUS(CheckTensorType(logging_context, paddings_tensor,
3903                                           kTfLiteInt32, node->inputs->data[1],
3904                                           node_index));
3905     TF_LITE_ENSURE_STATUS(CheckPaddingsTensorShape(
3906         logging_context, paddings_tensor, NumDimensions(&input_tensor),
3907         node->inputs->data[1], node_index));
3908     TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
3909         logging_context, paddings_tensor, node->inputs->data[1], node_index));
3910 
3911     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
3912     TF_LITE_ENSURE_STATUS(
3913         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor,
3914                                        node->outputs->data[0], node_index));
3915     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_tensor, 1,
3916                                            XNN_MAX_TENSOR_DIMS,
3917                                            node->outputs->data[0]));
3918     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
3919         logging_context, output_tensor, node->outputs->data[0], node_index));
3920 
3921     const int32_t* paddings_data =
3922         reinterpret_cast<const int32_t*>(paddings_tensor.data.data);
3923     for (int i = 0; i < NumDimensions(&paddings_tensor); i++) {
3924       const int32_t pre_padding = paddings_data[i * 2 + 0];
3925       if (pre_padding < 0) {
3926         TF_LITE_MAYBE_KERNEL_LOG(
3927             logging_context,
3928             "invalid pre-padding %d for dimension #%d in node %d", pre_padding,
3929             i, node_index);
3930         return kTfLiteError;
3931       }
3932 
3933       const int32_t post_padding = paddings_data[i * 2 + 1];
3934       if (post_padding < 0) {
3935         TF_LITE_MAYBE_KERNEL_LOG(
3936             logging_context,
3937             "invalid post-padding %d for dimension #%d in node %d", pre_padding,
3938             i, node_index);
3939         return kTfLiteError;
3940       }
3941     }
3942 
3943     if (subgraph != nullptr) {
3944       std::array<size_t, XNN_MAX_TENSOR_DIMS> pre_paddings{};
3945       std::array<size_t, XNN_MAX_TENSOR_DIMS> post_paddings{};
3946       for (int i = 0; i < SizeOfDimension(&paddings_tensor, 0); i++) {
3947         pre_paddings[i] = static_cast<size_t>(paddings_data[i * 2 + 0]);
3948         post_paddings[i] = static_cast<size_t>(paddings_data[i * 2 + 1]);
3949       }
3950 
3951       const xnn_status status = xnn_define_static_constant_pad(
3952           subgraph, pre_paddings.data(), post_paddings.data(),
3953           /*padding_value=*/0.0f,
3954           /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
3955           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
3956       if (status != xnn_status_success) {
3957         TF_LITE_KERNEL_LOG(logging_context, "failed to delegate PAD node #%d",
3958                            node_index);
3959         return kTfLiteError;
3960       }
3961     }
3962 
3963     return kTfLiteOk;
3964   }
3965 
VisitPreluNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const std::unordered_set<int> & quasi_static_tensors,const std::vector<uint32_t> & xnnpack_tensors)3966   static TfLiteStatus VisitPreluNode(
3967       xnn_subgraph_t subgraph, const Delegate& delegate,
3968       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
3969       const TfLiteTensor* tensors,
3970       const std::unordered_set<int>& quasi_static_tensors,
3971       const std::vector<uint32_t>& xnnpack_tensors) {
3972     TF_LITE_ENSURE_STATUS(
3973         CheckNumInputsAndOutputs(logging_context, node, 2, 1, node_index));
3974 
3975     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
3976     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
3977         logging_context, input_tensor, node->inputs->data[0], node_index));
3978     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_tensor, 1,
3979                                            XNN_MAX_TENSOR_DIMS,
3980                                            node->inputs->data[0]));
3981     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
3982         logging_context, input_tensor, node->inputs->data[0], node_index));
3983 
3984     const TfLiteTensor& slope_tensor = tensors[node->inputs->data[1]];
3985     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
3986         logging_context, slope_tensor, node->inputs->data[1], node_index));
3987     TF_LITE_ENSURE_STATUS(CheckSlopeTensorShape(
3988         logging_context, slope_tensor, node->inputs->data[1], node_index));
3989     if (quasi_static_tensors.count(node->inputs->data[1]) == 0) {
3990       TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
3991           logging_context, slope_tensor, node->inputs->data[1], node_index));
3992     }
3993 
3994     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
3995     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
3996         logging_context, output_tensor, node->outputs->data[0], node_index));
3997     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_tensor, 1,
3998                                            XNN_MAX_TENSOR_DIMS,
3999                                            node->outputs->data[0]));
4000     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
4001         logging_context, output_tensor, node->outputs->data[0], node_index));
4002 
4003     if (subgraph != nullptr) {
4004       const xnn_status status = xnn_define_prelu(
4005           subgraph, /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
4006           /*slope_id=*/xnnpack_tensors[node->inputs->data[1]],
4007           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
4008       if (status != xnn_status_success) {
4009         TF_LITE_KERNEL_LOG(logging_context, "failed to delegate PRELU node #%d",
4010                            node_index);
4011         return kTfLiteError;
4012       }
4013     }
4014 
4015     return kTfLiteOk;
4016   }
4017 
VisitQuantizeNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const std::vector<uint32_t> & xnnpack_tensors)4018   static TfLiteStatus VisitQuantizeNode(
4019       xnn_subgraph_t subgraph, const Delegate& delegate,
4020       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
4021       const TfLiteTensor* tensors,
4022       const std::vector<uint32_t>& xnnpack_tensors) {
4023     TF_LITE_ENSURE_STATUS(
4024         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
4025 
4026     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
4027     TF_LITE_ENSURE_STATUS(
4028         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor,
4029                                        node->inputs->data[0], node_index));
4030     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
4031         logging_context, input_tensor, node->inputs->data[0], node_index));
4032 
4033     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
4034     TF_LITE_ENSURE_STATUS(
4035         CheckTensorQInt8OrQUInt8Type(delegate, logging_context, output_tensor,
4036                                      node->outputs->data[0], node_index));
4037     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
4038         logging_context, output_tensor, node->outputs->data[0], node_index));
4039 
4040     const xnn_datatype input_datatype = GetXNNPackDatatype(input_tensor);
4041     const xnn_datatype output_datatype = GetXNNPackDatatype(output_tensor);
4042     bool supported_combination = false;
4043     switch (input_datatype) {
4044       case xnn_datatype_fp32:
4045         supported_combination = true;
4046         break;
4047       case xnn_datatype_qint8:
4048       case xnn_datatype_quint8:
4049         if (input_datatype == output_datatype) {
4050           const float input_scale =
4051               GetTensorScaleOrDefault(input_tensor, std::nanf(""));
4052           const float output_scale =
4053               GetTensorScaleOrDefault(output_tensor, std::nanf(""));
4054           const float input_output_scale = input_scale / output_scale;
4055           if (input_output_scale < 1.0f / 256.0f ||
4056               input_output_scale > 128.0f) {
4057             TF_LITE_MAYBE_KERNEL_LOG(
4058                 logging_context,
4059                 "unsupported input-to-output scale in QUANTIZE node #%d",
4060                 node_index);
4061             return kTfLiteError;
4062           }
4063           supported_combination = true;
4064         }
4065         break;
4066       default:
4067         break;
4068     }
4069     if (!supported_combination) {
4070       TF_LITE_MAYBE_KERNEL_LOG(logging_context,
4071                                "unsupported combination of input type (%s) and "
4072                                "output type (%s) in QUANTIZE node #%d",
4073                                TfLiteTypeGetName(input_tensor.type),
4074                                TfLiteTypeGetName(output_tensor.type),
4075                                node_index);
4076       return kTfLiteError;
4077     }
4078 
4079     if (subgraph != nullptr) {
4080       const xnn_status status = xnn_define_convert(
4081           subgraph, /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
4082           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
4083       if (status != xnn_status_success) {
4084         TF_LITE_KERNEL_LOG(logging_context,
4085                            "failed to delegate QUANTIZE node #%d", node_index);
4086         return kTfLiteError;
4087       }
4088     }
4089 
4090     return kTfLiteOk;
4091   }
4092 
VisitReluNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,float output_min,float output_max,const std::vector<uint32_t> & xnnpack_tensors)4093   static TfLiteStatus VisitReluNode(
4094       xnn_subgraph_t subgraph, const Delegate& delegate,
4095       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
4096       const TfLiteTensor* tensors, float output_min, float output_max,
4097       const std::vector<uint32_t>& xnnpack_tensors) {
4098     TF_LITE_ENSURE_STATUS(
4099         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
4100 
4101     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
4102     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
4103         logging_context, input_tensor, node->inputs->data[0], node_index));
4104     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
4105         logging_context, input_tensor, node->inputs->data[0], node_index));
4106 
4107     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
4108     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
4109         logging_context, output_tensor, node->outputs->data[0], node_index));
4110     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
4111         logging_context, output_tensor, node->outputs->data[0], node_index));
4112 
4113     if (subgraph != nullptr) {
4114       const xnn_status status = xnn_define_clamp(
4115           subgraph, output_min, output_max,
4116           /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
4117           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
4118       if (status != xnn_status_success) {
4119         TF_LITE_KERNEL_LOG(logging_context, "failed to delegate RELU node #%d",
4120                            node_index);
4121         return kTfLiteError;
4122       }
4123     }
4124 
4125     return kTfLiteOk;
4126   }
4127 
VisitReshapeNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLiteReshapeParams * reshape_params,const std::vector<uint32_t> & xnnpack_tensors)4128   static TfLiteStatus VisitReshapeNode(
4129       xnn_subgraph_t subgraph, const Delegate& delegate,
4130       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
4131       const TfLiteTensor* tensors, const TfLiteReshapeParams* reshape_params,
4132       const std::vector<uint32_t>& xnnpack_tensors) {
4133     switch (node->inputs->size) {
4134       case 1:
4135       case 2:
4136         break;
4137       default:
4138         TF_LITE_MAYBE_KERNEL_LOG(
4139             logging_context,
4140             "unexpected number of inputs (%d) in node #%d: "
4141             "either one or two inputs expected",
4142             node->inputs->size, node_index);
4143         return kTfLiteError;
4144     }
4145     if (node->outputs->size != 1) {
4146       TF_LITE_MAYBE_KERNEL_LOG(
4147           logging_context,
4148           "unexpected number of outputs (%d) in node #%d: one output expected",
4149           node->outputs->size, node_index);
4150       return kTfLiteError;
4151     }
4152 
4153     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
4154     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
4155         logging_context, input_tensor, node->inputs->data[0], node_index));
4156     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_tensor, 0,
4157                                            XNN_MAX_TENSOR_DIMS,
4158                                            node->inputs->data[0]));
4159     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
4160         logging_context, input_tensor, node->inputs->data[0], node_index));
4161 
4162     if (node->inputs->size == 2) {
4163       const TfLiteTensor& shape_tensor = tensors[node->inputs->data[1]];
4164       TF_LITE_ENSURE_STATUS(CheckTensorType(logging_context, shape_tensor,
4165                                             kTfLiteInt32, node->inputs->data[1],
4166                                             node_index));
4167       TF_LITE_ENSURE_STATUS(CheckShapeTensorShape(
4168           logging_context, shape_tensor, node->inputs->data[1], node_index));
4169       TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
4170           logging_context, shape_tensor, node->inputs->data[1], node_index));
4171     }
4172 
4173     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
4174     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
4175         logging_context, output_tensor, node->outputs->data[0], node_index));
4176     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_tensor, 0,
4177                                            XNN_MAX_TENSOR_DIMS,
4178                                            node->outputs->data[0]));
4179     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
4180         logging_context, output_tensor, node->outputs->data[0], node_index));
4181 
4182     if (subgraph != nullptr) {
4183       std::array<size_t, XNN_MAX_TENSOR_DIMS> new_shape;
4184       std::copy(&output_tensor.dims->data[0],
4185                 &output_tensor.dims->data[NumDimensions(&output_tensor)],
4186                 new_shape.begin());
4187       const xnn_status status = xnn_define_static_reshape(
4188           subgraph, static_cast<size_t>(NumDimensions(&output_tensor)),
4189           new_shape.data(),
4190           /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
4191           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
4192       if (status != xnn_status_success) {
4193         TF_LITE_KERNEL_LOG(logging_context,
4194                            "failed to delegate RESHAPE node #%d", node_index);
4195         return kTfLiteError;
4196       }
4197     }
4198 
4199     return kTfLiteOk;
4200   }
4201 
VisitResizeBilinearNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLiteResizeBilinearParams * resize_params,const std::vector<uint32_t> & xnnpack_tensors)4202   static TfLiteStatus VisitResizeBilinearNode(
4203       xnn_subgraph_t subgraph, const Delegate& delegate,
4204       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
4205       const TfLiteTensor* tensors,
4206       const TfLiteResizeBilinearParams* resize_params,
4207       const std::vector<uint32_t>& xnnpack_tensors) {
4208     TF_LITE_ENSURE_STATUS(
4209         CheckNumInputsAndOutputs(logging_context, node, 2, 1, node_index));
4210 
4211     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
4212     TF_LITE_ENSURE_STATUS(
4213         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor,
4214                                        node->inputs->data[0], node_index));
4215     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_tensor, 4,
4216                                            node->inputs->data[0]));
4217     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
4218         logging_context, input_tensor, node->inputs->data[0], node_index));
4219 
4220     const TfLiteTensor& shape_tensor = tensors[node->inputs->data[1]];
4221     TF_LITE_ENSURE_STATUS(CheckTensorType(logging_context, shape_tensor,
4222                                           kTfLiteInt32, node->inputs->data[1],
4223                                           node_index));
4224     TF_LITE_ENSURE_STATUS(CheckShapeTensorShape(
4225         logging_context, shape_tensor, node->inputs->data[1], node_index));
4226     if (SizeOfDimension(&shape_tensor, 0) != 2) {
4227       TF_LITE_MAYBE_KERNEL_LOG(
4228           logging_context,
4229           "unexpected number of dimensions %d in the output shape in node %d",
4230           SizeOfDimension(&shape_tensor, 0), node_index);
4231     }
4232     TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
4233         logging_context, shape_tensor, node->inputs->data[1], node_index));
4234 
4235     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
4236     TF_LITE_ENSURE_STATUS(
4237         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor,
4238                                        node->outputs->data[0], node_index));
4239     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_tensor, 4,
4240                                            node->outputs->data[0]));
4241     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
4242         logging_context, output_tensor, node->outputs->data[0], node_index));
4243 
4244     const int32_t* shape_data =
4245         reinterpret_cast<const int32_t*>(shape_tensor.data.data);
4246     for (int i = 0; i < NumDimensions(&shape_tensor); i++) {
4247       const int32_t dim = shape_data[i];
4248       if (dim <= 0) {
4249         TF_LITE_MAYBE_KERNEL_LOG(
4250             logging_context, "invalid output dimension #%d value %d in node %d",
4251             i, dim, node_index);
4252         return kTfLiteError;
4253       }
4254     }
4255 
4256     if (subgraph != nullptr) {
4257       uint32_t flags = 0;
4258       if (resize_params->align_corners) {
4259         flags |= XNN_FLAG_ALIGN_CORNERS;
4260       } else if (!resize_params->half_pixel_centers) {
4261         flags |= XNN_FLAG_TENSORFLOW_LEGACY_MODE;
4262       }
4263       const xnn_status status = xnn_define_static_resize_bilinear_2d(
4264           subgraph, static_cast<size_t>(shape_data[0]),
4265           static_cast<size_t>(shape_data[1]),
4266           /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
4267           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], flags);
4268       if (status != xnn_status_success) {
4269         TF_LITE_KERNEL_LOG(logging_context,
4270                            "failed to delegate RESIZE_BILINEAR node #%d",
4271                            node_index);
4272         return kTfLiteError;
4273       }
4274     }
4275 
4276     return kTfLiteOk;
4277   }
4278 
VisitRoundNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const std::vector<uint32_t> & xnnpack_tensors)4279   static TfLiteStatus VisitRoundNode(
4280       xnn_subgraph_t subgraph, const Delegate& delegate,
4281       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
4282       const TfLiteTensor* tensors,
4283       const std::vector<uint32_t>& xnnpack_tensors) {
4284     TF_LITE_ENSURE_STATUS(
4285         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
4286 
4287     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
4288     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
4289         logging_context, input_tensor, node->inputs->data[0], node_index));
4290     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
4291         logging_context, input_tensor, node->inputs->data[0], node_index));
4292 
4293     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
4294     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
4295         logging_context, output_tensor, node->outputs->data[0], node_index));
4296     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
4297         logging_context, output_tensor, node->outputs->data[0], node_index));
4298 
4299     if (subgraph != nullptr) {
4300       const xnn_status status = xnn_define_bankers_rounding(
4301           subgraph, /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
4302           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
4303       if (status != xnn_status_success) {
4304         TF_LITE_KERNEL_LOG(logging_context, "failed to delegate ROUND node #%d",
4305                            node_index);
4306         return kTfLiteError;
4307       }
4308     }
4309 
4310     return kTfLiteOk;
4311   }
4312 
VisitSoftmaxNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLiteSoftmaxParams * params,const std::vector<uint32_t> & xnnpack_tensors)4313   static TfLiteStatus VisitSoftmaxNode(
4314       xnn_subgraph_t subgraph, const Delegate& delegate,
4315       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
4316       const TfLiteTensor* tensors, const TfLiteSoftmaxParams* params,
4317       const std::vector<uint32_t>& xnnpack_tensors) {
4318     if (params->beta != 1.0f) {
4319       if (logging_context != nullptr) {
4320         TF_LITE_KERNEL_LOG(logging_context,
4321                            "unsupported beta value %.7f in SOFTMAX node #%d",
4322                            params->beta, node_index);
4323       }
4324       return kTfLiteError;
4325     }
4326 
4327     TF_LITE_ENSURE_STATUS(
4328         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
4329 
4330     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
4331     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
4332         logging_context, input_tensor, node->inputs->data[0], node_index));
4333     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
4334         logging_context, input_tensor, node->inputs->data[0], node_index));
4335 
4336     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
4337     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
4338         logging_context, output_tensor, node->outputs->data[0], node_index));
4339     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
4340         logging_context, output_tensor, node->outputs->data[0], node_index));
4341 
4342     if (subgraph != nullptr) {
4343       const xnn_status status = xnn_define_softmax(
4344           subgraph, /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
4345           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
4346       if (status != xnn_status_success) {
4347         TF_LITE_KERNEL_LOG(logging_context,
4348                            "failed to delegate SOFTMAX node #%d", node_index);
4349         return kTfLiteError;
4350       }
4351     }
4352 
4353     return kTfLiteOk;
4354   }
4355 
VisitSquareNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const std::vector<uint32_t> & xnnpack_tensors)4356   static TfLiteStatus VisitSquareNode(
4357       xnn_subgraph_t subgraph, const Delegate& delegate,
4358       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
4359       const TfLiteTensor* tensors,
4360       const std::vector<uint32_t>& xnnpack_tensors) {
4361     TF_LITE_ENSURE_STATUS(
4362         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
4363 
4364     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
4365     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
4366         logging_context, input_tensor, node->inputs->data[0], node_index));
4367     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
4368         logging_context, input_tensor, node->inputs->data[0], node_index));
4369 
4370     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
4371     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
4372         logging_context, output_tensor, node->outputs->data[0], node_index));
4373     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
4374         logging_context, output_tensor, node->outputs->data[0], node_index));
4375 
4376     if (subgraph != nullptr) {
4377       const xnn_status status = xnn_define_square(
4378           subgraph, /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
4379           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
4380       if (status != xnn_status_success) {
4381         TF_LITE_KERNEL_LOG(logging_context,
4382                            "failed to delegate SQUARE node #%d", node_index);
4383         return kTfLiteError;
4384       }
4385     }
4386 
4387     return kTfLiteOk;
4388   }
4389 
VisitTransposeNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const std::vector<uint32_t> & xnnpack_tensors)4390   static TfLiteStatus VisitTransposeNode(
4391       xnn_subgraph_t subgraph, const Delegate& delegate,
4392       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
4393       const TfLiteTensor* tensors,
4394       const std::vector<uint32_t>& xnnpack_tensors) {
4395     TF_LITE_ENSURE_STATUS(
4396         CheckNumInputsAndOutputs(logging_context, node, 2, 1, node_index));
4397 
4398     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
4399     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
4400         logging_context, input_tensor, node->inputs->data[0], node_index));
4401 
4402     const TfLiteTensor& perm_tensor = tensors[node->inputs->data[1]];
4403     TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
4404         logging_context, perm_tensor, node->inputs->data[1], node_index));
4405 
4406     const int* perm_data = GetTensorData<int32_t>(&perm_tensor);
4407 
4408     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
4409     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
4410         logging_context, output_tensor, node->outputs->data[0], node_index));
4411     const int dims_count = NumDimensions(&input_tensor);
4412     std::array<size_t, XNN_MAX_TENSOR_DIMS> perm;
4413     std::copy(&perm_data[0], &perm_data[dims_count], perm.begin());
4414     if (subgraph != nullptr) {
4415       const xnn_status status = xnn_define_static_transpose(
4416           subgraph, dims_count, perm.data(),
4417           /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
4418           /*output_id=*/xnnpack_tensors[node->outputs->data[0]],
4419           /*flags=*/0);
4420       if (status != xnn_status_success) {
4421         TF_LITE_KERNEL_LOG(logging_context,
4422                            "failed to delegate TRANSPOSE node #%d", node_index);
4423         return kTfLiteError;
4424       }
4425     }
4426 
4427     return kTfLiteOk;
4428   }
4429 
VisitSqrtNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const std::vector<uint32_t> & xnnpack_tensors)4430   static TfLiteStatus VisitSqrtNode(
4431       xnn_subgraph_t subgraph, const Delegate& delegate,
4432       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
4433       const TfLiteTensor* tensors,
4434       const std::vector<uint32_t>& xnnpack_tensors) {
4435     TF_LITE_ENSURE_STATUS(
4436         CheckNumInputsAndOutputs(logging_context, node, 1, 1, node_index));
4437 
4438     const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
4439     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
4440         logging_context, input_tensor, node->inputs->data[0], node_index));
4441     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
4442         logging_context, input_tensor, node->inputs->data[0], node_index));
4443 
4444     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
4445     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
4446         logging_context, output_tensor, node->outputs->data[0], node_index));
4447     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
4448         logging_context, output_tensor, node->outputs->data[0], node_index));
4449 
4450     if (subgraph != nullptr) {
4451       const xnn_status status = xnn_define_square_root(
4452           subgraph, /*input_id=*/xnnpack_tensors[node->inputs->data[0]],
4453           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
4454       if (status != xnn_status_success) {
4455         TF_LITE_KERNEL_LOG(logging_context, "failed to delegate SQRT node #%d",
4456                            node_index);
4457         return kTfLiteError;
4458       }
4459     }
4460 
4461     return kTfLiteOk;
4462   }
4463 
VisitSquaredDifferenceNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const std::vector<uint32_t> & xnnpack_tensors)4464   static TfLiteStatus VisitSquaredDifferenceNode(
4465       xnn_subgraph_t subgraph, const Delegate& delegate,
4466       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
4467       const TfLiteTensor* tensors,
4468       const std::vector<uint32_t>& xnnpack_tensors) {
4469     TF_LITE_ENSURE_STATUS(
4470         CheckNumInputsAndOutputs(logging_context, node, 2, 1, node_index));
4471 
4472     const TfLiteTensor& input1_tensor = tensors[node->inputs->data[0]];
4473     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
4474         logging_context, input1_tensor, node->inputs->data[0], node_index));
4475     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
4476         logging_context, input1_tensor, node->inputs->data[0], node_index));
4477 
4478     const TfLiteTensor& input2_tensor = tensors[node->inputs->data[1]];
4479     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
4480         logging_context, input2_tensor, node->inputs->data[1], node_index));
4481     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
4482         logging_context, input2_tensor, node->inputs->data[1], node_index));
4483 
4484     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
4485     TF_LITE_ENSURE_STATUS(CheckTensorFloat32Type(
4486         logging_context, output_tensor, node->outputs->data[0], node_index));
4487     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
4488         logging_context, output_tensor, node->outputs->data[0], node_index));
4489 
4490     if (subgraph != nullptr) {
4491       const xnn_status status = xnn_define_squared_difference(
4492           subgraph, /*input1_id=*/xnnpack_tensors[node->inputs->data[0]],
4493           /*input2_id=*/xnnpack_tensors[node->inputs->data[1]],
4494           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
4495       if (status != xnn_status_success) {
4496         TF_LITE_KERNEL_LOG(logging_context,
4497                            "failed to delegate SQUARED_DIFFERENCE node #%d",
4498                            node_index);
4499         return kTfLiteError;
4500       }
4501     }
4502 
4503     return kTfLiteOk;
4504   }
4505 
VisitSubNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLiteSubParams * sub_params,const std::vector<uint32_t> & xnnpack_tensors)4506   static TfLiteStatus VisitSubNode(
4507       xnn_subgraph_t subgraph, const Delegate& delegate,
4508       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
4509       const TfLiteTensor* tensors, const TfLiteSubParams* sub_params,
4510       const std::vector<uint32_t>& xnnpack_tensors) {
4511     TF_LITE_ENSURE_STATUS(
4512         CheckNumInputsAndOutputs(logging_context, node, 2, 1, node_index));
4513 
4514     const TfLiteTensor& input1_tensor = tensors[node->inputs->data[0]];
4515     TF_LITE_ENSURE_STATUS(
4516         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input1_tensor,
4517                                        node->inputs->data[0], node_index));
4518     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
4519         logging_context, input1_tensor, node->inputs->data[0], node_index));
4520 
4521     const TfLiteTensor& input2_tensor = tensors[node->inputs->data[1]];
4522     TF_LITE_ENSURE_STATUS(
4523         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input2_tensor,
4524                                        node->inputs->data[1], node_index));
4525     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
4526         logging_context, input2_tensor, node->inputs->data[1], node_index));
4527 
4528     const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
4529     TF_LITE_ENSURE_STATUS(
4530         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor,
4531                                        node->outputs->data[0], node_index));
4532     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
4533         logging_context, output_tensor, node->outputs->data[0], node_index));
4534 
4535     const float scale_min = 1.0f / 1024.0f;
4536     const float scale_max = 256.0f;
4537     TF_LITE_ENSURE_STATUS(CheckTensorsInputOutputScale(
4538         logging_context, input1_tensor, output_tensor, scale_min, scale_max,
4539         node_index, "SUB"));
4540     TF_LITE_ENSURE_STATUS(CheckTensorsInputOutputScale(
4541         logging_context, input2_tensor, output_tensor, scale_min, scale_max,
4542         node_index, "SUB"));
4543 
4544     float output_min = -std::numeric_limits<float>::infinity();
4545     float output_max = +std::numeric_limits<float>::infinity();
4546     if (sub_params != nullptr) {
4547       TF_LITE_ENSURE_STATUS(ConvertActivationToOutputRange(
4548           logging_context, node_index, sub_params->activation, &output_min,
4549           &output_max));
4550     }
4551 
4552     if (subgraph != nullptr) {
4553       const xnn_status status = xnn_define_subtract(
4554           subgraph, output_min, output_max,
4555           /*input1_id=*/xnnpack_tensors[node->inputs->data[0]],
4556           /*input2_id=*/xnnpack_tensors[node->inputs->data[1]],
4557           /*output_id=*/xnnpack_tensors[node->outputs->data[0]], /*flags=*/0);
4558       if (status != xnn_status_success) {
4559         TF_LITE_KERNEL_LOG(logging_context, "failed to delegate SUB node #%d",
4560                            node_index);
4561         return kTfLiteError;
4562       }
4563     }
4564 
4565     return kTfLiteOk;
4566   }
4567 
VisitTransposeConvNode(xnn_subgraph_t subgraph,const Delegate & delegate,TfLiteContext * logging_context,int node_index,TfLiteNode * node,const TfLiteTensor * tensors,const TfLiteTransposeConvParams * deconv_params,const std::unordered_set<int> & quasi_static_tensors,const std::vector<uint32_t> & xnnpack_tensors)4568   static TfLiteStatus VisitTransposeConvNode(
4569       xnn_subgraph_t subgraph, const Delegate& delegate,
4570       TfLiteContext* logging_context, int node_index, TfLiteNode* node,
4571       const TfLiteTensor* tensors,
4572       const TfLiteTransposeConvParams* deconv_params,
4573       const std::unordered_set<int>& quasi_static_tensors,
4574       const std::vector<uint32_t>& xnnpack_tensors) {
4575     TF_LITE_ENSURE_STATUS(
4576         CheckNumInputsAndOutputs(logging_context, node,
4577                                  /*min_num_inputs=*/3, /*max_num_inputs=*/4,
4578                                  /*expected_num_outputs=*/1, node_index));
4579     const bool use_bias = node->inputs->size >= 4;
4580 
4581     const int output_shape_tensor_index = node->inputs->data[0];
4582     const TfLiteTensor& output_shape_tensor =
4583         tensors[output_shape_tensor_index];
4584     TF_LITE_ENSURE_STATUS(
4585         CheckTensorType(logging_context, output_shape_tensor, kTfLiteInt32,
4586                         output_shape_tensor_index, node_index));
4587     TF_LITE_ENSURE_STATUS(
4588         CheckShapeTensorShape(logging_context, output_shape_tensor,
4589                               output_shape_tensor_index, node_index));
4590     TF_LITE_ENSURE_STATUS(
4591         CheckTensorStaticAllocation(logging_context, output_shape_tensor,
4592                                     output_shape_tensor_index, node_index));
4593     const int output_shape_dims = SizeOfDimension(&output_shape_tensor, 0);
4594     if (output_shape_dims != 4) {
4595       TF_LITE_MAYBE_KERNEL_LOG(
4596           logging_context,
4597           "unsupported number of output shape dimensions (%d) in node #%d: "
4598           "4 dimensions expected",
4599           output_shape_dims, node_index);
4600       return kTfLiteError;
4601     }
4602 
4603     const int filter_tensor_index = node->inputs->data[1];
4604     const TfLiteTensor& filter_tensor = tensors[filter_tensor_index];
4605     TF_LITE_ENSURE_STATUS(
4606         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, filter_tensor,
4607                                        filter_tensor_index, node_index));
4608     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, filter_tensor, 4,
4609                                            filter_tensor_index));
4610     if (quasi_static_tensors.count(filter_tensor_index) == 0) {
4611       TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
4612           logging_context, filter_tensor, filter_tensor_index, node_index));
4613     }
4614 
4615     const int input_tensor_index = node->inputs->data[2];
4616     const TfLiteTensor& input_tensor = tensors[input_tensor_index];
4617     TF_LITE_ENSURE_STATUS(
4618         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, input_tensor,
4619                                        input_tensor_index, node_index));
4620     TF_LITE_ENSURE_STATUS(
4621         CheckTensorShape(logging_context, input_tensor, 4, input_tensor_index));
4622     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
4623         logging_context, input_tensor, input_tensor_index, node_index));
4624 
4625     uint32_t xnnpack_tensor_bias = XNN_INVALID_VALUE_ID;  // "No bias".
4626     if (use_bias) {
4627       const int bias_tensor_index = node->inputs->data[3];
4628       if (bias_tensor_index != kTfLiteOptionalTensor) {
4629         const TfLiteTensor& bias_tensor = tensors[bias_tensor_index];
4630         TF_LITE_ENSURE_STATUS(CheckTensorFloat32OrQInt32Type(
4631             delegate, logging_context, bias_tensor, bias_tensor_index,
4632             node_index));
4633         TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, bias_tensor, 1,
4634                                                bias_tensor_index));
4635         if (quasi_static_tensors.count(bias_tensor_index) == 0) {
4636           TF_LITE_ENSURE_STATUS(CheckTensorStaticAllocation(
4637               logging_context, bias_tensor, bias_tensor_index, node_index));
4638         }
4639         if (subgraph != nullptr) {
4640           xnnpack_tensor_bias = xnnpack_tensors[bias_tensor_index];
4641         }
4642       }
4643     }
4644 
4645     const int output_tensor_index = node->outputs->data[0];
4646     const TfLiteTensor& output_tensor = tensors[output_tensor_index];
4647     TF_LITE_ENSURE_STATUS(
4648         CheckTensorFloat32OrQUInt8Type(delegate, logging_context, output_tensor,
4649                                        output_tensor_index, node_index));
4650     TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_tensor, 4,
4651                                            output_tensor_index));
4652     TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
4653         logging_context, output_tensor, output_tensor_index, node_index));
4654 
4655     const int* input_tensor_dims = input_tensor.dims->data;
4656     const int input_height = input_tensor_dims[1];
4657     const int input_width = input_tensor_dims[2];
4658 
4659     const int* filter_tensor_dims = filter_tensor.dims->data;
4660     const int output_channels = filter_tensor_dims[0];
4661     const int kernel_height = filter_tensor_dims[1];
4662     const int kernel_width = filter_tensor_dims[2];
4663     const int input_channels = filter_tensor_dims[3];
4664 
4665     const int32_t* output_shape = GetTensorData<int32_t>(&output_shape_tensor);
4666     const int output_height = output_shape[1];
4667     const int output_width = output_shape[2];
4668     const int output_tensor_channels = output_shape[3];
4669     if (output_channels != output_tensor_channels) {
4670       TF_LITE_MAYBE_KERNEL_LOG(
4671           logging_context,
4672           "transpose convolution kernel output channel dimension (%d) "
4673           "doesn't match output shape channel dimension (%d) in node #%d: "
4674           "4 dimensions expected",
4675           output_channels, output_tensor_channels, node_index);
4676       return kTfLiteError;
4677     }
4678     if (input_channels != input_tensor_dims[3]) {
4679       TF_LITE_MAYBE_KERNEL_LOG(
4680           logging_context,
4681           "transpose convolution kernel input channel dimension (%d) "
4682           "doesn't match filter input channel (%d) in node #%d",
4683           input_channels, input_tensor_dims[3]);
4684       return kTfLiteError;
4685     }
4686 
4687     int padding_top = 0;
4688     int padding_bottom = 0;
4689     int padding_left = 0;
4690     int padding_right = 0;
4691     int adjustment_height = 0;
4692     int adjustment_width = 0;
4693     TF_LITE_ENSURE_STATUS(CalculateTransposeConvPaddings(
4694         logging_context, deconv_params->padding, input_height, input_width,
4695         kernel_height, kernel_width, /*dilation_height=*/1,
4696         /*dilation_width=*/1, deconv_params->stride_height,
4697         deconv_params->stride_width, node_index, output_height, output_width,
4698         &padding_top, &padding_bottom, &padding_left, &padding_right,
4699         &adjustment_height, &adjustment_width));
4700 
4701     if (subgraph != nullptr) {
4702       const xnn_status status = xnn_define_deconvolution_2d(
4703           subgraph,
4704           /*padding_top=*/padding_top,
4705           /*padding_right=*/padding_right,
4706           /*padding_bottom=*/padding_bottom,
4707           /*padding_left=*/padding_left,
4708           /*adjustment_height=*/adjustment_height,
4709           /*adjustment_width=*/adjustment_width,
4710           static_cast<uint32_t>(kernel_height),
4711           static_cast<uint32_t>(kernel_width),
4712           static_cast<uint32_t>(deconv_params->stride_height),
4713           static_cast<uint32_t>(deconv_params->stride_width),
4714           /*dilation_height=*/1,
4715           /*dilation_width=*/1,
4716           /*groups=*/1,
4717           /*group_input_channels=*/input_channels,
4718           /*group_output_channels=*/output_channels,
4719           /*output_min=*/-std::numeric_limits<float>::infinity(),
4720           /*output_max=*/+std::numeric_limits<float>::infinity(),
4721           /*input_id=*/xnnpack_tensors[input_tensor_index],
4722           /*filter_id=*/xnnpack_tensors[filter_tensor_index],
4723           /*bias_id=*/xnnpack_tensor_bias,
4724           /*output_id=*/xnnpack_tensors[output_tensor_index],
4725           /*flags=*/0);
4726       if (status != xnn_status_success) {
4727         TF_LITE_KERNEL_LOG(logging_context,
4728                            "failed to delegate TransposeConv node #%d",
4729                            node_index);
4730         return kTfLiteError;
4731       }
4732     }
4733 
4734     return kTfLiteOk;
4735   }
4736 
4737  private:
Subgraph(const Delegate & delegate,xnn_runtime_t runtime,const std::unordered_set<int> & externals)4738   Subgraph(const Delegate& delegate, xnn_runtime_t runtime,
4739            const std::unordered_set<int>& externals)
4740       : runtime_(runtime, &xnn_delete_runtime) {
4741     for (int t : externals) {
4742       externals_[t] = nullptr;
4743     }
4744   }
4745 
4746   // XNNPACK Runtime (subgraph + workspace) with smart-pointer for lifetime
4747   // management.
4748   std::unique_ptr<xnn_runtime, decltype(&xnn_delete_runtime)> runtime_{
4749       nullptr, &xnn_delete_runtime};
4750   // Mapping from TFLite Tensor IDs (same as XNNPACK Value IDs) for
4751   // input/output tensors in the delegated subgraph to their data locations.
4752   std::unordered_map<int, void*> externals_;
4753   // Memory location to use for 0-size extenal tensors, as TFLite init their
4754   // data pointer to nullptr, and XNNPACK requires valid data pointers.
4755   char dummy_data_{0};
4756 };
4757 
PrepareOpsToDelegate(TfLiteContext * context)4758 TfLiteIntArray* Delegate::PrepareOpsToDelegate(TfLiteContext* context) {
4759   // Clear previous data, in case the delegate is reused without re-creation.
4760   static_unpacked_data_map_.clear();
4761   static_unpacked_data_.clear();
4762   static_unpack_nodes_.clear();
4763   static_sparse_weights_.clear();
4764 
4765   TfLiteIntArray* execution_plan = nullptr;
4766   if (context->GetExecutionPlan(context, &execution_plan) != kTfLiteOk) {
4767     TF_LITE_KERNEL_LOG(context, "Unable to get graph execution plan.");
4768     return nullptr;
4769   }
4770 
4771   // Mapping for quasi-static (unpacked from static) tensor index to the node
4772   // index that produced it.
4773   std::unordered_map<int, int> quasi_static_tensors_producers;
4774   // Set of all quasi-static tensors in the execution plan.
4775   std::unordered_set<int> quasi_static_tensors;
4776   // Set of quasi-static tensors consumed by the delegated nodes.
4777   std::unordered_set<int> quasi_static_tensors_to_unpack;
4778 
4779   TfLiteIntArray* nodes_to_delegate =
4780       TfLiteIntArrayCreate(execution_plan->size);
4781   nodes_to_delegate->size = 0;
4782   for (int i = 0; i < execution_plan->size; ++i) {
4783     const int node_index = execution_plan->data[i];
4784 
4785     // Check if TFLite nodes can be delegated to XNNPACK
4786     TfLiteNode* node = nullptr;
4787     TfLiteRegistration* registration = nullptr;
4788     if (context->GetNodeAndRegistration(context, node_index, &node,
4789                                         &registration) != kTfLiteOk) {
4790       TF_LITE_KERNEL_LOG(context,
4791                          "Unable to get node and registration for node %d.",
4792                          node_index);
4793       continue;  // Soft error (skip this node).
4794     }
4795 
4796     // Prepare to unpack FP16/INT8 tensors.
4797     if (registration->builtin_code == kTfLiteBuiltinDequantize &&
4798         node->inputs->size == 1 && node->outputs->size == 1) {
4799       const TfLiteTensor& input_tensor =
4800           context->tensors[node->inputs->data[0]];
4801       const TfLiteTensor& output_tensor =
4802           context->tensors[node->outputs->data[0]];
4803 
4804       bool is_supported_int8_tensor = input_tensor.type == kTfLiteInt8;
4805       if (is_supported_int8_tensor) {
4806         const auto* quant_params = static_cast<const TfLiteAffineQuantization*>(
4807             input_tensor.quantization.params);
4808         if (quant_params == nullptr) {
4809           is_supported_int8_tensor = false;
4810         }
4811       }
4812       if (input_tensor.sparsity == nullptr &&
4813           (input_tensor.allocation_type == kTfLiteMmapRo ||
4814            quasi_static_tensors.count(node->inputs->data[0]) != 0) &&
4815           (input_tensor.type == kTfLiteFloat16 || is_supported_int8_tensor) &&
4816           output_tensor.type == kTfLiteFloat32) {
4817         static_unpack_nodes_.insert(node_index);
4818         quasi_static_tensors_producers[node->outputs->data[0]] = node_index;
4819         quasi_static_tensors.insert(node->outputs->data[0]);
4820 
4821         if (input_tensor.allocation_type != kTfLiteMmapRo) {
4822           quasi_static_tensors_to_unpack.insert(node->inputs->data[0]);
4823         }
4824 
4825         // If dequantized input is sparse, so is its output
4826         if (static_sparse_weights_.count(node->inputs->data[0]) != 0) {
4827           static_sparse_weights_.insert(node->outputs->data[0]);
4828         }
4829 
4830         // Skip this node for now. If output of the node is consumed only by
4831         // delegated nodes, it will be added to nodes_to_delegate in the end.
4832         continue;
4833       }
4834     }
4835 
4836     // Prepare to unpack sparse tensors.
4837     // TODO(b/157729695): In the future, we also need to handle the case where a
4838     // sparse tensor is fed to a TFLite op directly, and no Densify() op is
4839     // inserted. For now this is not a problem because the Conv() op in tflite
4840     // can only consume dense tensors.
4841     if (registration->builtin_code == kTfLiteBuiltinDensify &&
4842         node->inputs->size == 1 && node->outputs->size == 1) {
4843       const TfLiteTensor& input_tensor =
4844           context->tensors[node->inputs->data[0]];
4845       const TfLiteTensor& output_tensor =
4846           context->tensors[node->outputs->data[0]];
4847 
4848       if (input_tensor.allocation_type == kTfLiteMmapRo &&
4849           input_tensor.sparsity != nullptr &&
4850           (input_tensor.type == kTfLiteFloat16 ||
4851            input_tensor.type == kTfLiteInt8 ||
4852            input_tensor.type == kTfLiteFloat32) &&
4853           output_tensor.type == input_tensor.type) {
4854         static_unpack_nodes_.insert(node_index);
4855         quasi_static_tensors_producers[node->outputs->data[0]] = node_index;
4856         quasi_static_tensors.insert(node->outputs->data[0]);
4857         static_sparse_weights_.insert(node->outputs->data[0]);
4858 
4859         // Skip this node for now. If output of the node is consumed only by
4860         // delegated nodes, it will be added to nodes_to_delegate in the end.
4861         continue;
4862       }
4863     }
4864 
4865     if (Subgraph::VisitNode(/*subgraph=*/nullptr, /*delegate=*/*this, context,
4866                             registration, node, node_index,
4867                             quasi_static_tensors,
4868                             std::vector<uint32_t>()) != kTfLiteOk) {
4869       // If a non-delegated node consumes output of a node that unpacks static
4870       // data, that node shouldn't be delegated.
4871       for (int j = 0; j < node->inputs->size; j++) {
4872         const auto it =
4873             quasi_static_tensors_producers.find(node->inputs->data[j]);
4874         if (it != quasi_static_tensors_producers.end()) {
4875           static_unpack_nodes_.erase(it->second);
4876         }
4877       }
4878 
4879       // Non-delegatable node is not an error.
4880       continue;
4881     }
4882 
4883     for (int j = 0; j < node->inputs->size; j++) {
4884       if (quasi_static_tensors.count(node->inputs->data[j]) != 0) {
4885         quasi_static_tensors_to_unpack.insert(node->inputs->data[j]);
4886       }
4887     }
4888 
4889     nodes_to_delegate->data[nodes_to_delegate->size++] = node_index;
4890   }
4891 
4892   // Sort quasi-static tensors to be unpacked by the node index the produced
4893   // them. This ensures that in situations where quasi-static tensor is
4894   // produced from another quasi-static tensor, the tensors are unpacked in
4895   // the original execution plan order.
4896   std::vector<int> sorted_quasi_static_tensors_to_unpack(
4897       quasi_static_tensors_to_unpack.cbegin(),
4898       quasi_static_tensors_to_unpack.cend());
4899   std::sort(sorted_quasi_static_tensors_to_unpack.begin(),
4900             sorted_quasi_static_tensors_to_unpack.end(),
4901             [&quasi_static_tensors_producers](int t1, int t2) {
4902               return quasi_static_tensors_producers[t1] <
4903                      quasi_static_tensors_producers[t2];
4904             });
4905 
4906   // Unpack static data of all tensors
4907   for (int t : sorted_quasi_static_tensors_to_unpack) {
4908     const int producer_index = quasi_static_tensors_producers[t];
4909     // Check if TFLite nodes can be delegated to XNNPACK
4910     TfLiteNode* node = nullptr;
4911     TfLiteRegistration* registration = nullptr;
4912     if (context->GetNodeAndRegistration(context, producer_index, &node,
4913                                         &registration) != kTfLiteOk) {
4914       TF_LITE_KERNEL_LOG(context,
4915                          "Unable to get node and registration for node %d.",
4916                          producer_index);
4917       TfLiteIntArrayFree(nodes_to_delegate);
4918       return nullptr;  // Hard error.
4919     }
4920 
4921     if (node->inputs->size != 1) {
4922       TF_LITE_KERNEL_LOG(context, "unexpected number of inputs (%d) in node %d",
4923                          node->inputs->size, producer_index);
4924       TfLiteIntArrayFree(nodes_to_delegate);
4925       return nullptr;  // Hard error.
4926     }
4927 
4928     if (node->outputs->size != 1) {
4929       TF_LITE_KERNEL_LOG(context,
4930                          "unexpected number of outputs (%d) in node %d",
4931                          node->outputs->size, producer_index);
4932       TfLiteIntArrayFree(nodes_to_delegate);
4933       return nullptr;  // Hard error.
4934     }
4935 
4936     const TfLiteTensor& input_tensor = context->tensors[node->inputs->data[0]];
4937 
4938     // Consider the case when the input to unpacking node is quasi-static.
4939     const auto static_unpacked_input_it_ =
4940         static_unpacked_data_map_.find(node->inputs->data[0]);
4941     if (static_unpacked_input_it_ == static_unpacked_data_map_.end()) {
4942       if (input_tensor.allocation_type != kTfLiteMmapRo) {
4943         TF_LITE_KERNEL_LOG(
4944             context,
4945             "unexpected allocation type (%d) in tensor %d in node %d (%d)",
4946             input_tensor.allocation_type, node->inputs->data[0], producer_index,
4947             registration->builtin_code);
4948         TfLiteIntArrayFree(nodes_to_delegate);
4949         return nullptr;  // Hard error.
4950       }
4951     }
4952 
4953     const TfLiteTensor& output_tensor = context->tensors[t];
4954     size_t tensor_elements = output_tensor.bytes;
4955     switch (output_tensor.type) {
4956       case kTfLiteFloat32:
4957         tensor_elements /= sizeof(float);
4958         break;
4959       case kTfLiteFloat16:
4960         tensor_elements /= sizeof(uint16_t);
4961         break;
4962       case kTfLiteInt8:
4963         tensor_elements /= sizeof(int8_t);
4964         break;
4965       default: {
4966         TF_LITE_KERNEL_LOG(context,
4967                            "unexpected datatype (%s) in tensor %d in node %d",
4968                            TfLiteTypeGetName(output_tensor.type),
4969                            node->outputs->data[0], producer_index);
4970         TfLiteIntArrayFree(nodes_to_delegate);
4971         return nullptr;  // Hard error.
4972       }
4973     }
4974 
4975     // Align to XNN_EXTRA_BYTES bytes
4976     while (static_unpacked_data_.size() % XNN_EXTRA_BYTES != 0) {
4977       static_unpacked_data_.push_back(0);
4978     }
4979     const size_t tensor_offset = static_unpacked_data_.size();
4980     static_unpacked_data_.resize(tensor_offset + context->tensors[t].bytes);
4981 
4982     char* unpacked_data = static_unpacked_data_.data() + tensor_offset;
4983     const char* packed_data =
4984         static_unpacked_input_it_ != static_unpacked_data_map_.end()
4985             ? static_unpacked_data_.data() + static_unpacked_input_it_->second
4986             : static_cast<const char*>(input_tensor.data.data);
4987     switch (registration->builtin_code) {
4988       case kTfLiteBuiltinDequantize: {
4989         // Such a condition has been checked when preparing to unpack FP16/INT8
4990         // tensors.
4991         TFLITE_DCHECK(input_tensor.sparsity == nullptr);
4992         // Actual data unpacking
4993         switch (input_tensor.type) {
4994           case kTfLiteFloat16:
4995             DequantizeFloat16(reinterpret_cast<const uint16_t*>(packed_data),
4996                               reinterpret_cast<float*>(unpacked_data),
4997                               tensor_elements);
4998             break;
4999           case kTfLiteInt8: {
5000             TfLiteAffineQuantization* quant_params =
5001                 static_cast<TfLiteAffineQuantization*>(
5002                     input_tensor.quantization.params);
5003             // Such conditions have been checked when preparing to unpack INT8
5004             // tensors.
5005             TFLITE_DCHECK(quant_params != nullptr);
5006 
5007             if (quant_params->scale->size == 1) {
5008               // Per-tensor quantization
5009               DequantizeInt8(reinterpret_cast<const int8_t*>(packed_data),
5010                              reinterpret_cast<float*>(unpacked_data),
5011                              GetTensorShape(&input_tensor),
5012                              input_tensor.params.zero_point,
5013                              input_tensor.params.scale);
5014             } else {
5015               // Per-channel quantization
5016               PerChannelDequantizeInt8(
5017                   reinterpret_cast<const int8_t*>(packed_data),
5018                   reinterpret_cast<float*>(unpacked_data),
5019                   GetTensorShape(&input_tensor), quant_params->zero_point->data,
5020                   quant_params->scale->data, quant_params->quantized_dimension);
5021             }
5022             break;
5023           }
5024           default:
5025             // This should not happen as we only allow FP16/INT8 input_tensor
5026             // when preparing the unpacking.
5027             TFLITE_DCHECK(false);
5028         }
5029         break;
5030       }
5031       case kTfLiteBuiltinDensify: {
5032         // Such a condition has been checked when preparing to unpack FP16/INT8
5033         // tensors.
5034         TFLITE_DCHECK(input_tensor.sparsity != nullptr);
5035         const int dims_count = NumDimensions(&output_tensor);
5036         std::vector<int> vector_shape(dims_count);
5037         for (int i = 0; i < dims_count; i++) {
5038           vector_shape[i] = SizeOfDimension(&output_tensor, i);
5039         }
5040 
5041         switch (input_tensor.type) {
5042           case kTfLiteFloat32: {
5043             const size_t dense_size = context->tensors[t].bytes / sizeof(float);
5044             float* unpacked_fp32_data = reinterpret_cast<float*>(unpacked_data);
5045             tflite::internal::sparsity::FormatConverter<float> converter(
5046                 vector_shape, *input_tensor.sparsity);
5047             converter.SparseToDense(
5048                 static_cast<const float*>(input_tensor.data.data), dense_size,
5049                 unpacked_fp32_data, context);
5050             break;
5051           }
5052           case kTfLiteFloat16: {
5053             const size_t dense_size =
5054                 context->tensors[t].bytes / sizeof(Eigen::half);
5055             Eigen::half* unpacked_fp16_data =
5056                 reinterpret_cast<Eigen::half*>(unpacked_data);
5057             tflite::internal::sparsity::FormatConverter<Eigen::half> converter(
5058                 vector_shape, *input_tensor.sparsity);
5059             converter.SparseToDense(
5060                 static_cast<const Eigen::half*>(input_tensor.data.data),
5061                 dense_size, unpacked_fp16_data, context);
5062             break;
5063           }
5064           case kTfLiteInt8: {
5065             const size_t dense_size =
5066                 context->tensors[t].bytes / sizeof(int8_t);
5067             int8_t* unpacked_int8_data =
5068                 reinterpret_cast<int8_t*>(unpacked_data);
5069             tflite::internal::sparsity::FormatConverter<int8_t> converter(
5070                 vector_shape, *input_tensor.sparsity);
5071             converter.SparseToDense(
5072                 static_cast<const int8_t*>(input_tensor.data.data), dense_size,
5073                 unpacked_int8_data, context);
5074             break;
5075           }
5076           default: {
5077             // This should not happen as we only allow FP16/INT8 input_tensor
5078             // when preparing the unpacking.
5079             TFLITE_DCHECK(false);
5080           }
5081         }
5082         break;
5083       }
5084       default:
5085         TF_LITE_KERNEL_LOG(context, "unexpected op registration %d at node %d",
5086                            registration->builtin_code, producer_index);
5087         TfLiteIntArrayFree(nodes_to_delegate);
5088         return nullptr;  // Hard error.
5089     }
5090 
5091     static_unpacked_data_map_[t] = tensor_offset;
5092   }
5093 
5094   // Add nodes that unpack static data consumed by delegated nodes.
5095   // Note: this is done purely to avoid the overhead of running these nodes
5096   // again in TFLite interpreter which would allocate memory for their outputs.
5097   // We mark them as delegated, but the delegate would simply ignore these nodes
5098   // as the static weights are already unpacked.
5099   for (int node_index : static_unpack_nodes_) {
5100     nodes_to_delegate->data[nodes_to_delegate->size++] = node_index;
5101   }
5102   std::sort(&nodes_to_delegate->data[0],
5103             &nodes_to_delegate->data[nodes_to_delegate->size]);
5104 
5105 #ifdef XNNPACK_DELEGATE_TEST_MODE
5106   // In the test mode build (used by unit tests), XNNPACK delegate claims to
5107   // support all operators in the execution plan to disable fallback to the
5108   // default TensorFlow Lite kernels. Thus, if any of the ops in the model are
5109   // not supported by the delegate, they will cause a failure in
5110   // ::tflite::Interpreter::ModifyGraphWithDelegate, to be caught in the unit
5111   // tests.
5112   nodes_to_delegate->size = execution_plan->size;
5113   std::copy(&execution_plan->data[0],
5114             &execution_plan->data[execution_plan->size],
5115             &nodes_to_delegate->data[0]);
5116 #endif
5117 
5118   return nodes_to_delegate;
5119 }
5120 
SubgraphInit(TfLiteContext * context,const char * buffer,size_t length)5121 void* SubgraphInit(TfLiteContext* context, const char* buffer, size_t length) {
5122   const TfLiteDelegateParams* params =
5123       reinterpret_cast<const TfLiteDelegateParams*>(buffer);
5124 
5125   return static_cast<void*>(Subgraph::Create(
5126       context, params,
5127       *static_cast<::tflite::xnnpack::Delegate*>(params->delegate->data_)));
5128 }
5129 
SubgraphPrepare(TfLiteContext * context,TfLiteNode * node)5130 TfLiteStatus SubgraphPrepare(TfLiteContext* context, TfLiteNode* node) {
5131   if (node->user_data == nullptr) {
5132     return kTfLiteError;
5133   }
5134 
5135   return static_cast<Subgraph*>(node->user_data)->Prepare(context);
5136 }
5137 
SubgraphInvoke(TfLiteContext * context,TfLiteNode * node)5138 TfLiteStatus SubgraphInvoke(TfLiteContext* context, TfLiteNode* node) {
5139   if (node->user_data == nullptr) {
5140     return kTfLiteError;
5141   }
5142 
5143   return static_cast<Subgraph*>(node->user_data)->Invoke(context);
5144 }
5145 
SubgraphFree(TfLiteContext * context,void * buffer)5146 void SubgraphFree(TfLiteContext* context, void* buffer) {
5147   if (buffer != nullptr) {
5148     delete static_cast<Subgraph*>(buffer);
5149   }
5150 }
5151 
5152 const TfLiteRegistration kSubgraphRegistration = {
5153     /*.init=*/SubgraphInit,
5154     /*.free=*/SubgraphFree,
5155     /*.prepare=*/SubgraphPrepare,
5156     /*.invoke=*/SubgraphInvoke,
5157     /*.profiling_string=*/nullptr,
5158     /*.builtin_code=*/0,
5159     /*.custom_name=*/"TfLiteXNNPackDelegate",
5160     /*.version=*/2,
5161 };
5162 
DelegatePrepare(TfLiteContext * context,TfLiteDelegate * delegate)5163 TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
5164   TfLiteIntArray* ops_to_replace =
5165       static_cast<::tflite::xnnpack::Delegate*>(delegate->data_)
5166           ->PrepareOpsToDelegate(context);
5167   if (ops_to_replace == nullptr) {
5168     return kTfLiteError;
5169   }
5170 
5171   const TfLiteStatus status = context->ReplaceNodeSubsetsWithDelegateKernels(
5172       context, kSubgraphRegistration, ops_to_replace, delegate);
5173   TfLiteIntArrayFree(ops_to_replace);
5174   return status;
5175 }
5176 
5177 }  // namespace
5178 }  // namespace xnnpack
5179 }  // namespace tflite
5180 
TfLiteXNNPackDelegateWeightsCacheCreate()5181 TfLiteXNNPackDelegateWeightsCache* TfLiteXNNPackDelegateWeightsCacheCreate() {
5182   xnn_status status = xnn_initialize(/*allocator=*/nullptr);
5183   if (status != xnn_status_success) {
5184     return nullptr;
5185   }
5186 
5187   xnn_weights_cache_t weights_cache = nullptr;
5188   if (xnn_create_weights_cache(&weights_cache) != xnn_status_success) {
5189     return nullptr;
5190   }
5191   return reinterpret_cast<TfLiteXNNPackDelegateWeightsCache*>(weights_cache);
5192 }
5193 
5194 TfLiteXNNPackDelegateWeightsCache*
TfLiteXNNPackDelegateWeightsCacheCreateWithSize(size_t size)5195 TfLiteXNNPackDelegateWeightsCacheCreateWithSize(size_t size) {
5196   xnn_status status = xnn_initialize(/*allocator=*/nullptr);
5197   if (status != xnn_status_success) {
5198     return nullptr;
5199   }
5200 
5201   xnn_weights_cache_t weights_cache = nullptr;
5202   if (xnn_create_weights_cache_with_size(size, &weights_cache) !=
5203       xnn_status_success) {
5204     return nullptr;
5205   }
5206   return reinterpret_cast<TfLiteXNNPackDelegateWeightsCache*>(weights_cache);
5207 }
5208 
TfLiteXNNPackDelegateWeightsCacheFinalizeSoft(TfLiteXNNPackDelegateWeightsCache * cache)5209 bool TfLiteXNNPackDelegateWeightsCacheFinalizeSoft(
5210     TfLiteXNNPackDelegateWeightsCache* cache) {
5211   auto weights_cache = reinterpret_cast<xnn_weights_cache_t>(cache);
5212   xnn_status status = xnn_finalize_weights_cache(
5213       weights_cache, xnn_weights_cache_finalization_kind_soft);
5214   return status == xnn_status_success;
5215 }
5216 
TfLiteXNNPackDelegateWeightsCacheFinalizeHard(TfLiteXNNPackDelegateWeightsCache * cache)5217 bool TfLiteXNNPackDelegateWeightsCacheFinalizeHard(
5218     TfLiteXNNPackDelegateWeightsCache* cache) {
5219   auto weights_cache = reinterpret_cast<xnn_weights_cache_t>(cache);
5220   xnn_status status = xnn_finalize_weights_cache(
5221       weights_cache, xnn_weights_cache_finalization_kind_hard);
5222   return status == xnn_status_success;
5223 }
5224 
TfLiteXNNPackDelegateWeightsCacheDelete(TfLiteXNNPackDelegateWeightsCache * cache)5225 void TfLiteXNNPackDelegateWeightsCacheDelete(
5226     TfLiteXNNPackDelegateWeightsCache* cache) {
5227   if (cache == nullptr) {
5228     return;
5229   }
5230   auto weights_cache = reinterpret_cast<xnn_weights_cache_t>(cache);
5231   xnn_delete_weights_cache(weights_cache);
5232   xnn_deinitialize();
5233 }
5234 
TfLiteXNNPackDelegateOptionsDefault()5235 TfLiteXNNPackDelegateOptions TfLiteXNNPackDelegateOptionsDefault() {
5236   TfLiteXNNPackDelegateOptions options = {0};
5237 
5238   // Quantized inference is enabled by default on Web platform
5239 #ifdef XNNPACK_DELEGATE_ENABLE_QS8
5240   options.flags |= TFLITE_XNNPACK_DELEGATE_FLAG_QS8;
5241 #endif
5242 #ifdef XNNPACK_DELEGATE_ENABLE_QU8
5243   options.flags |= TFLITE_XNNPACK_DELEGATE_FLAG_QU8;
5244 #endif
5245 
5246   // Enable quantized inference for the delegate build used in unit tests.
5247 #ifdef XNNPACK_DELEGATE_TEST_MODE
5248   options.flags |= TFLITE_XNNPACK_DELEGATE_FLAG_QS8;
5249   options.flags |= TFLITE_XNNPACK_DELEGATE_FLAG_QU8;
5250 #endif  // XNNPACK_DELEGATE_TEST_MODE
5251 
5252   return options;
5253 }
5254 
TfLiteXNNPackDelegateCreate(const TfLiteXNNPackDelegateOptions * options)5255 TfLiteDelegate* TfLiteXNNPackDelegateCreate(
5256     const TfLiteXNNPackDelegateOptions* options) {
5257   xnn_status status = xnn_initialize(/*allocator=*/nullptr);
5258   if (status != xnn_status_success) {
5259     return nullptr;
5260   }
5261 
5262   xnn_workspace_t workspace = nullptr;
5263   if (xnn_create_workspace(&workspace) != xnn_status_success) {
5264     return nullptr;
5265   }
5266 
5267   auto* xnnpack_delegate = new ::tflite::xnnpack::Delegate(options, workspace);
5268   return xnnpack_delegate ? xnnpack_delegate->tflite_delegate() : nullptr;
5269 }
5270 
TfLiteXNNPackDelegateGetThreadPool(TfLiteDelegate * delegate)5271 void* TfLiteXNNPackDelegateGetThreadPool(TfLiteDelegate* delegate) {
5272   if (delegate == nullptr) {
5273     return nullptr;
5274   }
5275 
5276   return static_cast<void*>(
5277       static_cast<::tflite::xnnpack::Delegate*>(delegate->data_)->threadpool());
5278 }
5279 
TfLiteXNNPackDelegateDelete(TfLiteDelegate * delegate)5280 void TfLiteXNNPackDelegateDelete(TfLiteDelegate* delegate) {
5281   if (delegate != nullptr) {
5282     delete static_cast<::tflite::xnnpack::Delegate*>(delegate->data_);
5283   }
5284 }
5285