xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/coreml/coreml_delegate.mm (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/lite/delegates/coreml/coreml_delegate.h"
16
17#include <string.h>
18#include <sys/utsname.h>
19#include <limits>
20#include <vector>
21
22#include "tensorflow/lite/builtin_ops.h"
23#include "tensorflow/lite/c/builtin_op_data.h"
24#include "tensorflow/lite/context_util.h"
25#include "tensorflow/lite/delegates/coreml/builders/op_validator.h"
26#include "tensorflow/lite/delegates/coreml/builders/util.h"
27#include "tensorflow/lite/delegates/coreml/coreml_delegate_kernel.h"
28#include "tensorflow/lite/delegates/utils.h"
29#include "tensorflow/lite/kernels/kernel_util.h"
30#include "tensorflow/lite/minimal_logging.h"
31
32namespace tflite {
33namespace {
34constexpr int kMinNodesPerCoreMlDelegate = 2;
35
36using delegates::coreml::CoreMlDelegateKernel;
37
38bool IsNodeSupportedByDelegate(const TfLiteRegistration* registration, const TfLiteNode* node,
39                               TfLiteContext* context, const TfLiteCoreMlDelegateOptions* options) {
40  if (@available(iOS 11.0, *)) {
41  } else {
42    return false;
43  }
44
45  // For most ops, only version 1 is supported.
46  if (registration->version > 1) {
47    switch (registration->builtin_code) {
48      case kTfLiteBuiltinDepthwiseConv2d:
49        if (registration->version > 2) return false;
50        break;
51      // FullyConnected without bias is supported starting from version 6.
52      case kTfLiteBuiltinFullyConnected:
53        if (registration->version > 6) return false;
54        break;
55      default:
56        return false;
57    }
58  }
59
60  // The model should not be full-integer quantized. For ops supported by Core ML delegate,
61  // Testing if the first input is float is sufficient to filter full-integer quantized ops.
62  int input_tensor_index = 0;
63  // TransposeConv input: (output_shape, filters, input)
64  if (registration->builtin_code == kTfLiteBuiltinTransposeConv) {
65    input_tensor_index = 2;
66  }
67  if (GetInput(context, node, input_tensor_index)->type != kTfLiteFloat32) {
68    return false;
69  }
70
71  // TODO(b/149179044): Add extra validation if this is not sufficient.
72
73  // TODO(karimnossier): Refactor this function.
74  // TODO(karimnosseir): Add
75  // 1) Checks for versioning.
76  // 2) Checks for input constraints.
77  // Follow the ordering of TfLiteBuiltinOperator enum.
78  switch (registration->builtin_code) {
79    case kTfLiteBuiltinAdd: {
80      return node->builtin_data != nullptr &&
81             delegates::coreml::IsBinaryOpSupported(registration, node, context);
82    }
83    case kTfLiteBuiltinAveragePool2d: {
84      const auto* params = reinterpret_cast<const TfLitePoolParams*>(node->builtin_data);
85      return params != nullptr && params->activation == kTfLiteActNone;
86    }
87    case kTfLiteBuiltinConcatenation: {
88      return delegates::coreml::IsConcatenationOpSupported(registration, node, context);
89    }
90    case kTfLiteBuiltinConv2d: {
91      return delegates::coreml::IsConvolutionOpSupported(registration, node, context);
92    }
93    case kTfLiteBuiltinDepthwiseConv2d: {
94      return delegates::coreml::IsDepthwiseConvolutionOpSupported(registration, node, context);
95    }
96    case kTfLiteBuiltinFullyConnected: {
97      return delegates::coreml::IsFullyConnectedOpSupported(registration, node, context);
98    }
99    case kTfLiteBuiltinHardSwish: {
100      return true;
101    }
102    case kTfLiteBuiltinLogistic: {
103      return true;
104    }
105    case kTfLiteBuiltinMaxPool2d: {
106      const auto* params = reinterpret_cast<const TfLitePoolParams*>(node->builtin_data);
107      return params != nullptr && params->activation == kTfLiteActNone;
108    }
109    case kTfLiteBuiltinMirrorPad: {
110      return delegates::coreml::IsMirrorPadOpSupported(registration, node, context);
111    }
112    case kTfLiteBuiltinMean: {
113      return delegates::coreml::IsMeanOpSupported(registration, node, context);
114    }
115    case kTfLiteBuiltinMul: {
116      return node->builtin_data != nullptr &&
117             delegates::coreml::IsBinaryOpSupported(registration, node, context);
118    }
119    case kTfLiteBuiltinPad:
120    case kTfLiteBuiltinPadv2: {
121      return delegates::coreml::IsPadOpSupported(registration, node, context);
122    }
123    case kTfLiteBuiltinRelu: {
124      return true;
125    }
126    case kTfLiteBuiltinReluN1To1: {
127      return true;
128    }
129    case kTfLiteBuiltinRelu6: {
130      return true;
131    }
132    case kTfLiteBuiltinReshape: {
133      return delegates::coreml::IsReshapeOpSupported(registration, node, context,
134                                                     options->coreml_version);
135    }
136    case kTfLiteBuiltinResizeBilinear: {
137      return delegates::coreml::IsResizeBilinearOpSupported(registration, node, context);
138    }
139    case kTfLiteBuiltinSoftmax: {
140      // Only supports when beta is 1.0 for now.
141      const auto* softmax_params = reinterpret_cast<const TfLiteSoftmaxParams*>(node->builtin_data);
142      return softmax_params != nullptr && softmax_params->beta == 1.0;
143    }
144    case kTfLiteBuiltinTanh: {
145      return true;
146    }
147    case kTfLiteBuiltinTransposeConv: {
148      return delegates::coreml::IsTransposeConvolutionOpSupported(registration, node, context);
149    }
150    default:
151      return false;
152  }
153  return false;
154}
155
156class CoreMlDelegate : public TfLiteDelegate {
157 public:
158  explicit CoreMlDelegate(const TfLiteCoreMlDelegateOptions* params)
159      : params_(params != nullptr ? *params : TfLiteCoreMlDelegateOptions()) {
160    {
161      if (@available(iOS 13.0, *)) {
162        if (params_.coreml_version != 2 && params_.coreml_version != 3) {
163          NSLog(@"coreml_version must be 2 or 3. Setting to 3.");
164          params_.coreml_version = 3;
165        }
166      } else if (@available(iOS 12.0, *)) {
167        if (params_.coreml_version != 2) {
168          NSLog(@"coreml_version must be 2 - using Core ML version 2.");
169          params_.coreml_version = 2;
170        }
171      }
172      if (params_.max_delegated_partitions <= 0) {
173        params_.max_delegated_partitions = std::numeric_limits<int>::max();
174      }
175      if (params_.min_nodes_per_partition <= 0) {
176        params_.min_nodes_per_partition = kMinNodesPerCoreMlDelegate;
177      }
178    }
179  }
180
181  TfLiteCoreMlDelegateOptions* params() { return &params_; }
182
183  bool VerifyDelegate() { return true; }
184
185 private:
186  TfLiteCoreMlDelegateOptions params_;
187};
188
189TfLiteRegistration GetCoreMlKernelRegistration() {
190  // This is the registration for the Delegate Node that gets added to
191  // the TFLite graph instead of the subGraph it replaces it.
192  // It is treated as an OP node. But in our case
193  // Init will initialize the delegate
194  // Invoke will run the delegate graph.
195  // Prepare for prearing the delegate.
196  // Free for any cleaning needed by the delegate.
197  TfLiteRegistration kernel_registration{};
198  kernel_registration.profiling_string = nullptr;
199  kernel_registration.builtin_code = kTfLiteBuiltinDelegate;
200  kernel_registration.custom_name = "TfLiteCoreMlDelegate";
201  kernel_registration.free = [](TfLiteContext* context, void* buffer) -> void {
202    delete reinterpret_cast<CoreMlDelegateKernel*>(buffer);
203  };
204  kernel_registration.init = [](TfLiteContext* context, const char* buffer,
205                                size_t length) -> void* {
206    const auto* params = reinterpret_cast<const TfLiteDelegateParams*>(buffer);
207    const auto* coreml_options = (reinterpret_cast<CoreMlDelegate*>(params->delegate))->params();
208    CoreMlDelegateKernel* coreml_kernel = new CoreMlDelegateKernel(coreml_options->coreml_version);
209    if (coreml_kernel->Init(context, params) != kTfLiteOk) {
210      delete coreml_kernel;
211      return nullptr;
212    }
213    return coreml_kernel;
214  };
215  kernel_registration.invoke = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
216    CoreMlDelegateKernel* kernel = reinterpret_cast<CoreMlDelegateKernel*>(node->user_data);
217    if (!kernel) {
218      TF_LITE_KERNEL_LOG(context, "CoreMl Kernel was not initialized");
219      return kTfLiteError;
220    }
221    return kernel->Invoke(context, node);
222  };
223  kernel_registration.prepare = [](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
224    CoreMlDelegateKernel* kernel = reinterpret_cast<CoreMlDelegateKernel*>(node->user_data);
225    if (kernel == nullptr) {
226      TF_LITE_KERNEL_LOG(context, "CoreMl Kernel was not initialized");
227      return kTfLiteError;
228    }
229    return kernel->Prepare(context, node);
230  };
231
232  return kernel_registration;
233}
234
235TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
236  const auto* params = reinterpret_cast<TfLiteCoreMlDelegateOptions*>(delegate->data_);
237
238  delegates::IsNodeSupportedFn node_supported_fn = [=](TfLiteContext* context, TfLiteNode* node,
239                                                       TfLiteRegistration* registration,
240                                                       std::string* unsupported_details) -> bool {
241    return IsNodeSupportedByDelegate(registration, node, context, params);
242  };
243
244  delegates::FP16GraphPartitionHelper partition_helper(context, node_supported_fn);
245  TF_LITE_ENSURE_STATUS(partition_helper.Partition(nullptr));
246
247  std::vector<int> delegated_nodes = partition_helper.GetNodesOfFirstNLargestPartitions(
248      params->max_delegated_partitions, params->min_nodes_per_partition);
249  TFLITE_LOG_PROD(tflite::TFLITE_LOG_INFO,
250                  "CoreML delegate: %d nodes delegated out of %d nodes, "
251                  "with %d partitions.\n",
252                  delegated_nodes.size(), partition_helper.num_total_nodes(),
253                  partition_helper.num_partitions());
254  return context->ReplaceNodeSubsetsWithDelegateKernels(
255      context, GetCoreMlKernelRegistration(), BuildTfLiteIntArray(delegated_nodes).get(), delegate);
256}
257
258TfLiteDelegate* CreateCoreMlDelegate(const TfLiteCoreMlDelegateOptions* options) {
259  TfLiteDelegate* delegate = new CoreMlDelegate(options);
260  if (!static_cast<CoreMlDelegate*>(delegate)->VerifyDelegate()) {
261    delete delegate;
262    return nullptr;
263  }
264
265  delegate->data_ = static_cast<tflite::CoreMlDelegate*>(delegate)->params();
266  delegate->flags = kTfLiteDelegateFlagsNone;
267  delegate->Prepare = &DelegatePrepare;
268  delegate->CopyFromBufferHandle = nullptr;
269  delegate->CopyToBufferHandle = nullptr;
270  delegate->FreeBufferHandle = nullptr;
271
272  return delegate;
273}
274}  // namespace
275}  // namespace tflite
276
277namespace {
278// utsname.machine has device identifier. For example, identifier for iPhone Xs is "iPhone11,2".
279// Since Neural Engine is only available for use on A12 and later, major device version in the
280// identifier is checked for these models:
281// A12: iPhone XS (11,2), iPad Mini - 5th Gen (11,1)
282// A12X: iPad Pro - 3rd Gen (8,1)
283// For more information, see https://www.theiphonewiki.com/wiki/Models
284bool IsNeuralEngineAvailable() {
285  struct utsname system_info;
286  uname(&system_info);
287
288  if (strncmp("iPad", system_info.machine, 4) == 0) {
289    const int major_version = atoi(system_info.machine + 4);
290    return major_version >= 8;  // There are no device between iPad 8 and 11.
291  } else if (strncmp("iPhone", system_info.machine, 6) == 0) {
292    const int major_version = atoi(system_info.machine + 6);
293    return major_version >= 11;
294  }
295  return false;
296}
297
298}  // namespace
299
300TfLiteDelegate* TfLiteCoreMlDelegateCreate(const TfLiteCoreMlDelegateOptions* options) {
301  if (@available(iOS 12.0, *)) {
302    if (options->enabled_devices == TfLiteCoreMlDelegateDevicesWithNeuralEngine &&
303        !IsNeuralEngineAvailable()) {
304      NSLog(@"This device does not have Neural Engine, so Core ML delegate will not be enabled. "
305             "If you want to run Core ML delegate anyway, set enabled_devices option to "
306             "TfLiteCoreMlDelegateAllDevices (or enabledDevices to .allDevices in Swift).");
307      return nullptr;
308    }
309    return tflite::CreateCoreMlDelegate(options);
310  } else {
311    NSLog(@"Core ML delegate is not supported in this iOS version. "
312           "Minimum required iOS version is 12.0.");
313    return nullptr;
314  }
315}
316
317void TfLiteCoreMlDelegateDelete(TfLiteDelegate* delegate) { delete delegate; }
318