1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/c/c_api.h"
17
18 #include <stdarg.h>
19 #include <stdint.h>
20
21 #include <array>
22 #include <cmath>
23 #include <fstream>
24 #include <ios>
25 #include <string>
26 #include <vector>
27
28 #include <gtest/gtest.h>
29 #include "tensorflow/lite/c/c_api_internal.h"
30 #include "tensorflow/lite/c/c_api_opaque.h"
31 #include "tensorflow/lite/c/common.h"
32 #include "tensorflow/lite/testing/util.h"
33
34 namespace {
35
TEST(CAPI,Version)36 TEST(CAPI, Version) { EXPECT_STRNE("", TfLiteVersion()); }
37
TEST(CApiSimple,Smoke)38 TEST(CApiSimple, Smoke) {
39 TfLiteModel* model =
40 TfLiteModelCreateFromFile("tensorflow/lite/testdata/add.bin");
41 ASSERT_NE(model, nullptr);
42
43 TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();
44 ASSERT_NE(options, nullptr);
45 TfLiteInterpreterOptionsSetNumThreads(options, 2);
46
47 TfLiteInterpreter* interpreter = TfLiteInterpreterCreate(model, options);
48 ASSERT_NE(interpreter, nullptr);
49
50 // The options can be deleted immediately after interpreter creation.
51 TfLiteInterpreterOptionsDelete(options);
52
53 ASSERT_EQ(TfLiteInterpreterAllocateTensors(interpreter), kTfLiteOk);
54 ASSERT_EQ(TfLiteInterpreterGetInputTensorCount(interpreter), 1);
55 ASSERT_EQ(TfLiteInterpreterGetOutputTensorCount(interpreter), 1);
56
57 std::array<int, 1> input_dims = {2};
58 ASSERT_EQ(TfLiteInterpreterResizeInputTensor(
59 interpreter, 0, input_dims.data(), input_dims.size()),
60 kTfLiteOk);
61 ASSERT_EQ(TfLiteInterpreterAllocateTensors(interpreter), kTfLiteOk);
62
63 TfLiteTensor* input_tensor = TfLiteInterpreterGetInputTensor(interpreter, 0);
64 ASSERT_NE(input_tensor, nullptr);
65 EXPECT_EQ(TfLiteTensorType(input_tensor), kTfLiteFloat32);
66 EXPECT_EQ(TfLiteTensorNumDims(input_tensor), 1);
67 EXPECT_EQ(TfLiteTensorDim(input_tensor, 0), 2);
68 EXPECT_EQ(TfLiteTensorByteSize(input_tensor), sizeof(float) * 2);
69 EXPECT_NE(TfLiteTensorData(input_tensor), nullptr);
70 EXPECT_STREQ(TfLiteTensorName(input_tensor), "input");
71
72 TfLiteQuantizationParams input_params =
73 TfLiteTensorQuantizationParams(input_tensor);
74 EXPECT_EQ(input_params.scale, 0.f);
75 EXPECT_EQ(input_params.zero_point, 0);
76
77 std::array<float, 2> input = {1.f, 3.f};
78 ASSERT_EQ(TfLiteTensorCopyFromBuffer(input_tensor, input.data(),
79 input.size() * sizeof(float)),
80 kTfLiteOk);
81
82 ASSERT_EQ(TfLiteInterpreterInvoke(interpreter), kTfLiteOk);
83
84 const TfLiteTensor* output_tensor =
85 TfLiteInterpreterGetOutputTensor(interpreter, 0);
86 ASSERT_NE(output_tensor, nullptr);
87 EXPECT_EQ(TfLiteTensorType(output_tensor), kTfLiteFloat32);
88 EXPECT_EQ(TfLiteTensorNumDims(output_tensor), 1);
89 EXPECT_EQ(TfLiteTensorDim(output_tensor, 0), 2);
90 EXPECT_EQ(TfLiteTensorByteSize(output_tensor), sizeof(float) * 2);
91 EXPECT_NE(TfLiteTensorData(output_tensor), nullptr);
92 EXPECT_STREQ(TfLiteTensorName(output_tensor), "output");
93
94 TfLiteQuantizationParams output_params =
95 TfLiteTensorQuantizationParams(output_tensor);
96 EXPECT_EQ(output_params.scale, 0.f);
97 EXPECT_EQ(output_params.zero_point, 0);
98
99 std::array<float, 2> output;
100 ASSERT_EQ(TfLiteTensorCopyToBuffer(output_tensor, output.data(),
101 output.size() * sizeof(float)),
102 kTfLiteOk);
103 EXPECT_EQ(output[0], 3.f);
104 EXPECT_EQ(output[1], 9.f);
105
106 TfLiteInterpreterDelete(interpreter);
107 TfLiteModelDelete(model);
108 }
109
TEST(CApiSimple,QuantizationParams)110 TEST(CApiSimple, QuantizationParams) {
111 TfLiteModel* model = TfLiteModelCreateFromFile(
112 "tensorflow/lite/testdata/add_quantized.bin");
113 ASSERT_NE(model, nullptr);
114
115 TfLiteInterpreter* interpreter = TfLiteInterpreterCreate(model, nullptr);
116 ASSERT_NE(interpreter, nullptr);
117
118 const std::array<int, 1> input_dims = {2};
119 ASSERT_EQ(TfLiteInterpreterResizeInputTensor(
120 interpreter, 0, input_dims.data(), input_dims.size()),
121 kTfLiteOk);
122 ASSERT_EQ(TfLiteInterpreterAllocateTensors(interpreter), kTfLiteOk);
123
124 TfLiteTensor* input_tensor = TfLiteInterpreterGetInputTensor(interpreter, 0);
125 ASSERT_NE(input_tensor, nullptr);
126 EXPECT_EQ(TfLiteTensorType(input_tensor), kTfLiteUInt8);
127 EXPECT_EQ(TfLiteTensorNumDims(input_tensor), 1);
128 EXPECT_EQ(TfLiteTensorDim(input_tensor, 0), 2);
129
130 TfLiteQuantizationParams input_params =
131 TfLiteTensorQuantizationParams(input_tensor);
132 EXPECT_EQ(input_params.scale, 0.003922f);
133 EXPECT_EQ(input_params.zero_point, 0);
134
135 const std::array<uint8_t, 2> input = {1, 3};
136 ASSERT_EQ(TfLiteTensorCopyFromBuffer(input_tensor, input.data(),
137 input.size() * sizeof(uint8_t)),
138 kTfLiteOk);
139
140 ASSERT_EQ(TfLiteInterpreterInvoke(interpreter), kTfLiteOk);
141
142 const TfLiteTensor* output_tensor =
143 TfLiteInterpreterGetOutputTensor(interpreter, 0);
144 ASSERT_NE(output_tensor, nullptr);
145
146 TfLiteQuantizationParams output_params =
147 TfLiteTensorQuantizationParams(output_tensor);
148 EXPECT_EQ(output_params.scale, 0.003922f);
149 EXPECT_EQ(output_params.zero_point, 0);
150
151 std::array<uint8_t, 2> output;
152 ASSERT_EQ(TfLiteTensorCopyToBuffer(output_tensor, output.data(),
153 output.size() * sizeof(uint8_t)),
154 kTfLiteOk);
155 EXPECT_EQ(output[0], 3);
156 EXPECT_EQ(output[1], 9);
157
158 const float dequantizedOutput0 =
159 output_params.scale * (output[0] - output_params.zero_point);
160 const float dequantizedOutput1 =
161 output_params.scale * (output[1] - output_params.zero_point);
162 EXPECT_EQ(dequantizedOutput0, 0.011766f);
163 EXPECT_EQ(dequantizedOutput1, 0.035298f);
164
165 TfLiteInterpreterDelete(interpreter);
166 TfLiteModelDelete(model);
167 }
168
TEST(CApiSimple,Delegate)169 TEST(CApiSimple, Delegate) {
170 TfLiteModel* model =
171 TfLiteModelCreateFromFile("tensorflow/lite/testdata/add.bin");
172
173 // Create and install a delegate instance.
174 bool delegate_prepared = false;
175 TfLiteDelegate delegate = TfLiteDelegateCreate();
176 delegate.data_ = &delegate_prepared;
177 delegate.Prepare = [](TfLiteContext* context, TfLiteDelegate* delegate) {
178 *static_cast<bool*>(delegate->data_) = true;
179 return kTfLiteOk;
180 };
181 TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();
182 TfLiteInterpreterOptionsAddDelegate(options, &delegate);
183 TfLiteInterpreter* interpreter = TfLiteInterpreterCreate(model, options);
184
185 // The delegate should have been applied.
186 EXPECT_TRUE(delegate_prepared);
187
188 // Subsequent execution should behave properly (the delegate is a no-op).
189 TfLiteInterpreterOptionsDelete(options);
190 EXPECT_EQ(TfLiteInterpreterInvoke(interpreter), kTfLiteOk);
191 TfLiteInterpreterDelete(interpreter);
192 TfLiteModelDelete(model);
193 }
194
TEST(CApiSimple,DelegateFails)195 TEST(CApiSimple, DelegateFails) {
196 TfLiteModel* model =
197 TfLiteModelCreateFromFile("tensorflow/lite/testdata/add.bin");
198
199 // Create and install a delegate instance.
200 TfLiteDelegate delegate = TfLiteDelegateCreate();
201 delegate.Prepare = [](TfLiteContext* context, TfLiteDelegate* delegate) {
202 return kTfLiteError;
203 };
204 TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();
205 TfLiteInterpreterOptionsAddDelegate(options, &delegate);
206 TfLiteInterpreter* interpreter = TfLiteInterpreterCreate(model, options);
207
208 // Interpreter creation should fail as delegate preparation failed.
209 EXPECT_EQ(nullptr, interpreter);
210
211 TfLiteInterpreterOptionsDelete(options);
212 TfLiteModelDelete(model);
213 }
214
TEST(CApiSimple,ErrorReporter)215 TEST(CApiSimple, ErrorReporter) {
216 TfLiteModel* model =
217 TfLiteModelCreateFromFile("tensorflow/lite/testdata/add.bin");
218 TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();
219
220 // Install a custom error reporter into the interpreter by way of options.
221 tflite::TestErrorReporter reporter;
222 TfLiteInterpreterOptionsSetErrorReporter(
223 options,
224 [](void* user_data, const char* format, va_list args) {
225 reinterpret_cast<tflite::TestErrorReporter*>(user_data)->Report(format,
226 args);
227 },
228 &reporter);
229 TfLiteInterpreter* interpreter = TfLiteInterpreterCreate(model, options);
230
231 // The options can be deleted immediately after interpreter creation.
232 TfLiteInterpreterOptionsDelete(options);
233
234 // Invoke the interpreter before tensor allocation.
235 EXPECT_EQ(TfLiteInterpreterInvoke(interpreter), kTfLiteError);
236
237 // The error should propagate to the custom error reporter.
238 EXPECT_EQ(reporter.error_messages(),
239 "Invoke called on model that is not ready.");
240 EXPECT_EQ(reporter.num_calls(), 1);
241
242 TfLiteInterpreterDelete(interpreter);
243 TfLiteModelDelete(model);
244 }
245
TEST(CApiSimple,ValidModel)246 TEST(CApiSimple, ValidModel) {
247 std::ifstream model_file("tensorflow/lite/testdata/add.bin");
248
249 model_file.seekg(0, std::ios_base::end);
250 std::vector<char> model_buffer(model_file.tellg());
251
252 model_file.seekg(0, std::ios_base::beg);
253 model_file.read(model_buffer.data(), model_buffer.size());
254
255 TfLiteModel* model =
256 TfLiteModelCreate(model_buffer.data(), model_buffer.size());
257 ASSERT_NE(model, nullptr);
258 TfLiteModelDelete(model);
259 }
260
TEST(CApiSimple,ValidModelFromFile)261 TEST(CApiSimple, ValidModelFromFile) {
262 TfLiteModel* model =
263 TfLiteModelCreateFromFile("tensorflow/lite/testdata/add.bin");
264 ASSERT_NE(model, nullptr);
265 TfLiteModelDelete(model);
266 }
267
TEST(CApiSimple,InvalidModel)268 TEST(CApiSimple, InvalidModel) {
269 std::vector<char> invalid_model(20, 'c');
270 TfLiteModel* model =
271 TfLiteModelCreate(invalid_model.data(), invalid_model.size());
272 ASSERT_EQ(model, nullptr);
273 }
274
TEST(CApiSimple,InvalidModelFromFile)275 TEST(CApiSimple, InvalidModelFromFile) {
276 TfLiteModel* model = TfLiteModelCreateFromFile("invalid/path/foo.tflite");
277 ASSERT_EQ(model, nullptr);
278 }
279
280 struct SinhParams {
281 bool use_cosh_instead = false;
282 };
283
FlexSinhInit(TfLiteOpaqueContext * context,const char * buffer,size_t length)284 void* FlexSinhInit(TfLiteOpaqueContext* context, const char* buffer,
285 size_t length) {
286 auto sinh_params = new SinhParams;
287 // The buffer that is passed into here is the custom_options
288 // field from the flatbuffer (tensorflow/lite/schema/schema.fbs)
289 // `Operator` for this node.
290 // Typically it should be stored as a FlexBuffer, but for this test
291 // we assume that it is just a string.
292 if (std::string(buffer, length) == "use_cosh") {
293 sinh_params->use_cosh_instead = true;
294 }
295 return sinh_params;
296 }
297
FlexSinhFree(TfLiteOpaqueContext * context,void * data)298 void FlexSinhFree(TfLiteOpaqueContext* context, void* data) {
299 delete static_cast<SinhParams*>(data);
300 }
301
FlexSinhPrepare(TfLiteOpaqueContext * context,TfLiteOpaqueNode * node)302 TfLiteStatus FlexSinhPrepare(TfLiteOpaqueContext* context,
303 TfLiteOpaqueNode* node) {
304 return kTfLiteOk;
305 }
306
FlexSinhEval(TfLiteOpaqueContext * context,TfLiteOpaqueNode * node)307 TfLiteStatus FlexSinhEval(TfLiteOpaqueContext* context,
308 TfLiteOpaqueNode* node) {
309 auto sinh_params =
310 static_cast<SinhParams*>(TfLiteOpaqueNodeGetUserData(node));
311 const TfLiteOpaqueTensor* input = TfLiteOpaqueNodeGetInput(context, node, 0);
312 size_t input_bytes = TfLiteOpaqueTensorByteSize(input);
313 void* data_ptr = TfLiteOpaqueTensorData(input);
314 float input_value;
315 memcpy(&input_value, data_ptr, input_bytes);
316
317 TfLiteOpaqueTensor* output = TfLiteOpaqueNodeGetOutput(context, node, 0);
318 float output_value = sinh_params->use_cosh_instead ? std::cosh(input_value)
319 : std::sinh(input_value);
320 TfLiteOpaqueTensorCopyFromBuffer(output, &output_value, sizeof(output_value));
321 return kTfLiteOk;
322 }
323
TEST(CApiSimple,CustomOpSupport)324 TEST(CApiSimple, CustomOpSupport) {
325 TfLiteModel* model = TfLiteModelCreateFromFile(
326 "tensorflow/lite/testdata/custom_sinh.bin");
327 ASSERT_NE(model, nullptr);
328
329 TfLiteRegistrationExternal* reg = TfLiteRegistrationExternalCreate("Sinh", 1);
330 TfLiteRegistrationExternalSetInit(reg, &FlexSinhInit);
331 TfLiteRegistrationExternalSetFree(reg, &FlexSinhFree);
332 TfLiteRegistrationExternalSetPrepare(reg, &FlexSinhPrepare);
333 TfLiteRegistrationExternalSetInvoke(reg, &FlexSinhEval);
334
335 TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();
336 TfLiteInterpreterOptionsAddRegistrationExternal(options, reg);
337
338 TfLiteInterpreter* interpreter = TfLiteInterpreterCreate(model, options);
339
340 TfLiteInterpreterOptionsDelete(options);
341 ASSERT_EQ(TfLiteInterpreterAllocateTensors(interpreter), kTfLiteOk);
342 TfLiteTensor* input_tensor = TfLiteInterpreterGetInputTensor(interpreter, 0);
343 float input_value = 1.0f;
344 TfLiteTensorCopyFromBuffer(input_tensor, &input_value, sizeof(float));
345
346 EXPECT_EQ(TfLiteInterpreterInvoke(interpreter), kTfLiteOk);
347
348 const TfLiteTensor* output_tensor =
349 TfLiteInterpreterGetOutputTensor(interpreter, 0);
350 float output_value;
351 TfLiteTensorCopyToBuffer(output_tensor, &output_value, sizeof(float));
352 EXPECT_EQ(output_value, std::sinh(1.0f));
353
354 TfLiteInterpreterDelete(interpreter);
355 TfLiteModelDelete(model);
356 TfLiteRegistrationExternalDelete(reg);
357 }
358
find_builtin_op_add(void * user_data,TfLiteBuiltinOperator op,int version)359 const TfLiteRegistration* find_builtin_op_add(void* user_data,
360 TfLiteBuiltinOperator op,
361 int version) {
362 static TfLiteRegistration registration{/*init=*/nullptr,
363 /*free=*/nullptr,
364 /*prepare=*/nullptr,
365 /*invoke=*/nullptr,
366 /*profiling_string=*/nullptr,
367 /*builtin_code=*/kTfLiteBuiltinAdd,
368 /*custom_name=*/nullptr,
369 /*version=*/1};
370 if (op == kTfLiteBuiltinAdd && version == 1) {
371 return ®istration;
372 }
373 return nullptr;
374 }
375
find_custom_op_sinh(void * user_data,const char * op,int version)376 const TfLiteRegistration* find_custom_op_sinh(void* user_data, const char* op,
377 int version) {
378 static TfLiteRegistration registration{/*init=*/nullptr,
379 /*free=*/nullptr,
380 /*prepare=*/nullptr,
381 /*invoke=*/nullptr,
382 /*profiling_string=*/nullptr,
383 /*builtin_code=*/kTfLiteBuiltinCustom,
384 /*custom_name=*/"Sinh",
385 /*version=*/1};
386 if (strcmp(op, "Sinh") == 0 && version == 1) {
387 return ®istration;
388 }
389 return nullptr;
390 }
391
TEST(CApiSimple,CallbackOpResolver)392 TEST(CApiSimple, CallbackOpResolver) {
393 tflite::internal::CallbackOpResolver resolver;
394 struct TfLiteOpResolverCallbacks callbacks {};
395 callbacks.find_builtin_op = find_builtin_op_add;
396 callbacks.find_custom_op = find_custom_op_sinh;
397
398 resolver.SetCallbacks(callbacks);
399 auto reg_add = resolver.FindOp(
400 static_cast<::tflite::BuiltinOperator>(kTfLiteBuiltinAdd), 1);
401 ASSERT_NE(reg_add, nullptr);
402 EXPECT_EQ(reg_add->builtin_code, kTfLiteBuiltinAdd);
403 EXPECT_EQ(reg_add->version, 1);
404 EXPECT_EQ(reg_add->registration_external, nullptr);
405
406 EXPECT_EQ(
407 resolver.FindOp(
408 static_cast<::tflite::BuiltinOperator>(kTfLiteBuiltinConv2d), 1),
409 nullptr);
410
411 auto reg_sinh = resolver.FindOp("Sinh", 1);
412 ASSERT_NE(reg_sinh, nullptr);
413 EXPECT_EQ(reg_sinh->builtin_code, kTfLiteBuiltinCustom);
414 EXPECT_EQ(reg_sinh->custom_name, "Sinh");
415 EXPECT_EQ(reg_sinh->version, 1);
416 EXPECT_EQ(reg_sinh->registration_external, nullptr);
417
418 EXPECT_EQ(resolver.FindOp("Cosh", 1), nullptr);
419 }
420
dummy_find_builtin_op_v1(void * user_data,TfLiteBuiltinOperator op,int version)421 const TfLiteRegistration_V1* dummy_find_builtin_op_v1(void* user_data,
422 TfLiteBuiltinOperator op,
423 int version) {
424 static TfLiteRegistration_V1 registration_v1{
425 nullptr, nullptr, nullptr, nullptr,
426 nullptr, kTfLiteBuiltinAdd, nullptr, 1};
427 if (op == kTfLiteBuiltinAdd) {
428 return ®istration_v1;
429 }
430 return nullptr;
431 }
432
dummy_find_custom_op_v1(void * user_data,const char * op,int version)433 const TfLiteRegistration_V1* dummy_find_custom_op_v1(void* user_data,
434 const char* op,
435 int version) {
436 static TfLiteRegistration_V1 registration_v1{
437 nullptr, nullptr, nullptr, nullptr, nullptr, kTfLiteBuiltinCustom,
438 "Sinh", 1};
439 if (strcmp(op, "Sinh") == 0) {
440 return ®istration_v1;
441 }
442 return nullptr;
443 }
444
TEST(CApiSimple,CallbackOpResolver_V1)445 TEST(CApiSimple, CallbackOpResolver_V1) {
446 tflite::internal::CallbackOpResolver resolver;
447 struct TfLiteOpResolverCallbacks callbacks {};
448 callbacks.find_builtin_op_v1 = dummy_find_builtin_op_v1;
449 callbacks.find_custom_op_v1 = dummy_find_custom_op_v1;
450
451 resolver.SetCallbacks(callbacks);
452 auto reg_add = resolver.FindOp(
453 static_cast<::tflite::BuiltinOperator>(kTfLiteBuiltinAdd), 1);
454 ASSERT_NE(reg_add, nullptr);
455 EXPECT_EQ(reg_add->builtin_code, kTfLiteBuiltinAdd);
456 EXPECT_EQ(reg_add->version, 1);
457 EXPECT_EQ(reg_add->registration_external, nullptr);
458
459 EXPECT_EQ(
460 resolver.FindOp(
461 static_cast<::tflite::BuiltinOperator>(kTfLiteBuiltinConv2d), 1),
462 nullptr);
463
464 // Query kTfLiteBuiltinAdd multiple times to check if caching logic works.
465 for (int i = 0; i < 10; ++i) {
466 auto reg_add = resolver.FindOp(
467 static_cast<::tflite::BuiltinOperator>(kTfLiteBuiltinAdd), 1);
468 ASSERT_NE(reg_add, nullptr);
469 EXPECT_EQ(reg_add->builtin_code, kTfLiteBuiltinAdd);
470 EXPECT_EQ(reg_add->version, 1);
471 EXPECT_EQ(reg_add->registration_external, nullptr);
472 }
473
474 auto reg_sinh = resolver.FindOp("Sinh", 1);
475 ASSERT_NE(reg_sinh, nullptr);
476 EXPECT_EQ(reg_sinh->builtin_code, kTfLiteBuiltinCustom);
477 EXPECT_EQ(reg_sinh->custom_name, "Sinh");
478 EXPECT_EQ(reg_sinh->version, 1);
479 EXPECT_EQ(reg_sinh->registration_external, nullptr);
480
481 EXPECT_EQ(resolver.FindOp("Cosh", 1), nullptr);
482
483 // Query "Sinh" multiple times to check if caching logic works.
484 for (int i = 0; i < 10; ++i) {
485 auto reg_sinh = resolver.FindOp("Sinh", 1);
486 ASSERT_NE(reg_sinh, nullptr);
487 EXPECT_EQ(reg_sinh->builtin_code, kTfLiteBuiltinCustom);
488 EXPECT_EQ(reg_sinh->custom_name, "Sinh");
489 EXPECT_EQ(reg_sinh->version, 1);
490 EXPECT_EQ(reg_sinh->registration_external, nullptr);
491 }
492 }
493
494 } // namespace
495