1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/c/c_api.h"
16
17 #include <memory>
18 #include <mutex> // NOLINT
19 #include <utility>
20
21 #include "tensorflow/lite/builtin_ops.h"
22 #include "tensorflow/lite/c/c_api_internal.h"
23 #include "tensorflow/lite/c/common_internal.h"
24 #include "tensorflow/lite/create_op_resolver.h"
25 #include "tensorflow/lite/delegates/interpreter_utils.h"
26 #include "tensorflow/lite/delegates/nnapi/nnapi_delegate.h"
27 #include "tensorflow/lite/error_reporter.h"
28 #include "tensorflow/lite/interpreter.h"
29 #include "tensorflow/lite/kernels/internal/compatibility.h"
30 #include "tensorflow/lite/model.h"
31 #include "tensorflow/lite/version.h"
32
33 namespace {
34 class CallbackErrorReporter : public tflite::ErrorReporter {
35 public:
CallbackErrorReporter(TfLiteErrorReporterCallback callback)36 explicit CallbackErrorReporter(TfLiteErrorReporterCallback callback)
37 : callback_(callback) {}
38
Report(const char * format,va_list args)39 int Report(const char* format, va_list args) override {
40 callback_.error_reporter(callback_.user_data, format, args);
41 return 0;
42 }
43
44 private:
45 TfLiteErrorReporterCallback callback_;
46 };
47
48 } // namespace
49
50 extern "C" {
51
52 // LINT.IfChange
53
TfLiteVersion()54 const char* TfLiteVersion() { return TFLITE_VERSION_STRING; }
55
TfLiteModelCreate(const void * model_data,size_t model_size)56 TfLiteModel* TfLiteModelCreate(const void* model_data, size_t model_size) {
57 auto model = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(
58 static_cast<const char*>(model_data), model_size);
59 std::shared_ptr<const tflite::FlatBufferModel> shared_model(model.release());
60 return shared_model ? new TfLiteModel{std::move(shared_model)} : nullptr;
61 }
62
TfLiteModelCreateFromFile(const char * model_path)63 TfLiteModel* TfLiteModelCreateFromFile(const char* model_path) {
64 auto model = tflite::FlatBufferModel::VerifyAndBuildFromFile(model_path);
65 std::shared_ptr<const tflite::FlatBufferModel> shared_model(model.release());
66 return shared_model ? new TfLiteModel{std::move(shared_model)} : nullptr;
67 }
68
TfLiteModelDelete(TfLiteModel * model)69 void TfLiteModelDelete(TfLiteModel* model) { delete model; }
70
TfLiteRegistrationExternalCreate(const char * custom_name,const int version)71 TfLiteRegistrationExternal* TfLiteRegistrationExternalCreate(
72 const char* custom_name, const int version) {
73 return new TfLiteRegistrationExternal{custom_name, version};
74 }
75
TfLiteRegistrationExternalDelete(TfLiteRegistrationExternal * reg)76 void TfLiteRegistrationExternalDelete(TfLiteRegistrationExternal* reg) {
77 delete reg;
78 }
79
TfLiteRegistrationExternalSetInit(TfLiteRegistrationExternal * registration,void * (* init)(TfLiteOpaqueContext * context,const char * buffer,size_t length))80 void TfLiteRegistrationExternalSetInit(
81 TfLiteRegistrationExternal* registration,
82 void* (*init)(TfLiteOpaqueContext* context, const char* buffer,
83 size_t length)) {
84 registration->init = init;
85 }
86
TfLiteRegistrationExternalSetFree(TfLiteRegistrationExternal * registration,void (* free)(TfLiteOpaqueContext * context,void * data))87 void TfLiteRegistrationExternalSetFree(
88 TfLiteRegistrationExternal* registration,
89 void (*free)(TfLiteOpaqueContext* context, void* data)) {
90 registration->free = free;
91 }
92
TfLiteRegistrationExternalSetPrepare(TfLiteRegistrationExternal * registration,TfLiteStatus (* prepare)(TfLiteOpaqueContext * context,TfLiteOpaqueNode * node))93 void TfLiteRegistrationExternalSetPrepare(
94 TfLiteRegistrationExternal* registration,
95 TfLiteStatus (*prepare)(TfLiteOpaqueContext* context,
96 TfLiteOpaqueNode* node)) {
97 registration->prepare = prepare;
98 }
99
TfLiteRegistrationExternalSetInvoke(TfLiteRegistrationExternal * registration,TfLiteStatus (* invoke)(TfLiteOpaqueContext * context,TfLiteOpaqueNode * node))100 void TfLiteRegistrationExternalSetInvoke(
101 TfLiteRegistrationExternal* registration,
102 TfLiteStatus (*invoke)(TfLiteOpaqueContext* context,
103 TfLiteOpaqueNode* node)) {
104 registration->invoke = invoke;
105 }
106
TfLiteInterpreterOptionsCreate()107 TfLiteInterpreterOptions* TfLiteInterpreterOptionsCreate() {
108 return new TfLiteInterpreterOptions{};
109 }
110
TfLiteInterpreterOptionsDelete(TfLiteInterpreterOptions * options)111 void TfLiteInterpreterOptionsDelete(TfLiteInterpreterOptions* options) {
112 delete options;
113 }
114
TfLiteInterpreterOptionsSetNumThreads(TfLiteInterpreterOptions * options,int32_t num_threads)115 void TfLiteInterpreterOptionsSetNumThreads(TfLiteInterpreterOptions* options,
116 int32_t num_threads) {
117 options->num_threads = num_threads;
118 }
119
TfLiteInterpreterOptionsAddDelegate(TfLiteInterpreterOptions * options,TfLiteDelegate * delegate)120 void TfLiteInterpreterOptionsAddDelegate(TfLiteInterpreterOptions* options,
121 TfLiteDelegate* delegate) {
122 options->delegates.push_back(delegate);
123 }
124
TfLiteInterpreterOptionsSetErrorReporter(TfLiteInterpreterOptions * options,void (* reporter)(void * user_data,const char * format,va_list args),void * user_data)125 void TfLiteInterpreterOptionsSetErrorReporter(
126 TfLiteInterpreterOptions* options,
127 void (*reporter)(void* user_data, const char* format, va_list args),
128 void* user_data) {
129 options->error_reporter_callback.error_reporter = reporter;
130 options->error_reporter_callback.user_data = user_data;
131 }
132
TfLiteInterpreterOptionsAddRegistrationExternal(TfLiteInterpreterOptions * options,TfLiteRegistrationExternal * registration)133 void TfLiteInterpreterOptionsAddRegistrationExternal(
134 TfLiteInterpreterOptions* options,
135 TfLiteRegistrationExternal* registration) {
136 options->op_registrations.push_back(registration);
137 }
138
InitTfLiteRegistration(TfLiteRegistration * registration,TfLiteRegistrationExternal * registration_external)139 static void InitTfLiteRegistration(
140 TfLiteRegistration* registration,
141 TfLiteRegistrationExternal* registration_external) {
142 registration->custom_name = registration_external->custom_name;
143 registration->version = registration_external->version;
144 registration->registration_external = registration_external;
145 }
146
TfLiteInterpreterCreate(const TfLiteModel * model,const TfLiteInterpreterOptions * optional_options)147 TfLiteInterpreter* TfLiteInterpreterCreate(
148 const TfLiteModel* model,
149 const TfLiteInterpreterOptions* optional_options) {
150 std::unique_ptr<tflite::MutableOpResolver> resolver =
151 tflite::CreateOpResolver();
152 return tflite::internal::InterpreterCreateWithOpResolver(
153 model, optional_options, resolver.get());
154 }
155
TfLiteInterpreterDelete(TfLiteInterpreter * interpreter)156 void TfLiteInterpreterDelete(TfLiteInterpreter* interpreter) {
157 delete interpreter;
158 }
159
TfLiteInterpreterGetInputTensorCount(const TfLiteInterpreter * interpreter)160 int32_t TfLiteInterpreterGetInputTensorCount(
161 const TfLiteInterpreter* interpreter) {
162 return static_cast<int32_t>(interpreter->impl->inputs().size());
163 }
164
TfLiteInterpreterGetInputTensor(const TfLiteInterpreter * interpreter,int32_t input_index)165 TfLiteTensor* TfLiteInterpreterGetInputTensor(
166 const TfLiteInterpreter* interpreter, int32_t input_index) {
167 return interpreter->impl->tensor(interpreter->impl->inputs()[input_index]);
168 }
169
TfLiteInterpreterResizeInputTensor(TfLiteInterpreter * interpreter,int32_t input_index,const int * input_dims,int32_t input_dims_size)170 TfLiteStatus TfLiteInterpreterResizeInputTensor(TfLiteInterpreter* interpreter,
171 int32_t input_index,
172 const int* input_dims,
173 int32_t input_dims_size) {
174 std::vector<int> dims{input_dims, input_dims + input_dims_size};
175 return interpreter->impl->ResizeInputTensor(
176 interpreter->impl->inputs()[input_index], dims);
177 }
178
TfLiteInterpreterAllocateTensors(TfLiteInterpreter * interpreter)179 TfLiteStatus TfLiteInterpreterAllocateTensors(TfLiteInterpreter* interpreter) {
180 return interpreter->impl->AllocateTensors();
181 }
182
TfLiteInterpreterInvoke(TfLiteInterpreter * interpreter)183 TfLiteStatus TfLiteInterpreterInvoke(TfLiteInterpreter* interpreter) {
184 if (interpreter->enable_delegate_fallback) {
185 return tflite::delegates::InterpreterUtils::InvokeWithCPUFallback(
186 interpreter->impl.get());
187 } else {
188 return interpreter->impl->Invoke();
189 }
190 }
191
TfLiteInterpreterGetOutputTensorCount(const TfLiteInterpreter * interpreter)192 int32_t TfLiteInterpreterGetOutputTensorCount(
193 const TfLiteInterpreter* interpreter) {
194 return static_cast<int32_t>(interpreter->impl->outputs().size());
195 }
196
TfLiteInterpreterGetOutputTensor(const TfLiteInterpreter * interpreter,int32_t output_index)197 const TfLiteTensor* TfLiteInterpreterGetOutputTensor(
198 const TfLiteInterpreter* interpreter, int32_t output_index) {
199 return interpreter->impl->tensor(interpreter->impl->outputs()[output_index]);
200 }
201
TfLiteTensorType(const TfLiteTensor * tensor)202 TfLiteType TfLiteTensorType(const TfLiteTensor* tensor) { return tensor->type; }
203
TfLiteTensorNumDims(const TfLiteTensor * tensor)204 int32_t TfLiteTensorNumDims(const TfLiteTensor* tensor) {
205 return tensor->dims->size;
206 }
207
TfLiteTensorDim(const TfLiteTensor * tensor,int32_t dim_index)208 int32_t TfLiteTensorDim(const TfLiteTensor* tensor, int32_t dim_index) {
209 return tensor->dims->data[dim_index];
210 }
211
TfLiteTensorByteSize(const TfLiteTensor * tensor)212 size_t TfLiteTensorByteSize(const TfLiteTensor* tensor) {
213 return tensor->bytes;
214 }
215
TfLiteTensorData(const TfLiteTensor * tensor)216 void* TfLiteTensorData(const TfLiteTensor* tensor) { return tensor->data.raw; }
217
TfLiteTensorName(const TfLiteTensor * tensor)218 const char* TfLiteTensorName(const TfLiteTensor* tensor) {
219 return tensor->name;
220 }
221
TfLiteTensorQuantizationParams(const TfLiteTensor * tensor)222 TfLiteQuantizationParams TfLiteTensorQuantizationParams(
223 const TfLiteTensor* tensor) {
224 return tensor->params;
225 }
226
TfLiteTensorCopyFromBuffer(TfLiteTensor * tensor,const void * input_data,size_t input_data_size)227 TfLiteStatus TfLiteTensorCopyFromBuffer(TfLiteTensor* tensor,
228 const void* input_data,
229 size_t input_data_size) {
230 if (tensor->bytes != input_data_size) {
231 return kTfLiteError;
232 }
233 memcpy(tensor->data.raw, input_data, input_data_size);
234 return kTfLiteOk;
235 }
236
TfLiteTensorCopyToBuffer(const TfLiteTensor * tensor,void * output_data,size_t output_data_size)237 TfLiteStatus TfLiteTensorCopyToBuffer(const TfLiteTensor* tensor,
238 void* output_data,
239 size_t output_data_size) {
240 if (tensor->bytes != output_data_size) {
241 return kTfLiteError;
242 }
243 memcpy(output_data, tensor->data.raw, output_data_size);
244 return kTfLiteOk;
245 }
246
247 // LINT.ThenChange(//tensorflow/lite/experimental/examples/unity/TensorFlowLitePlugin/Assets/TensorFlowLite/SDK/Scripts/Interpreter.cs)
248
249 } // extern "C"
250
251 namespace tflite {
252 namespace internal {
253
254 // Implementation of CallbackOpResolver class which is defined in
255 // c_api_internal.h. CallbackOpResolver is a (C++) `tflite::OpResolver` that
256 // forwards the methods to (C ABI) callback functions from a
257 // `TfLiteOpResolverCallbacks` struct.
258
259 // FindOp for builtin op query.
FindOp(tflite::BuiltinOperator op,int version) const260 const TfLiteRegistration* CallbackOpResolver::FindOp(tflite::BuiltinOperator op,
261 int version) const {
262 // Use Registration V2 API to find op.
263 if (op_resolver_callbacks_.find_builtin_op) {
264 return op_resolver_callbacks_.find_builtin_op(
265 op_resolver_callbacks_.user_data,
266 static_cast<TfLiteBuiltinOperator>(op), version);
267 }
268 if (op_resolver_callbacks_.find_builtin_op_v1) {
269 // Check if cached Registration is available.
270 std::lock_guard<std::mutex> lock(mutex_);
271 for (const auto& created_registration : temporary_builtin_registrations_) {
272 if (created_registration->builtin_code == op &&
273 created_registration->version == version) {
274 return created_registration.get();
275 }
276 }
277 // Get a Registration V1 object and create a Registration V2 object.
278 const TfLiteRegistration_V1* reg_v1 =
279 op_resolver_callbacks_.find_builtin_op_v1(
280 op_resolver_callbacks_.user_data,
281 static_cast<TfLiteBuiltinOperator>(op), version);
282 if (reg_v1) {
283 TfLiteRegistration* new_registration = new TfLiteRegistration();
284 memcpy(new_registration, reg_v1, sizeof(TfLiteRegistration_V1));
285 new_registration->registration_external = nullptr;
286 temporary_builtin_registrations_.push_back(
287 std::unique_ptr<TfLiteRegistration>(new_registration));
288 return new_registration;
289 }
290 }
291 return nullptr;
292 }
293
294 // FindOp for custom op query.
FindOp(const char * op,int version) const295 const TfLiteRegistration* CallbackOpResolver::FindOp(const char* op,
296 int version) const {
297 // Use Registration V2 API to find op.
298 if (op_resolver_callbacks_.find_custom_op) {
299 return op_resolver_callbacks_.find_custom_op(
300 op_resolver_callbacks_.user_data, op, version);
301 }
302 if (op_resolver_callbacks_.find_custom_op_v1) {
303 // Check if cached Registration is available.
304 std::lock_guard<std::mutex> lock(mutex_);
305 for (const auto& created_registration : temporary_custom_registrations_) {
306 if (strcmp(created_registration->custom_name, op) == 0 &&
307 created_registration->version == version) {
308 return created_registration.get();
309 }
310 }
311 // Get a Registration V1 object and create a Registration V2 object.
312 const TfLiteRegistration_V1* reg_v1 =
313 op_resolver_callbacks_.find_custom_op_v1(
314 op_resolver_callbacks_.user_data, op, version);
315 if (reg_v1) {
316 TfLiteRegistration* new_registration = new TfLiteRegistration();
317 memcpy(new_registration, reg_v1, sizeof(TfLiteRegistration_V1));
318 new_registration->registration_external = nullptr;
319 temporary_custom_registrations_.push_back(
320 std::unique_ptr<TfLiteRegistration>(new_registration));
321 return new_registration;
322 }
323 }
324 return nullptr;
325 }
326
InterpreterCreateWithOpResolver(const TfLiteModel * model,const TfLiteInterpreterOptions * optional_options,tflite::MutableOpResolver * mutable_resolver)327 TfLiteInterpreter* InterpreterCreateWithOpResolver(
328 const TfLiteModel* model, const TfLiteInterpreterOptions* optional_options,
329 tflite::MutableOpResolver* mutable_resolver) {
330 TFLITE_DCHECK_NE(mutable_resolver, nullptr);
331 if (!model || !model->impl) {
332 return nullptr;
333 }
334
335 std::unique_ptr<tflite::ErrorReporter> optional_error_reporter;
336 if (optional_options &&
337 optional_options->error_reporter_callback.error_reporter != nullptr) {
338 optional_error_reporter = std::make_unique<CallbackErrorReporter>(
339 optional_options->error_reporter_callback);
340 }
341
342 // By default, we use the provided mutable_op_resolver, adding any builtin or
343 // custom ops registered with `TfLiteInterpreterOptionsAddBuiltinOp` and/or
344 // `TfLiteInterpreterOptionsAddCustomOp`.
345 tflite::OpResolver* op_resolver = mutable_resolver;
346 if (optional_options) {
347 mutable_resolver->AddAll(optional_options->mutable_op_resolver);
348 for (auto* registration_external : optional_options->op_registrations) {
349 TfLiteRegistration registration{};
350 InitTfLiteRegistration(®istration, registration_external);
351 mutable_resolver->AddCustom(registration_external->custom_name,
352 ®istration,
353 registration_external->version);
354 }
355 }
356 // However, if `TfLiteInterpreterOptionsSetOpResolver` has been called with
357 // a non-null callback parameter, then we instead use a
358 // `CallbackOpResolver` that will forward to the callbacks provided there.
359 CallbackOpResolver callback_op_resolver;
360 if (optional_options &&
361 (optional_options->op_resolver_callbacks.find_builtin_op != nullptr ||
362 optional_options->op_resolver_callbacks.find_custom_op != nullptr ||
363 optional_options->op_resolver_callbacks.find_builtin_op_v1 != nullptr ||
364 optional_options->op_resolver_callbacks.find_custom_op_v1 != nullptr)) {
365 callback_op_resolver.SetCallbacks(optional_options->op_resolver_callbacks);
366 op_resolver = &callback_op_resolver;
367 }
368
369 tflite::ErrorReporter* error_reporter = optional_error_reporter
370 ? optional_error_reporter.get()
371 : tflite::DefaultErrorReporter();
372 tflite::InterpreterBuilder builder(model->impl->GetModel(), *op_resolver,
373 error_reporter);
374
375 std::unique_ptr<tflite::Interpreter> interpreter;
376 if (builder(&interpreter) != kTfLiteOk) {
377 return nullptr;
378 }
379
380 if (optional_options) {
381 if (optional_options->num_threads !=
382 TfLiteInterpreterOptions::kDefaultNumThreads) {
383 interpreter->SetNumThreads(optional_options->num_threads);
384 }
385
386 if (optional_options->use_nnapi) {
387 if (interpreter->ModifyGraphWithDelegate(tflite::NnApiDelegate()) !=
388 kTfLiteOk) {
389 return nullptr;
390 }
391 }
392
393 for (auto* delegate : optional_options->delegates) {
394 if (interpreter->ModifyGraphWithDelegate(delegate) != kTfLiteOk) {
395 return nullptr;
396 }
397 }
398 }
399
400 bool enable_delegate_fallback =
401 optional_options != nullptr && optional_options->enable_delegate_fallback;
402
403 return new TfLiteInterpreter{model->impl, std::move(optional_error_reporter),
404 std::move(interpreter),
405 enable_delegate_fallback};
406 }
407
408 } // namespace internal
409 } // namespace tflite
410