xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/c/c_api.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/c/c_api.h"
16 
17 #include <memory>
18 #include <mutex>  // NOLINT
19 #include <utility>
20 
21 #include "tensorflow/lite/builtin_ops.h"
22 #include "tensorflow/lite/c/c_api_internal.h"
23 #include "tensorflow/lite/c/common_internal.h"
24 #include "tensorflow/lite/create_op_resolver.h"
25 #include "tensorflow/lite/delegates/interpreter_utils.h"
26 #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
27 #include "tensorflow/lite/error_reporter.h"
28 #include "tensorflow/lite/interpreter.h"
29 #include "tensorflow/lite/kernels/internal/compatibility.h"
30 #include "tensorflow/lite/model.h"
31 #include "tensorflow/lite/version.h"
32 
33 namespace {
34 class CallbackErrorReporter : public tflite::ErrorReporter {
35  public:
CallbackErrorReporter(TfLiteErrorReporterCallback callback)36   explicit CallbackErrorReporter(TfLiteErrorReporterCallback callback)
37       : callback_(callback) {}
38 
Report(const char * format,va_list args)39   int Report(const char* format, va_list args) override {
40     callback_.error_reporter(callback_.user_data, format, args);
41     return 0;
42   }
43 
44  private:
45   TfLiteErrorReporterCallback callback_;
46 };
47 
48 }  // namespace
49 
50 extern "C" {
51 
52 // LINT.IfChange
53 
TfLiteVersion()54 const char* TfLiteVersion() { return TFLITE_VERSION_STRING; }
55 
TfLiteModelCreate(const void * model_data,size_t model_size)56 TfLiteModel* TfLiteModelCreate(const void* model_data, size_t model_size) {
57   auto model = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(
58       static_cast<const char*>(model_data), model_size);
59   std::shared_ptr<const tflite::FlatBufferModel> shared_model(model.release());
60   return shared_model ? new TfLiteModel{std::move(shared_model)} : nullptr;
61 }
62 
TfLiteModelCreateFromFile(const char * model_path)63 TfLiteModel* TfLiteModelCreateFromFile(const char* model_path) {
64   auto model = tflite::FlatBufferModel::VerifyAndBuildFromFile(model_path);
65   std::shared_ptr<const tflite::FlatBufferModel> shared_model(model.release());
66   return shared_model ? new TfLiteModel{std::move(shared_model)} : nullptr;
67 }
68 
TfLiteModelDelete(TfLiteModel * model)69 void TfLiteModelDelete(TfLiteModel* model) { delete model; }
70 
TfLiteRegistrationExternalCreate(const char * custom_name,const int version)71 TfLiteRegistrationExternal* TfLiteRegistrationExternalCreate(
72     const char* custom_name, const int version) {
73   return new TfLiteRegistrationExternal{custom_name, version};
74 }
75 
TfLiteRegistrationExternalDelete(TfLiteRegistrationExternal * reg)76 void TfLiteRegistrationExternalDelete(TfLiteRegistrationExternal* reg) {
77   delete reg;
78 }
79 
TfLiteRegistrationExternalSetInit(TfLiteRegistrationExternal * registration,void * (* init)(TfLiteOpaqueContext * context,const char * buffer,size_t length))80 void TfLiteRegistrationExternalSetInit(
81     TfLiteRegistrationExternal* registration,
82     void* (*init)(TfLiteOpaqueContext* context, const char* buffer,
83                   size_t length)) {
84   registration->init = init;
85 }
86 
TfLiteRegistrationExternalSetFree(TfLiteRegistrationExternal * registration,void (* free)(TfLiteOpaqueContext * context,void * data))87 void TfLiteRegistrationExternalSetFree(
88     TfLiteRegistrationExternal* registration,
89     void (*free)(TfLiteOpaqueContext* context, void* data)) {
90   registration->free = free;
91 }
92 
TfLiteRegistrationExternalSetPrepare(TfLiteRegistrationExternal * registration,TfLiteStatus (* prepare)(TfLiteOpaqueContext * context,TfLiteOpaqueNode * node))93 void TfLiteRegistrationExternalSetPrepare(
94     TfLiteRegistrationExternal* registration,
95     TfLiteStatus (*prepare)(TfLiteOpaqueContext* context,
96                             TfLiteOpaqueNode* node)) {
97   registration->prepare = prepare;
98 }
99 
TfLiteRegistrationExternalSetInvoke(TfLiteRegistrationExternal * registration,TfLiteStatus (* invoke)(TfLiteOpaqueContext * context,TfLiteOpaqueNode * node))100 void TfLiteRegistrationExternalSetInvoke(
101     TfLiteRegistrationExternal* registration,
102     TfLiteStatus (*invoke)(TfLiteOpaqueContext* context,
103                            TfLiteOpaqueNode* node)) {
104   registration->invoke = invoke;
105 }
106 
TfLiteInterpreterOptionsCreate()107 TfLiteInterpreterOptions* TfLiteInterpreterOptionsCreate() {
108   return new TfLiteInterpreterOptions{};
109 }
110 
TfLiteInterpreterOptionsDelete(TfLiteInterpreterOptions * options)111 void TfLiteInterpreterOptionsDelete(TfLiteInterpreterOptions* options) {
112   delete options;
113 }
114 
TfLiteInterpreterOptionsSetNumThreads(TfLiteInterpreterOptions * options,int32_t num_threads)115 void TfLiteInterpreterOptionsSetNumThreads(TfLiteInterpreterOptions* options,
116                                            int32_t num_threads) {
117   options->num_threads = num_threads;
118 }
119 
TfLiteInterpreterOptionsAddDelegate(TfLiteInterpreterOptions * options,TfLiteDelegate * delegate)120 void TfLiteInterpreterOptionsAddDelegate(TfLiteInterpreterOptions* options,
121                                          TfLiteDelegate* delegate) {
122   options->delegates.push_back(delegate);
123 }
124 
TfLiteInterpreterOptionsSetErrorReporter(TfLiteInterpreterOptions * options,void (* reporter)(void * user_data,const char * format,va_list args),void * user_data)125 void TfLiteInterpreterOptionsSetErrorReporter(
126     TfLiteInterpreterOptions* options,
127     void (*reporter)(void* user_data, const char* format, va_list args),
128     void* user_data) {
129   options->error_reporter_callback.error_reporter = reporter;
130   options->error_reporter_callback.user_data = user_data;
131 }
132 
TfLiteInterpreterOptionsAddRegistrationExternal(TfLiteInterpreterOptions * options,TfLiteRegistrationExternal * registration)133 void TfLiteInterpreterOptionsAddRegistrationExternal(
134     TfLiteInterpreterOptions* options,
135     TfLiteRegistrationExternal* registration) {
136   options->op_registrations.push_back(registration);
137 }
138 
InitTfLiteRegistration(TfLiteRegistration * registration,TfLiteRegistrationExternal * registration_external)139 static void InitTfLiteRegistration(
140     TfLiteRegistration* registration,
141     TfLiteRegistrationExternal* registration_external) {
142   registration->custom_name = registration_external->custom_name;
143   registration->version = registration_external->version;
144   registration->registration_external = registration_external;
145 }
146 
TfLiteInterpreterCreate(const TfLiteModel * model,const TfLiteInterpreterOptions * optional_options)147 TfLiteInterpreter* TfLiteInterpreterCreate(
148     const TfLiteModel* model,
149     const TfLiteInterpreterOptions* optional_options) {
150   std::unique_ptr<tflite::MutableOpResolver> resolver =
151       tflite::CreateOpResolver();
152   return tflite::internal::InterpreterCreateWithOpResolver(
153       model, optional_options, resolver.get());
154 }
155 
TfLiteInterpreterDelete(TfLiteInterpreter * interpreter)156 void TfLiteInterpreterDelete(TfLiteInterpreter* interpreter) {
157   delete interpreter;
158 }
159 
TfLiteInterpreterGetInputTensorCount(const TfLiteInterpreter * interpreter)160 int32_t TfLiteInterpreterGetInputTensorCount(
161     const TfLiteInterpreter* interpreter) {
162   return static_cast<int32_t>(interpreter->impl->inputs().size());
163 }
164 
TfLiteInterpreterGetInputTensor(const TfLiteInterpreter * interpreter,int32_t input_index)165 TfLiteTensor* TfLiteInterpreterGetInputTensor(
166     const TfLiteInterpreter* interpreter, int32_t input_index) {
167   return interpreter->impl->tensor(interpreter->impl->inputs()[input_index]);
168 }
169 
TfLiteInterpreterResizeInputTensor(TfLiteInterpreter * interpreter,int32_t input_index,const int * input_dims,int32_t input_dims_size)170 TfLiteStatus TfLiteInterpreterResizeInputTensor(TfLiteInterpreter* interpreter,
171                                                 int32_t input_index,
172                                                 const int* input_dims,
173                                                 int32_t input_dims_size) {
174   std::vector<int> dims{input_dims, input_dims + input_dims_size};
175   return interpreter->impl->ResizeInputTensor(
176       interpreter->impl->inputs()[input_index], dims);
177 }
178 
TfLiteInterpreterAllocateTensors(TfLiteInterpreter * interpreter)179 TfLiteStatus TfLiteInterpreterAllocateTensors(TfLiteInterpreter* interpreter) {
180   return interpreter->impl->AllocateTensors();
181 }
182 
TfLiteInterpreterInvoke(TfLiteInterpreter * interpreter)183 TfLiteStatus TfLiteInterpreterInvoke(TfLiteInterpreter* interpreter) {
184   if (interpreter->enable_delegate_fallback) {
185     return tflite::delegates::InterpreterUtils::InvokeWithCPUFallback(
186         interpreter->impl.get());
187   } else {
188     return interpreter->impl->Invoke();
189   }
190 }
191 
TfLiteInterpreterGetOutputTensorCount(const TfLiteInterpreter * interpreter)192 int32_t TfLiteInterpreterGetOutputTensorCount(
193     const TfLiteInterpreter* interpreter) {
194   return static_cast<int32_t>(interpreter->impl->outputs().size());
195 }
196 
TfLiteInterpreterGetOutputTensor(const TfLiteInterpreter * interpreter,int32_t output_index)197 const TfLiteTensor* TfLiteInterpreterGetOutputTensor(
198     const TfLiteInterpreter* interpreter, int32_t output_index) {
199   return interpreter->impl->tensor(interpreter->impl->outputs()[output_index]);
200 }
201 
TfLiteTensorType(const TfLiteTensor * tensor)202 TfLiteType TfLiteTensorType(const TfLiteTensor* tensor) { return tensor->type; }
203 
TfLiteTensorNumDims(const TfLiteTensor * tensor)204 int32_t TfLiteTensorNumDims(const TfLiteTensor* tensor) {
205   return tensor->dims->size;
206 }
207 
TfLiteTensorDim(const TfLiteTensor * tensor,int32_t dim_index)208 int32_t TfLiteTensorDim(const TfLiteTensor* tensor, int32_t dim_index) {
209   return tensor->dims->data[dim_index];
210 }
211 
TfLiteTensorByteSize(const TfLiteTensor * tensor)212 size_t TfLiteTensorByteSize(const TfLiteTensor* tensor) {
213   return tensor->bytes;
214 }
215 
TfLiteTensorData(const TfLiteTensor * tensor)216 void* TfLiteTensorData(const TfLiteTensor* tensor) { return tensor->data.raw; }
217 
TfLiteTensorName(const TfLiteTensor * tensor)218 const char* TfLiteTensorName(const TfLiteTensor* tensor) {
219   return tensor->name;
220 }
221 
TfLiteTensorQuantizationParams(const TfLiteTensor * tensor)222 TfLiteQuantizationParams TfLiteTensorQuantizationParams(
223     const TfLiteTensor* tensor) {
224   return tensor->params;
225 }
226 
TfLiteTensorCopyFromBuffer(TfLiteTensor * tensor,const void * input_data,size_t input_data_size)227 TfLiteStatus TfLiteTensorCopyFromBuffer(TfLiteTensor* tensor,
228                                         const void* input_data,
229                                         size_t input_data_size) {
230   if (tensor->bytes != input_data_size) {
231     return kTfLiteError;
232   }
233   memcpy(tensor->data.raw, input_data, input_data_size);
234   return kTfLiteOk;
235 }
236 
TfLiteTensorCopyToBuffer(const TfLiteTensor * tensor,void * output_data,size_t output_data_size)237 TfLiteStatus TfLiteTensorCopyToBuffer(const TfLiteTensor* tensor,
238                                       void* output_data,
239                                       size_t output_data_size) {
240   if (tensor->bytes != output_data_size) {
241     return kTfLiteError;
242   }
243   memcpy(output_data, tensor->data.raw, output_data_size);
244   return kTfLiteOk;
245 }
246 
247 // LINT.ThenChange(//tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs)
248 
249 }  // extern "C"
250 
251 namespace tflite {
252 namespace internal {
253 
254 // Implementation of CallbackOpResolver class which is defined in
255 // c_api_internal.h. CallbackOpResolver is a (C++) `tflite::OpResolver` that
256 // forwards the methods to (C ABI) callback functions from a
257 // `TfLiteOpResolverCallbacks` struct.
258 
259 // FindOp for builtin op query.
FindOp(tflite::BuiltinOperator op,int version) const260 const TfLiteRegistration* CallbackOpResolver::FindOp(tflite::BuiltinOperator op,
261                                                      int version) const {
262   // Use Registration V2 API to find op.
263   if (op_resolver_callbacks_.find_builtin_op) {
264     return op_resolver_callbacks_.find_builtin_op(
265         op_resolver_callbacks_.user_data,
266         static_cast<TfLiteBuiltinOperator>(op), version);
267   }
268   if (op_resolver_callbacks_.find_builtin_op_v1) {
269     // Check if cached Registration is available.
270     std::lock_guard<std::mutex> lock(mutex_);
271     for (const auto& created_registration : temporary_builtin_registrations_) {
272       if (created_registration->builtin_code == op &&
273           created_registration->version == version) {
274         return created_registration.get();
275       }
276     }
277     // Get a Registration V1 object and create a Registration V2 object.
278     const TfLiteRegistration_V1* reg_v1 =
279         op_resolver_callbacks_.find_builtin_op_v1(
280             op_resolver_callbacks_.user_data,
281             static_cast<TfLiteBuiltinOperator>(op), version);
282     if (reg_v1) {
283       TfLiteRegistration* new_registration = new TfLiteRegistration();
284       memcpy(new_registration, reg_v1, sizeof(TfLiteRegistration_V1));
285       new_registration->registration_external = nullptr;
286       temporary_builtin_registrations_.push_back(
287           std::unique_ptr<TfLiteRegistration>(new_registration));
288       return new_registration;
289     }
290   }
291   return nullptr;
292 }
293 
294 // FindOp for custom op query.
FindOp(const char * op,int version) const295 const TfLiteRegistration* CallbackOpResolver::FindOp(const char* op,
296                                                      int version) const {
297   // Use Registration V2 API to find op.
298   if (op_resolver_callbacks_.find_custom_op) {
299     return op_resolver_callbacks_.find_custom_op(
300         op_resolver_callbacks_.user_data, op, version);
301   }
302   if (op_resolver_callbacks_.find_custom_op_v1) {
303     // Check if cached Registration is available.
304     std::lock_guard<std::mutex> lock(mutex_);
305     for (const auto& created_registration : temporary_custom_registrations_) {
306       if (strcmp(created_registration->custom_name, op) == 0 &&
307           created_registration->version == version) {
308         return created_registration.get();
309       }
310     }
311     // Get a Registration V1 object and create a Registration V2 object.
312     const TfLiteRegistration_V1* reg_v1 =
313         op_resolver_callbacks_.find_custom_op_v1(
314             op_resolver_callbacks_.user_data, op, version);
315     if (reg_v1) {
316       TfLiteRegistration* new_registration = new TfLiteRegistration();
317       memcpy(new_registration, reg_v1, sizeof(TfLiteRegistration_V1));
318       new_registration->registration_external = nullptr;
319       temporary_custom_registrations_.push_back(
320           std::unique_ptr<TfLiteRegistration>(new_registration));
321       return new_registration;
322     }
323   }
324   return nullptr;
325 }
326 
InterpreterCreateWithOpResolver(const TfLiteModel * model,const TfLiteInterpreterOptions * optional_options,tflite::MutableOpResolver * mutable_resolver)327 TfLiteInterpreter* InterpreterCreateWithOpResolver(
328     const TfLiteModel* model, const TfLiteInterpreterOptions* optional_options,
329     tflite::MutableOpResolver* mutable_resolver) {
330   TFLITE_DCHECK_NE(mutable_resolver, nullptr);
331   if (!model || !model->impl) {
332     return nullptr;
333   }
334 
335   std::unique_ptr<tflite::ErrorReporter> optional_error_reporter;
336   if (optional_options &&
337       optional_options->error_reporter_callback.error_reporter != nullptr) {
338     optional_error_reporter = std::make_unique<CallbackErrorReporter>(
339         optional_options->error_reporter_callback);
340   }
341 
342   // By default, we use the provided mutable_op_resolver, adding any builtin or
343   // custom ops registered with `TfLiteInterpreterOptionsAddBuiltinOp` and/or
344   // `TfLiteInterpreterOptionsAddCustomOp`.
345   tflite::OpResolver* op_resolver = mutable_resolver;
346   if (optional_options) {
347     mutable_resolver->AddAll(optional_options->mutable_op_resolver);
348     for (auto* registration_external : optional_options->op_registrations) {
349       TfLiteRegistration registration{};
350       InitTfLiteRegistration(&registration, registration_external);
351       mutable_resolver->AddCustom(registration_external->custom_name,
352                                   &registration,
353                                   registration_external->version);
354     }
355   }
356   // However, if `TfLiteInterpreterOptionsSetOpResolver` has been called with
357   // a non-null callback parameter, then we instead use a
358   // `CallbackOpResolver` that will forward to the callbacks provided there.
359   CallbackOpResolver callback_op_resolver;
360   if (optional_options &&
361       (optional_options->op_resolver_callbacks.find_builtin_op != nullptr ||
362        optional_options->op_resolver_callbacks.find_custom_op != nullptr ||
363        optional_options->op_resolver_callbacks.find_builtin_op_v1 != nullptr ||
364        optional_options->op_resolver_callbacks.find_custom_op_v1 != nullptr)) {
365     callback_op_resolver.SetCallbacks(optional_options->op_resolver_callbacks);
366     op_resolver = &callback_op_resolver;
367   }
368 
369   tflite::ErrorReporter* error_reporter = optional_error_reporter
370                                               ? optional_error_reporter.get()
371                                               : tflite::DefaultErrorReporter();
372   tflite::InterpreterBuilder builder(model->impl->GetModel(), *op_resolver,
373                                      error_reporter);
374 
375   std::unique_ptr<tflite::Interpreter> interpreter;
376   if (builder(&interpreter) != kTfLiteOk) {
377     return nullptr;
378   }
379 
380   if (optional_options) {
381     if (optional_options->num_threads !=
382         TfLiteInterpreterOptions::kDefaultNumThreads) {
383       interpreter->SetNumThreads(optional_options->num_threads);
384     }
385 
386     if (optional_options->use_nnapi) {
387       if (interpreter->ModifyGraphWithDelegate(tflite::NnApiDelegate()) !=
388           kTfLiteOk) {
389         return nullptr;
390       }
391     }
392 
393     for (auto* delegate : optional_options->delegates) {
394       if (interpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) {
395         return nullptr;
396       }
397     }
398   }
399 
400   bool enable_delegate_fallback =
401       optional_options != nullptr && optional_options->enable_delegate_fallback;
402 
403   return new TfLiteInterpreter{model->impl, std::move(optional_error_reporter),
404                                std::move(interpreter),
405                                enable_delegate_fallback};
406 }
407 
408 }  // namespace internal
409 }  // namespace tflite
410