1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_LITE_C_C_API_INTERNAL_H_ 16 #define TENSORFLOW_LITE_C_C_API_INTERNAL_H_ 17 18 #include <stdarg.h> 19 20 #include <memory> 21 #include <mutex> // NOLINT 22 #include <vector> 23 24 #include "tensorflow/lite/builtin_ops.h" 25 #include "tensorflow/lite/core/api/error_reporter.h" 26 #include "tensorflow/lite/core/api/op_resolver.h" 27 #include "tensorflow/lite/interpreter.h" 28 #include "tensorflow/lite/model.h" 29 #include "tensorflow/lite/mutable_op_resolver.h" 30 #include "tensorflow/lite/signature_runner.h" 31 32 // Internal structures and subroutines used by the C API. These are likely to 33 // change and should not be depended on directly by any C API clients. 34 // 35 // NOTE: This header does not follow C conventions and does not define a C API. 36 // It is effectively an (internal) implementation detail of the C API. 37 38 struct TfLiteModel { 39 // Sharing is safe as FlatBufferModel is const. 40 std::shared_ptr<const tflite::FlatBufferModel> impl; 41 }; 42 43 // The `TfLiteOpResolver` struct is an abstract callback interface that 44 // contains function pointers for callbacks that return a 45 // `TfLiteRegistration` given an op code or custom op name. This mechanism is 46 // used to map ops referenced in the flatbuffer model to executable function 47 // pointers (`TfLiteRegistration`s). 48 // This struct mirrors the tflite::OpResolver C++ abstract base class. 49 struct TfLiteOpResolverCallbacks { 50 // Opaque data that gets passed down to the callback functions. 51 void* user_data = nullptr; 52 53 // Callback that finds the op registration for a builtin operator by enum 54 // code. The `user_data` parameter will be set to the 55 // `op_resolver_user_data` value that was passed to 56 // `TfLiteInterpreterOptionsSetOpResolver`. 57 const TfLiteRegistration* (*find_builtin_op)(void* user_data, 58 TfLiteBuiltinOperator op, 59 int version); 60 // Callback that finds the op registration of a custom operator by op name. 61 // The `user_data` parameter will be set to the `op_resolver_user_data` value 62 // that was passed to `TfLiteInterpreterOptionsSetOpResolver`. 63 const TfLiteRegistration* (*find_custom_op)(void* user_data, const char* op, 64 int version); 65 66 // `find_builtin_op` which returns `TfLiteRegistration_V1`. 67 const TfLiteRegistration_V1* (*find_builtin_op_v1)(void* user_data, 68 TfLiteBuiltinOperator op, 69 int version); 70 // `find_custom_op` which returns `TfLiteRegistration_V1`. 71 const TfLiteRegistration_V1* (*find_custom_op_v1)(void* user_data, 72 const char* op, 73 int version); 74 }; 75 76 // This struct mirrors the tflite::ErrorResolver C++ abstract base class. 77 struct TfLiteErrorReporterCallback { 78 // Opaque data that gets passed down to the callback function. 79 void* user_data = nullptr; 80 81 // Callback function that reports an error. 82 void (*error_reporter)(void* user_data, const char* format, 83 va_list args) = nullptr; 84 }; 85 86 struct TfLiteInterpreterOptions { 87 enum { 88 kDefaultNumThreads = -1, 89 }; 90 int num_threads = kDefaultNumThreads; 91 92 tflite::MutableOpResolver mutable_op_resolver; 93 94 TfLiteOpResolverCallbacks op_resolver_callbacks = {}; 95 96 std::vector<TfLiteDelegate*> delegates; 97 98 TfLiteErrorReporterCallback error_reporter_callback; 99 100 bool use_nnapi = false; 101 102 // Determines whether to allow automatic fallback to CPU. 103 // If true, and if one or more delegates were set, 104 // then if Invoke with delegates fails, it will be 105 // automatically retried without delegates. 106 bool enable_delegate_fallback = false; 107 108 // TfLiteRegistrationExternal objects owned by caller of 109 // `TfLiteInterpreterOptionsAddRegistrationExternal` API. 110 std::vector<TfLiteRegistrationExternal*> op_registrations; 111 }; 112 113 struct TfLiteInterpreter { 114 // Taking a reference to the (const) model data avoids lifetime-related issues 115 // and complexity with the TfLiteModel's existence. 116 std::shared_ptr<const tflite::FlatBufferModel> model; 117 118 // The interpreter does not take ownership of the provided ErrorReporter 119 // instance, so we ensure its validity here. Note that the interpreter may use 120 // the reporter in its destructor, so the reporter should be declared first. 121 std::unique_ptr<tflite::ErrorReporter> optional_error_reporter; 122 123 std::unique_ptr<tflite::Interpreter> impl; 124 125 bool enable_delegate_fallback; 126 }; 127 128 struct TfLiteSignatureRunner { 129 // The tflite::SignatureRunner runner object that this points to is owned by 130 // the interpreter. So this pointer will become invalid when the interpreter 131 // is destroyed. 132 tflite::SignatureRunner* impl; 133 }; 134 135 namespace tflite { 136 namespace internal { 137 138 /// `CallbackOpResolver` is a (C++) `tflite::OpResolver` that forwards the 139 /// methods to (C ABI) callback functions from a `TfLiteOpResolverCallbacks` 140 /// struct. 141 /// 142 /// The SetCallbacks method must be called before calling any of the FindOp 143 /// methods. 144 class CallbackOpResolver : public ::tflite::OpResolver { 145 public: CallbackOpResolver()146 CallbackOpResolver() {} SetCallbacks(const struct TfLiteOpResolverCallbacks & op_resolver_callbacks)147 void SetCallbacks( 148 const struct TfLiteOpResolverCallbacks& op_resolver_callbacks) { 149 op_resolver_callbacks_ = op_resolver_callbacks; 150 } 151 const TfLiteRegistration* FindOp(tflite::BuiltinOperator op, 152 int version) const override; 153 154 const TfLiteRegistration* FindOp(const char* op, int version) const override; 155 156 private: 157 CallbackOpResolver(const CallbackOpResolver&) = delete; 158 CallbackOpResolver& operator=(const CallbackOpResolver&) = delete; 159 160 struct TfLiteOpResolverCallbacks op_resolver_callbacks_ = {}; 161 162 // mutable objects to store temporary `TfLiteRegistration`. 163 mutable std::mutex mutex_; 164 mutable std::vector<std::unique_ptr<TfLiteRegistration>> 165 temporary_builtin_registrations_; // GUARDED_BY(mutex_) 166 mutable std::vector<std::unique_ptr<TfLiteRegistration>> 167 temporary_custom_registrations_; // GUARDED_BY(mutex_) 168 }; 169 170 // This adds the builtin and/or custom operators specified in options in 171 // `optional_options` (if any) to `mutable_resolver`, and then returns a newly 172 // created TfLiteInterpreter using `mutable_op_resolver` as the default 173 // OpResolver, and using any other options in `optional_options`, and using 174 // the provided `model`. 175 // 176 // * `model` must be a valid model instance. The caller retains ownership of the 177 // object, and can destroy it immediately after creating the interpreter; the 178 // interpreter will maintain its own reference to the underlying model data. 179 // * `optional_options` may be null. The caller retains ownership of the object, 180 // and can safely destroy it immediately after creating the interpreter. 181 // * `mutable_resolver` must not be null. The caller retains ownership of the 182 // MutableOpResolver object, and can safely destroy it immediately after 183 // creating the interpreter. 184 // 185 // NOTE: The client *must* explicitly allocate tensors before attempting to 186 // access input tensor data or invoke the interpreter. 187 188 TfLiteInterpreter* InterpreterCreateWithOpResolver( 189 const TfLiteModel* model, const TfLiteInterpreterOptions* optional_options, 190 tflite::MutableOpResolver* mutable_resolver); 191 192 } // namespace internal 193 } // namespace tflite 194 195 #endif // TENSORFLOW_LITE_C_C_API_INTERNAL_H_ 196