1*993b0882SAndroid Build Coastguard Worker /* 2*993b0882SAndroid Build Coastguard Worker * Copyright (C) 2018 The Android Open Source Project 3*993b0882SAndroid Build Coastguard Worker * 4*993b0882SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License"); 5*993b0882SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License. 6*993b0882SAndroid Build Coastguard Worker * You may obtain a copy of the License at 7*993b0882SAndroid Build Coastguard Worker * 8*993b0882SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0 9*993b0882SAndroid Build Coastguard Worker * 10*993b0882SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software 11*993b0882SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS, 12*993b0882SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13*993b0882SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and 14*993b0882SAndroid Build Coastguard Worker * limitations under the License. 15*993b0882SAndroid Build Coastguard Worker */ 16*993b0882SAndroid Build Coastguard Worker 17*993b0882SAndroid Build Coastguard Worker // Contains classes that can execute different models/parts of a model. 18*993b0882SAndroid Build Coastguard Worker 19*993b0882SAndroid Build Coastguard Worker #ifndef LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_ 20*993b0882SAndroid Build Coastguard Worker #define LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_ 21*993b0882SAndroid Build Coastguard Worker 22*993b0882SAndroid Build Coastguard Worker #include <memory> 23*993b0882SAndroid Build Coastguard Worker 24*993b0882SAndroid Build Coastguard Worker #include "annotator/types.h" 25*993b0882SAndroid Build Coastguard Worker #include "utils/base/logging.h" 26*993b0882SAndroid Build Coastguard Worker #include "utils/tensor-view.h" 27*993b0882SAndroid Build Coastguard Worker #include "utils/tflite-model-executor.h" 28*993b0882SAndroid Build Coastguard Worker 29*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 { 30*993b0882SAndroid Build Coastguard Worker 31*993b0882SAndroid Build Coastguard Worker // Executor for the text selection prediction and classification models. 32*993b0882SAndroid Build Coastguard Worker class ModelExecutor : public TfLiteModelExecutor { 33*993b0882SAndroid Build Coastguard Worker public: FromModelSpec(const tflite::Model * model_spec)34*993b0882SAndroid Build Coastguard Worker static std::unique_ptr<ModelExecutor> FromModelSpec( 35*993b0882SAndroid Build Coastguard Worker const tflite::Model* model_spec) { 36*993b0882SAndroid Build Coastguard Worker auto model = TfLiteModelFromModelSpec(model_spec); 37*993b0882SAndroid Build Coastguard Worker if (!model) { 38*993b0882SAndroid Build Coastguard Worker return nullptr; 39*993b0882SAndroid Build Coastguard Worker } 40*993b0882SAndroid Build Coastguard Worker return std::unique_ptr<ModelExecutor>(new ModelExecutor(std::move(model))); 41*993b0882SAndroid Build Coastguard Worker } 42*993b0882SAndroid Build Coastguard Worker FromBuffer(const flatbuffers::Vector<uint8_t> * model_spec_buffer)43*993b0882SAndroid Build Coastguard Worker static std::unique_ptr<ModelExecutor> FromBuffer( 44*993b0882SAndroid Build Coastguard Worker const flatbuffers::Vector<uint8_t>* model_spec_buffer) { 45*993b0882SAndroid Build Coastguard Worker auto model = TfLiteModelFromBuffer(model_spec_buffer); 46*993b0882SAndroid Build Coastguard Worker if (!model) { 47*993b0882SAndroid Build Coastguard Worker return nullptr; 48*993b0882SAndroid Build Coastguard Worker } 49*993b0882SAndroid Build Coastguard Worker return std::unique_ptr<ModelExecutor>(new ModelExecutor(std::move(model))); 50*993b0882SAndroid Build Coastguard Worker } 51*993b0882SAndroid Build Coastguard Worker 52*993b0882SAndroid Build Coastguard Worker TensorView<float> ComputeLogits(const TensorView<float>& features, 53*993b0882SAndroid Build Coastguard Worker tflite::Interpreter* interpreter) const; 54*993b0882SAndroid Build Coastguard Worker 55*993b0882SAndroid Build Coastguard Worker protected: ModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model)56*993b0882SAndroid Build Coastguard Worker explicit ModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model) 57*993b0882SAndroid Build Coastguard Worker : TfLiteModelExecutor(std::move(model)) {} 58*993b0882SAndroid Build Coastguard Worker 59*993b0882SAndroid Build Coastguard Worker static constexpr int kInputIndexFeatures = 0; 60*993b0882SAndroid Build Coastguard Worker static constexpr int kOutputIndexLogits = 0; 61*993b0882SAndroid Build Coastguard Worker }; 62*993b0882SAndroid Build Coastguard Worker 63*993b0882SAndroid Build Coastguard Worker // Executor for embedding sparse features into a dense vector. 64*993b0882SAndroid Build Coastguard Worker class EmbeddingExecutor { 65*993b0882SAndroid Build Coastguard Worker public: ~EmbeddingExecutor()66*993b0882SAndroid Build Coastguard Worker virtual ~EmbeddingExecutor() {} 67*993b0882SAndroid Build Coastguard Worker 68*993b0882SAndroid Build Coastguard Worker // Embeds the sparse_features into a dense embedding and adds (+) it 69*993b0882SAndroid Build Coastguard Worker // element-wise to the dest vector. 70*993b0882SAndroid Build Coastguard Worker virtual bool AddEmbedding(const TensorView<int>& sparse_features, float* dest, 71*993b0882SAndroid Build Coastguard Worker int dest_size) const = 0; 72*993b0882SAndroid Build Coastguard Worker 73*993b0882SAndroid Build Coastguard Worker // Returns true when the model is ready to be used, false otherwise. IsReady()74*993b0882SAndroid Build Coastguard Worker virtual bool IsReady() const { return true; } 75*993b0882SAndroid Build Coastguard Worker }; 76*993b0882SAndroid Build Coastguard Worker 77*993b0882SAndroid Build Coastguard Worker class TFLiteEmbeddingExecutor : public EmbeddingExecutor { 78*993b0882SAndroid Build Coastguard Worker public: 79*993b0882SAndroid Build Coastguard Worker static std::unique_ptr<TFLiteEmbeddingExecutor> FromBuffer( 80*993b0882SAndroid Build Coastguard Worker const flatbuffers::Vector<uint8_t>* model_spec_buffer, int embedding_size, 81*993b0882SAndroid Build Coastguard Worker int quantization_bits, 82*993b0882SAndroid Build Coastguard Worker const Model_::EmbeddingPruningMask* embedding_pruning_mask = nullptr); 83*993b0882SAndroid Build Coastguard Worker 84*993b0882SAndroid Build Coastguard Worker // Embeds the sparse_features into a dense embedding and adds (+) it 85*993b0882SAndroid Build Coastguard Worker // element-wise to the dest vector. 86*993b0882SAndroid Build Coastguard Worker bool AddEmbedding(const TensorView<int>& sparse_features, float* dest, 87*993b0882SAndroid Build Coastguard Worker int dest_size) const; 88*993b0882SAndroid Build Coastguard Worker 89*993b0882SAndroid Build Coastguard Worker // Auxiliary function for computing prefixes used in implementation of 90*993b0882SAndroid Build Coastguard Worker // efficient mask indexing data structure. 91*993b0882SAndroid Build Coastguard Worker void ComputePrefixCounts(); 92*993b0882SAndroid Build Coastguard Worker 93*993b0882SAndroid Build Coastguard Worker // Function implementing mask indexing based on efficient data structure 94*993b0882SAndroid Build Coastguard Worker int PruneBucketId(int bucket_id) const; 95*993b0882SAndroid Build Coastguard Worker 96*993b0882SAndroid Build Coastguard Worker protected: 97*993b0882SAndroid Build Coastguard Worker explicit TFLiteEmbeddingExecutor( 98*993b0882SAndroid Build Coastguard Worker std::unique_ptr<TfLiteModelExecutor> executor, int quantization_bits, 99*993b0882SAndroid Build Coastguard Worker int num_buckets, int bytes_per_embedding, int output_embedding_size, 100*993b0882SAndroid Build Coastguard Worker const TfLiteTensor* scales, const TfLiteTensor* embeddings, 101*993b0882SAndroid Build Coastguard Worker std::unique_ptr<tflite::Interpreter> interpreter, 102*993b0882SAndroid Build Coastguard Worker const Model_::EmbeddingPruningMask* embedding_pruning_mask = nullptr); 103*993b0882SAndroid Build Coastguard Worker 104*993b0882SAndroid Build Coastguard Worker std::unique_ptr<TfLiteModelExecutor> executor_; 105*993b0882SAndroid Build Coastguard Worker 106*993b0882SAndroid Build Coastguard Worker int quantization_bits_; 107*993b0882SAndroid Build Coastguard Worker int num_buckets_ = -1; 108*993b0882SAndroid Build Coastguard Worker int bytes_per_embedding_ = -1; 109*993b0882SAndroid Build Coastguard Worker int output_embedding_size_ = -1; 110*993b0882SAndroid Build Coastguard Worker const TfLiteTensor* scales_ = nullptr; 111*993b0882SAndroid Build Coastguard Worker const TfLiteTensor* embeddings_ = nullptr; 112*993b0882SAndroid Build Coastguard Worker 113*993b0882SAndroid Build Coastguard Worker // NOTE: This interpreter is used in a read-only way (as a storage for the 114*993b0882SAndroid Build Coastguard Worker // model params), thus is still thread-safe. 115*993b0882SAndroid Build Coastguard Worker std::unique_ptr<tflite::Interpreter> interpreter_; 116*993b0882SAndroid Build Coastguard Worker 117*993b0882SAndroid Build Coastguard Worker std::vector<uint64> pruning_mask_; 118*993b0882SAndroid Build Coastguard Worker std::vector<uint16> prefix_counts_; 119*993b0882SAndroid Build Coastguard Worker int full_num_buckets_ = -1; 120*993b0882SAndroid Build Coastguard Worker 121*993b0882SAndroid Build Coastguard Worker // Index of row of embedding table corresponding to all pruned buckets. 122*993b0882SAndroid Build Coastguard Worker int pruned_row_bucket_id_ = -1; 123*993b0882SAndroid Build Coastguard Worker }; 124*993b0882SAndroid Build Coastguard Worker 125*993b0882SAndroid Build Coastguard Worker } // namespace libtextclassifier3 126*993b0882SAndroid Build Coastguard Worker 127*993b0882SAndroid Build Coastguard Worker #endif // LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_ 128