1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TFLITE_ENGINE_H_ 17 #define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TFLITE_ENGINE_H_ 18 19 #include <sys/mman.h> 20 21 #include <memory> 22 23 #include "absl/memory/memory.h" 24 #include "absl/status/status.h" 25 #include "absl/strings/string_view.h" 26 #include "tensorflow/lite/c/common.h" 27 #include "tensorflow/lite/core/api/op_resolver.h" 28 #include "tensorflow/lite/kernels/register.h" 29 #include "tensorflow_lite_support/cc/port/tflite_wrapper.h" 30 #include "tensorflow_lite_support/cc/task/core/external_file_handler.h" 31 #include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" 32 #include "tensorflow_lite_support/metadata/cc/metadata_extractor.h" 33 34 // If compiled with -DTFLITE_USE_C_API, this file will use the TF Lite C API 35 // rather than the TF Lite C++ API. 36 // TODO(b/168025296): eliminate the '#if TFLITE_USE_C_API' directives here and 37 // elsewhere and instead use the C API unconditionally, once we have a suitable 38 // replacement for the features of tflite::support::TfLiteInterpreterWrapper. 39 #if TFLITE_USE_C_API 40 #include "tensorflow/lite/c/c_api.h" 41 #else 42 #include "tensorflow/lite/interpreter.h" 43 #include "tensorflow/lite/model.h" 44 #endif 45 46 namespace tflite { 47 namespace task { 48 namespace core { 49 50 // TfLiteEngine encapsulates logic for TFLite model initialization, inference 51 // and error reporting. 52 class TfLiteEngine { 53 public: 54 // Types. 55 using InterpreterWrapper = tflite::support::TfLiteInterpreterWrapper; 56 #if TFLITE_USE_C_API 57 using Model = struct TfLiteModel; 58 using Interpreter = struct TfLiteInterpreter; 59 using ModelDeleter = void (*)(Model*); 60 using InterpreterDeleter = InterpreterWrapper::InterpreterDeleter; 61 #else 62 using Model = tflite::FlatBufferModel; 63 using Interpreter = tflite::Interpreter; 64 using ModelDeleter = std::default_delete<Model>; 65 using InterpreterDeleter = std::default_delete<Interpreter>; 66 #endif 67 68 // Constructors. 69 explicit TfLiteEngine( 70 std::unique_ptr<tflite::OpResolver> resolver = 71 absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>()); 72 // Model is neither copyable nor movable. 73 TfLiteEngine(const TfLiteEngine&) = delete; 74 TfLiteEngine& operator=(const TfLiteEngine&) = delete; 75 76 // Accessors. InputCount(const Interpreter * interpreter)77 static int32_t InputCount(const Interpreter* interpreter) { 78 #if TFLITE_USE_C_API 79 return TfLiteInterpreterGetInputTensorCount(interpreter); 80 #else 81 return interpreter->inputs().size(); 82 #endif 83 } OutputCount(const Interpreter * interpreter)84 static int32_t OutputCount(const Interpreter* interpreter) { 85 #if TFLITE_USE_C_API 86 return TfLiteInterpreterGetOutputTensorCount(interpreter); 87 #else 88 return interpreter->outputs().size(); 89 #endif 90 } GetInput(Interpreter * interpreter,int index)91 static TfLiteTensor* GetInput(Interpreter* interpreter, int index) { 92 #if TFLITE_USE_C_API 93 return TfLiteInterpreterGetInputTensor(interpreter, index); 94 #else 95 return interpreter->tensor(interpreter->inputs()[index]); 96 #endif 97 } 98 // Same as above, but const. GetInput(const Interpreter * interpreter,int index)99 static const TfLiteTensor* GetInput(const Interpreter* interpreter, 100 int index) { 101 #if TFLITE_USE_C_API 102 return TfLiteInterpreterGetInputTensor(interpreter, index); 103 #else 104 return interpreter->tensor(interpreter->inputs()[index]); 105 #endif 106 } GetOutput(Interpreter * interpreter,int index)107 static TfLiteTensor* GetOutput(Interpreter* interpreter, int index) { 108 #if TFLITE_USE_C_API 109 // We need a const_cast here, because the TF Lite C API only has a non-const 110 // version of GetOutputTensor (in part because C doesn't support overloading 111 // on const). 112 return const_cast<TfLiteTensor*>( 113 TfLiteInterpreterGetOutputTensor(interpreter, index)); 114 #else 115 return interpreter->tensor(interpreter->outputs()[index]); 116 #endif 117 } 118 // Same as above, but const. GetOutput(const Interpreter * interpreter,int index)119 static const TfLiteTensor* GetOutput(const Interpreter* interpreter, 120 int index) { 121 #if TFLITE_USE_C_API 122 return TfLiteInterpreterGetOutputTensor(interpreter, index); 123 #else 124 return interpreter->tensor(interpreter->outputs()[index]); 125 #endif 126 } 127 128 std::vector<TfLiteTensor*> GetInputs(); 129 std::vector<const TfLiteTensor*> GetOutputs(); 130 model()131 const Model* model() const { return model_.get(); } interpreter()132 Interpreter* interpreter() { return interpreter_.get(); } interpreter()133 const Interpreter* interpreter() const { return interpreter_.get(); } interpreter_wrapper()134 InterpreterWrapper* interpreter_wrapper() { return &interpreter_; } metadata_extractor()135 const tflite::metadata::ModelMetadataExtractor* metadata_extractor() const { 136 return model_metadata_extractor_.get(); 137 } 138 139 // Builds the TF Lite FlatBufferModel (model_) from the raw FlatBuffer data 140 // whose ownership remains with the caller, and which must outlive the current 141 // object. This performs extra verification on the input data using 142 // tflite::Verify. 143 absl::Status BuildModelFromFlatBuffer(const char* buffer_data, 144 size_t buffer_size); 145 146 // Builds the TF Lite model from a given file. 147 absl::Status BuildModelFromFile(const std::string& file_name); 148 149 // Builds the TF Lite model from a given file descriptor using mmap(2). 150 absl::Status BuildModelFromFileDescriptor(int file_descriptor); 151 152 // Builds the TFLite model from the provided ExternalFile proto, which must 153 // outlive the current object. 154 absl::Status BuildModelFromExternalFileProto( 155 const ExternalFile* external_file); 156 157 // Initializes interpreter with encapsulated model. 158 // Note: setting num_threads to -1 has for effect to let TFLite runtime set 159 // the value. 160 absl::Status InitInterpreter(int num_threads = 1); 161 162 // Same as above, but allows specifying `compute_settings` for acceleration. 163 absl::Status InitInterpreter( 164 const tflite::proto::ComputeSettings& compute_settings, 165 int num_threads = 1); 166 167 // Cancels the on-going `Invoke()` call if any and if possible. This method 168 // can be called from a different thread than the one where `Invoke()` is 169 // running. Cancel()170 void Cancel() { 171 #if TFLITE_USE_C_API 172 // NOP. 173 #else 174 interpreter_.Cancel(); 175 #endif 176 } 177 178 protected: 179 // TF Lite's DefaultErrorReporter() outputs to stderr. This one captures the 180 // error into a string so that it can be used to complement tensorflow::Status 181 // error messages. 182 struct ErrorReporter : public tflite::ErrorReporter { 183 // Last error message captured by this error reporter. 184 char error_message[256]; 185 int Report(const char* format, va_list args) override; 186 }; 187 // Custom error reporter capturing low-level TF Lite error messages. 188 ErrorReporter error_reporter_; 189 190 private: 191 // Builds the model from the buffer and stores it in 'model_'. 192 void BuildModelFromBuffer(const char* buffer_data, size_t buffer_size); 193 194 // Gets the buffer from the file handler; verifies and builds the model 195 // from the buffer; if successful, sets 'model_metadata_extractor_' to be 196 // a TF Lite Metadata extractor for the model; and calculates an appropriate 197 // return Status, 198 absl::Status InitializeFromModelFileHandler(); 199 200 // TF Lite model and interpreter for actual inference. 201 std::unique_ptr<Model, ModelDeleter> model_; 202 203 // Interpreter wrapper built from the model. 204 InterpreterWrapper interpreter_; 205 206 // TFLite Metadata extractor built from the model. 207 std::unique_ptr<tflite::metadata::ModelMetadataExtractor> 208 model_metadata_extractor_; 209 210 // Mechanism used by TF Lite to map Ops referenced in the FlatBuffer model to 211 // actual implementation. Defaults to TF Lite BuiltinOpResolver. 212 std::unique_ptr<tflite::OpResolver> resolver_; 213 214 // ExternalFile and corresponding ExternalFileHandler for models loaded from 215 // disk or file descriptor. 216 ExternalFile external_file_; 217 std::unique_ptr<ExternalFileHandler> model_file_handler_; 218 }; 219 220 } // namespace core 221 } // namespace task 222 } // namespace tflite 223 224 #endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TFLITE_ENGINE_H_ 225