xref: /aosp_15_r20/external/tflite-support/tensorflow_lite_support/cc/task/core/tflite_engine.h (revision b16991f985baa50654c05c5adbb3c8bbcfb40082)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TFLITE_ENGINE_H_
17 #define TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TFLITE_ENGINE_H_
18 
19 #include <sys/mman.h>
20 
21 #include <memory>
22 
23 #include "absl/memory/memory.h"
24 #include "absl/status/status.h"
25 #include "absl/strings/string_view.h"
26 #include "tensorflow/lite/c/common.h"
27 #include "tensorflow/lite/core/api/op_resolver.h"
28 #include "tensorflow/lite/kernels/register.h"
29 #include "tensorflow_lite_support/cc/port/tflite_wrapper.h"
30 #include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
31 #include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h"
32 #include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
33 
34 // If compiled with -DTFLITE_USE_C_API, this file will use the TF Lite C API
35 // rather than the TF Lite C++ API.
36 // TODO(b/168025296): eliminate the '#if TFLITE_USE_C_API' directives here and
37 // elsewhere and instead use the C API unconditionally, once we have a suitable
38 // replacement for the features of tflite::support::TfLiteInterpreterWrapper.
39 #if TFLITE_USE_C_API
40 #include "tensorflow/lite/c/c_api.h"
41 #else
42 #include "tensorflow/lite/interpreter.h"
43 #include "tensorflow/lite/model.h"
44 #endif
45 
46 namespace tflite {
47 namespace task {
48 namespace core {
49 
50 // TfLiteEngine encapsulates logic for TFLite model initialization, inference
51 // and error reporting.
52 class TfLiteEngine {
53  public:
54   // Types.
55   using InterpreterWrapper = tflite::support::TfLiteInterpreterWrapper;
56 #if TFLITE_USE_C_API
57   using Model = struct TfLiteModel;
58   using Interpreter = struct TfLiteInterpreter;
59   using ModelDeleter = void (*)(Model*);
60   using InterpreterDeleter = InterpreterWrapper::InterpreterDeleter;
61 #else
62   using Model = tflite::FlatBufferModel;
63   using Interpreter = tflite::Interpreter;
64   using ModelDeleter = std::default_delete<Model>;
65   using InterpreterDeleter = std::default_delete<Interpreter>;
66 #endif
67 
68   // Constructors.
69   explicit TfLiteEngine(
70       std::unique_ptr<tflite::OpResolver> resolver =
71           absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
72   // Model is neither copyable nor movable.
73   TfLiteEngine(const TfLiteEngine&) = delete;
74   TfLiteEngine& operator=(const TfLiteEngine&) = delete;
75 
76   // Accessors.
InputCount(const Interpreter * interpreter)77   static int32_t InputCount(const Interpreter* interpreter) {
78 #if TFLITE_USE_C_API
79     return TfLiteInterpreterGetInputTensorCount(interpreter);
80 #else
81     return interpreter->inputs().size();
82 #endif
83   }
OutputCount(const Interpreter * interpreter)84   static int32_t OutputCount(const Interpreter* interpreter) {
85 #if TFLITE_USE_C_API
86     return TfLiteInterpreterGetOutputTensorCount(interpreter);
87 #else
88     return interpreter->outputs().size();
89 #endif
90   }
GetInput(Interpreter * interpreter,int index)91   static TfLiteTensor* GetInput(Interpreter* interpreter, int index) {
92 #if TFLITE_USE_C_API
93     return TfLiteInterpreterGetInputTensor(interpreter, index);
94 #else
95     return interpreter->tensor(interpreter->inputs()[index]);
96 #endif
97   }
98   // Same as above, but const.
GetInput(const Interpreter * interpreter,int index)99   static const TfLiteTensor* GetInput(const Interpreter* interpreter,
100                                       int index) {
101 #if TFLITE_USE_C_API
102     return TfLiteInterpreterGetInputTensor(interpreter, index);
103 #else
104     return interpreter->tensor(interpreter->inputs()[index]);
105 #endif
106   }
GetOutput(Interpreter * interpreter,int index)107   static TfLiteTensor* GetOutput(Interpreter* interpreter, int index) {
108 #if TFLITE_USE_C_API
109     // We need a const_cast here, because the TF Lite C API only has a non-const
110     // version of GetOutputTensor (in part because C doesn't support overloading
111     // on const).
112     return const_cast<TfLiteTensor*>(
113         TfLiteInterpreterGetOutputTensor(interpreter, index));
114 #else
115     return interpreter->tensor(interpreter->outputs()[index]);
116 #endif
117   }
118   // Same as above, but const.
GetOutput(const Interpreter * interpreter,int index)119   static const TfLiteTensor* GetOutput(const Interpreter* interpreter,
120                                        int index) {
121 #if TFLITE_USE_C_API
122     return TfLiteInterpreterGetOutputTensor(interpreter, index);
123 #else
124     return interpreter->tensor(interpreter->outputs()[index]);
125 #endif
126   }
127 
128   std::vector<TfLiteTensor*> GetInputs();
129   std::vector<const TfLiteTensor*> GetOutputs();
130 
model()131   const Model* model() const { return model_.get(); }
interpreter()132   Interpreter* interpreter() { return interpreter_.get(); }
interpreter()133   const Interpreter* interpreter() const { return interpreter_.get(); }
interpreter_wrapper()134   InterpreterWrapper* interpreter_wrapper() { return &interpreter_; }
metadata_extractor()135   const tflite::metadata::ModelMetadataExtractor* metadata_extractor() const {
136     return model_metadata_extractor_.get();
137   }
138 
139   // Builds the TF Lite FlatBufferModel (model_) from the raw FlatBuffer data
140   // whose ownership remains with the caller, and which must outlive the current
141   // object. This performs extra verification on the input data using
142   // tflite::Verify.
143   absl::Status BuildModelFromFlatBuffer(const char* buffer_data,
144                                         size_t buffer_size);
145 
146   // Builds the TF Lite model from a given file.
147   absl::Status BuildModelFromFile(const std::string& file_name);
148 
149   // Builds the TF Lite model from a given file descriptor using mmap(2).
150   absl::Status BuildModelFromFileDescriptor(int file_descriptor);
151 
152   // Builds the TFLite model from the provided ExternalFile proto, which must
153   // outlive the current object.
154   absl::Status BuildModelFromExternalFileProto(
155       const ExternalFile* external_file);
156 
157   // Initializes interpreter with encapsulated model.
158   // Note: setting num_threads to -1 has for effect to let TFLite runtime set
159   // the value.
160   absl::Status InitInterpreter(int num_threads = 1);
161 
162   // Same as above, but allows specifying `compute_settings` for acceleration.
163   absl::Status InitInterpreter(
164       const tflite::proto::ComputeSettings& compute_settings,
165       int num_threads = 1);
166 
167   // Cancels the on-going `Invoke()` call if any and if possible. This method
168   // can be called from a different thread than the one where `Invoke()` is
169   // running.
Cancel()170   void Cancel() {
171 #if TFLITE_USE_C_API
172     // NOP.
173 #else
174     interpreter_.Cancel();
175 #endif
176   }
177 
178  protected:
179   // TF Lite's DefaultErrorReporter() outputs to stderr. This one captures the
180   // error into a string so that it can be used to complement tensorflow::Status
181   // error messages.
182   struct ErrorReporter : public tflite::ErrorReporter {
183     // Last error message captured by this error reporter.
184     char error_message[256];
185     int Report(const char* format, va_list args) override;
186   };
187   // Custom error reporter capturing low-level TF Lite error messages.
188   ErrorReporter error_reporter_;
189 
190  private:
191   // Builds the model from the buffer and stores it in 'model_'.
192   void BuildModelFromBuffer(const char* buffer_data, size_t buffer_size);
193 
194   // Gets the buffer from the file handler; verifies and builds the model
195   // from the buffer; if successful, sets 'model_metadata_extractor_' to be
196   // a TF Lite Metadata extractor for the model; and calculates an appropriate
197   // return Status,
198   absl::Status InitializeFromModelFileHandler();
199 
200   // TF Lite model and interpreter for actual inference.
201   std::unique_ptr<Model, ModelDeleter> model_;
202 
203   // Interpreter wrapper built from the model.
204   InterpreterWrapper interpreter_;
205 
206   // TFLite Metadata extractor built from the model.
207   std::unique_ptr<tflite::metadata::ModelMetadataExtractor>
208       model_metadata_extractor_;
209 
210   // Mechanism used by TF Lite to map Ops referenced in the FlatBuffer model to
211   // actual implementation. Defaults to TF Lite BuiltinOpResolver.
212   std::unique_ptr<tflite::OpResolver> resolver_;
213 
214   // ExternalFile and corresponding ExternalFileHandler for models loaded from
215   // disk or file descriptor.
216   ExternalFile external_file_;
217   std::unique_ptr<ExternalFileHandler> model_file_handler_;
218 };
219 
220 }  // namespace core
221 }  // namespace task
222 }  // namespace tflite
223 
224 #endif  // TENSORFLOW_LITE_SUPPORT_CC_TASK_CORE_TFLITE_ENGINE_H_
225