1*993b0882SAndroid Build Coastguard Worker /* 2*993b0882SAndroid Build Coastguard Worker * Copyright (C) 2018 The Android Open Source Project 3*993b0882SAndroid Build Coastguard Worker * 4*993b0882SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License"); 5*993b0882SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License. 6*993b0882SAndroid Build Coastguard Worker * You may obtain a copy of the License at 7*993b0882SAndroid Build Coastguard Worker * 8*993b0882SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0 9*993b0882SAndroid Build Coastguard Worker * 10*993b0882SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software 11*993b0882SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS, 12*993b0882SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13*993b0882SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and 14*993b0882SAndroid Build Coastguard Worker * limitations under the License. 15*993b0882SAndroid Build Coastguard Worker */ 16*993b0882SAndroid Build Coastguard Worker 17*993b0882SAndroid Build Coastguard Worker // Contains classes that can execute different models/parts of a model. 18*993b0882SAndroid Build Coastguard Worker 19*993b0882SAndroid Build Coastguard Worker #ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_ 20*993b0882SAndroid Build Coastguard Worker #define LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_ 21*993b0882SAndroid Build Coastguard Worker 22*993b0882SAndroid Build Coastguard Worker #include <cstdint> 23*993b0882SAndroid Build Coastguard Worker #include <memory> 24*993b0882SAndroid Build Coastguard Worker 25*993b0882SAndroid Build Coastguard Worker #include "utils/base/logging.h" 26*993b0882SAndroid Build Coastguard Worker #include "utils/tensor-view.h" 27*993b0882SAndroid Build Coastguard Worker #include "tensorflow/lite/interpreter.h" 28*993b0882SAndroid Build Coastguard Worker #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" 29*993b0882SAndroid Build Coastguard Worker #include "tensorflow/lite/kernels/register.h" 30*993b0882SAndroid Build Coastguard Worker #include "tensorflow/lite/model.h" 31*993b0882SAndroid Build Coastguard Worker #include "tensorflow/lite/mutable_op_resolver.h" 32*993b0882SAndroid Build Coastguard Worker #include "tensorflow/lite/op_resolver.h" 33*993b0882SAndroid Build Coastguard Worker #include "tensorflow/lite/string_util.h" 34*993b0882SAndroid Build Coastguard Worker 35*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 { 36*993b0882SAndroid Build Coastguard Worker 37*993b0882SAndroid Build Coastguard Worker // Creates a TF.Lite Op resolver in default configuration, with ops for 38*993b0882SAndroid Build Coastguard Worker // Annotator and Actions models. 39*993b0882SAndroid Build Coastguard Worker std::unique_ptr<tflite::OpResolver> BuildOpResolver(); 40*993b0882SAndroid Build Coastguard Worker 41*993b0882SAndroid Build Coastguard Worker // Like above, but allows passage of a function that can register additional 42*993b0882SAndroid Build Coastguard Worker // ops. 43*993b0882SAndroid Build Coastguard Worker std::unique_ptr<tflite::OpResolver> BuildOpResolver( 44*993b0882SAndroid Build Coastguard Worker const std::function<void(tflite::MutableOpResolver*)>& customize_fn); 45*993b0882SAndroid Build Coastguard Worker 46*993b0882SAndroid Build Coastguard Worker std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromModelSpec( 47*993b0882SAndroid Build Coastguard Worker const tflite::Model*); 48*993b0882SAndroid Build Coastguard Worker std::unique_ptr<const tflite::FlatBufferModel> TfLiteModelFromBuffer( 49*993b0882SAndroid Build Coastguard Worker const flatbuffers::Vector<uint8_t>*); 50*993b0882SAndroid Build Coastguard Worker 51*993b0882SAndroid Build Coastguard Worker // Executor for the text selection prediction and classification models. 52*993b0882SAndroid Build Coastguard Worker class TfLiteModelExecutor { 53*993b0882SAndroid Build Coastguard Worker public: FromModelSpec(const tflite::Model * model_spec)54*993b0882SAndroid Build Coastguard Worker static std::unique_ptr<TfLiteModelExecutor> FromModelSpec( 55*993b0882SAndroid Build Coastguard Worker const tflite::Model* model_spec) { 56*993b0882SAndroid Build Coastguard Worker auto model = TfLiteModelFromModelSpec(model_spec); 57*993b0882SAndroid Build Coastguard Worker if (!model) { 58*993b0882SAndroid Build Coastguard Worker return nullptr; 59*993b0882SAndroid Build Coastguard Worker } 60*993b0882SAndroid Build Coastguard Worker return std::unique_ptr<TfLiteModelExecutor>( 61*993b0882SAndroid Build Coastguard Worker new TfLiteModelExecutor(std::move(model))); 62*993b0882SAndroid Build Coastguard Worker } 63*993b0882SAndroid Build Coastguard Worker FromBuffer(const flatbuffers::Vector<uint8_t> * model_spec_buffer)64*993b0882SAndroid Build Coastguard Worker static std::unique_ptr<TfLiteModelExecutor> FromBuffer( 65*993b0882SAndroid Build Coastguard Worker const flatbuffers::Vector<uint8_t>* model_spec_buffer) { 66*993b0882SAndroid Build Coastguard Worker auto model = TfLiteModelFromBuffer(model_spec_buffer); 67*993b0882SAndroid Build Coastguard Worker if (!model) { 68*993b0882SAndroid Build Coastguard Worker return nullptr; 69*993b0882SAndroid Build Coastguard Worker } 70*993b0882SAndroid Build Coastguard Worker return std::unique_ptr<TfLiteModelExecutor>( 71*993b0882SAndroid Build Coastguard Worker new TfLiteModelExecutor(std::move(model))); 72*993b0882SAndroid Build Coastguard Worker } 73*993b0882SAndroid Build Coastguard Worker 74*993b0882SAndroid Build Coastguard Worker // Creates an Interpreter for the model that serves as a scratch-pad for the 75*993b0882SAndroid Build Coastguard Worker // inference. The Interpreter is NOT thread-safe. 76*993b0882SAndroid Build Coastguard Worker std::unique_ptr<tflite::Interpreter> CreateInterpreter() const; 77*993b0882SAndroid Build Coastguard Worker 78*993b0882SAndroid Build Coastguard Worker template <typename T> SetInput(const int input_index,const TensorView<T> & input_data,tflite::Interpreter * interpreter)79*993b0882SAndroid Build Coastguard Worker void SetInput(const int input_index, const TensorView<T>& input_data, 80*993b0882SAndroid Build Coastguard Worker tflite::Interpreter* interpreter) const { 81*993b0882SAndroid Build Coastguard Worker input_data.copy_to(interpreter->typed_input_tensor<T>(input_index), 82*993b0882SAndroid Build Coastguard Worker input_data.size()); 83*993b0882SAndroid Build Coastguard Worker } 84*993b0882SAndroid Build Coastguard Worker 85*993b0882SAndroid Build Coastguard Worker template <typename T> SetInput(const int input_index,const std::vector<T> & input_data,tflite::Interpreter * interpreter)86*993b0882SAndroid Build Coastguard Worker void SetInput(const int input_index, const std::vector<T>& input_data, 87*993b0882SAndroid Build Coastguard Worker tflite::Interpreter* interpreter) const { 88*993b0882SAndroid Build Coastguard Worker std::copy(input_data.begin(), input_data.end(), 89*993b0882SAndroid Build Coastguard Worker interpreter->typed_input_tensor<T>(input_index)); 90*993b0882SAndroid Build Coastguard Worker } 91*993b0882SAndroid Build Coastguard Worker 92*993b0882SAndroid Build Coastguard Worker template <typename T> SetInput(const int input_index,const T input_value,tflite::Interpreter * interpreter)93*993b0882SAndroid Build Coastguard Worker void SetInput(const int input_index, const T input_value, 94*993b0882SAndroid Build Coastguard Worker tflite::Interpreter* interpreter) const { 95*993b0882SAndroid Build Coastguard Worker TfLiteTensor* input_tensor = 96*993b0882SAndroid Build Coastguard Worker interpreter->tensor(interpreter->inputs()[input_index]); 97*993b0882SAndroid Build Coastguard Worker switch (input_tensor->type) { 98*993b0882SAndroid Build Coastguard Worker case kTfLiteFloat32: 99*993b0882SAndroid Build Coastguard Worker *tflite::GetTensorData<float>(input_tensor) = input_value; 100*993b0882SAndroid Build Coastguard Worker break; 101*993b0882SAndroid Build Coastguard Worker case kTfLiteInt32: 102*993b0882SAndroid Build Coastguard Worker *tflite::GetTensorData<int32_t>(input_tensor) = input_value; 103*993b0882SAndroid Build Coastguard Worker break; 104*993b0882SAndroid Build Coastguard Worker case kTfLiteUInt8: 105*993b0882SAndroid Build Coastguard Worker *tflite::GetTensorData<uint8_t>(input_tensor) = input_value; 106*993b0882SAndroid Build Coastguard Worker break; 107*993b0882SAndroid Build Coastguard Worker case kTfLiteInt64: 108*993b0882SAndroid Build Coastguard Worker *tflite::GetTensorData<int64_t>(input_tensor) = input_value; 109*993b0882SAndroid Build Coastguard Worker break; 110*993b0882SAndroid Build Coastguard Worker case kTfLiteBool: 111*993b0882SAndroid Build Coastguard Worker *tflite::GetTensorData<bool>(input_tensor) = input_value; 112*993b0882SAndroid Build Coastguard Worker break; 113*993b0882SAndroid Build Coastguard Worker case kTfLiteInt16: 114*993b0882SAndroid Build Coastguard Worker *tflite::GetTensorData<int16_t>(input_tensor) = input_value; 115*993b0882SAndroid Build Coastguard Worker break; 116*993b0882SAndroid Build Coastguard Worker case kTfLiteInt8: 117*993b0882SAndroid Build Coastguard Worker *tflite::GetTensorData<int8_t>(input_tensor) = input_value; 118*993b0882SAndroid Build Coastguard Worker break; 119*993b0882SAndroid Build Coastguard Worker default: 120*993b0882SAndroid Build Coastguard Worker break; 121*993b0882SAndroid Build Coastguard Worker } 122*993b0882SAndroid Build Coastguard Worker } 123*993b0882SAndroid Build Coastguard Worker 124*993b0882SAndroid Build Coastguard Worker template <typename T> OutputView(const int output_index,const tflite::Interpreter * interpreter)125*993b0882SAndroid Build Coastguard Worker TensorView<T> OutputView(const int output_index, 126*993b0882SAndroid Build Coastguard Worker const tflite::Interpreter* interpreter) const { 127*993b0882SAndroid Build Coastguard Worker const TfLiteTensor* output_tensor = 128*993b0882SAndroid Build Coastguard Worker interpreter->tensor(interpreter->outputs()[output_index]); 129*993b0882SAndroid Build Coastguard Worker return TensorView<T>(interpreter->typed_output_tensor<T>(output_index), 130*993b0882SAndroid Build Coastguard Worker std::vector<int>(output_tensor->dims->data, 131*993b0882SAndroid Build Coastguard Worker output_tensor->dims->data + 132*993b0882SAndroid Build Coastguard Worker output_tensor->dims->size)); 133*993b0882SAndroid Build Coastguard Worker } 134*993b0882SAndroid Build Coastguard Worker 135*993b0882SAndroid Build Coastguard Worker template <typename T> Output(const int output_index,const tflite::Interpreter * interpreter)136*993b0882SAndroid Build Coastguard Worker std::vector<T> Output(const int output_index, 137*993b0882SAndroid Build Coastguard Worker const tflite::Interpreter* interpreter) const { 138*993b0882SAndroid Build Coastguard Worker TensorView<T> output_view = OutputView<T>(output_index, interpreter); 139*993b0882SAndroid Build Coastguard Worker return std::vector<T>(output_view.data(), 140*993b0882SAndroid Build Coastguard Worker output_view.data() + output_view.size()); 141*993b0882SAndroid Build Coastguard Worker } 142*993b0882SAndroid Build Coastguard Worker 143*993b0882SAndroid Build Coastguard Worker protected: 144*993b0882SAndroid Build Coastguard Worker explicit TfLiteModelExecutor( 145*993b0882SAndroid Build Coastguard Worker std::unique_ptr<const tflite::FlatBufferModel> model); 146*993b0882SAndroid Build Coastguard Worker TfLiteModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model, 147*993b0882SAndroid Build Coastguard Worker std::unique_ptr<tflite::OpResolver> resolver); 148*993b0882SAndroid Build Coastguard Worker 149*993b0882SAndroid Build Coastguard Worker std::unique_ptr<const tflite::FlatBufferModel> model_; 150*993b0882SAndroid Build Coastguard Worker std::unique_ptr<tflite::OpResolver> resolver_; 151*993b0882SAndroid Build Coastguard Worker }; 152*993b0882SAndroid Build Coastguard Worker 153*993b0882SAndroid Build Coastguard Worker template <> 154*993b0882SAndroid Build Coastguard Worker void TfLiteModelExecutor::SetInput(const int input_index, 155*993b0882SAndroid Build Coastguard Worker const std::vector<std::string>& input_data, 156*993b0882SAndroid Build Coastguard Worker tflite::Interpreter* interpreter) const; 157*993b0882SAndroid Build Coastguard Worker 158*993b0882SAndroid Build Coastguard Worker template <> 159*993b0882SAndroid Build Coastguard Worker std::vector<tflite::StringRef> TfLiteModelExecutor::Output( 160*993b0882SAndroid Build Coastguard Worker const int output_index, const tflite::Interpreter* interpreter) const; 161*993b0882SAndroid Build Coastguard Worker 162*993b0882SAndroid Build Coastguard Worker template <> 163*993b0882SAndroid Build Coastguard Worker std::vector<std::string> TfLiteModelExecutor::Output( 164*993b0882SAndroid Build Coastguard Worker const int output_index, const tflite::Interpreter* interpreter) const; 165*993b0882SAndroid Build Coastguard Worker 166*993b0882SAndroid Build Coastguard Worker } // namespace libtextclassifier3 167*993b0882SAndroid Build Coastguard Worker 168*993b0882SAndroid Build Coastguard Worker #endif // LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_ 169