xref: /aosp_15_r20/external/libtextclassifier/native/annotator/model-executor.h (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker // Contains classes that can execute different models/parts of a model.
18*993b0882SAndroid Build Coastguard Worker 
19*993b0882SAndroid Build Coastguard Worker #ifndef LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_
20*993b0882SAndroid Build Coastguard Worker #define LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_
21*993b0882SAndroid Build Coastguard Worker 
22*993b0882SAndroid Build Coastguard Worker #include <memory>
23*993b0882SAndroid Build Coastguard Worker 
24*993b0882SAndroid Build Coastguard Worker #include "annotator/types.h"
25*993b0882SAndroid Build Coastguard Worker #include "utils/base/logging.h"
26*993b0882SAndroid Build Coastguard Worker #include "utils/tensor-view.h"
27*993b0882SAndroid Build Coastguard Worker #include "utils/tflite-model-executor.h"
28*993b0882SAndroid Build Coastguard Worker 
29*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
30*993b0882SAndroid Build Coastguard Worker 
31*993b0882SAndroid Build Coastguard Worker // Executor for the text selection prediction and classification models.
32*993b0882SAndroid Build Coastguard Worker class ModelExecutor : public TfLiteModelExecutor {
33*993b0882SAndroid Build Coastguard Worker  public:
FromModelSpec(const tflite::Model * model_spec)34*993b0882SAndroid Build Coastguard Worker   static std::unique_ptr<ModelExecutor> FromModelSpec(
35*993b0882SAndroid Build Coastguard Worker       const tflite::Model* model_spec) {
36*993b0882SAndroid Build Coastguard Worker     auto model = TfLiteModelFromModelSpec(model_spec);
37*993b0882SAndroid Build Coastguard Worker     if (!model) {
38*993b0882SAndroid Build Coastguard Worker       return nullptr;
39*993b0882SAndroid Build Coastguard Worker     }
40*993b0882SAndroid Build Coastguard Worker     return std::unique_ptr<ModelExecutor>(new ModelExecutor(std::move(model)));
41*993b0882SAndroid Build Coastguard Worker   }
42*993b0882SAndroid Build Coastguard Worker 
FromBuffer(const flatbuffers::Vector<uint8_t> * model_spec_buffer)43*993b0882SAndroid Build Coastguard Worker   static std::unique_ptr<ModelExecutor> FromBuffer(
44*993b0882SAndroid Build Coastguard Worker       const flatbuffers::Vector<uint8_t>* model_spec_buffer) {
45*993b0882SAndroid Build Coastguard Worker     auto model = TfLiteModelFromBuffer(model_spec_buffer);
46*993b0882SAndroid Build Coastguard Worker     if (!model) {
47*993b0882SAndroid Build Coastguard Worker       return nullptr;
48*993b0882SAndroid Build Coastguard Worker     }
49*993b0882SAndroid Build Coastguard Worker     return std::unique_ptr<ModelExecutor>(new ModelExecutor(std::move(model)));
50*993b0882SAndroid Build Coastguard Worker   }
51*993b0882SAndroid Build Coastguard Worker 
52*993b0882SAndroid Build Coastguard Worker   TensorView<float> ComputeLogits(const TensorView<float>& features,
53*993b0882SAndroid Build Coastguard Worker                                   tflite::Interpreter* interpreter) const;
54*993b0882SAndroid Build Coastguard Worker 
55*993b0882SAndroid Build Coastguard Worker  protected:
ModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model)56*993b0882SAndroid Build Coastguard Worker   explicit ModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model)
57*993b0882SAndroid Build Coastguard Worker       : TfLiteModelExecutor(std::move(model)) {}
58*993b0882SAndroid Build Coastguard Worker 
59*993b0882SAndroid Build Coastguard Worker   static constexpr int kInputIndexFeatures = 0;
60*993b0882SAndroid Build Coastguard Worker   static constexpr int kOutputIndexLogits = 0;
61*993b0882SAndroid Build Coastguard Worker };
62*993b0882SAndroid Build Coastguard Worker 
63*993b0882SAndroid Build Coastguard Worker // Executor for embedding sparse features into a dense vector.
64*993b0882SAndroid Build Coastguard Worker class EmbeddingExecutor {
65*993b0882SAndroid Build Coastguard Worker  public:
~EmbeddingExecutor()66*993b0882SAndroid Build Coastguard Worker   virtual ~EmbeddingExecutor() {}
67*993b0882SAndroid Build Coastguard Worker 
68*993b0882SAndroid Build Coastguard Worker   // Embeds the sparse_features into a dense embedding and adds (+) it
69*993b0882SAndroid Build Coastguard Worker   // element-wise to the dest vector.
70*993b0882SAndroid Build Coastguard Worker   virtual bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
71*993b0882SAndroid Build Coastguard Worker                             int dest_size) const = 0;
72*993b0882SAndroid Build Coastguard Worker 
73*993b0882SAndroid Build Coastguard Worker   // Returns true when the model is ready to be used, false otherwise.
IsReady()74*993b0882SAndroid Build Coastguard Worker   virtual bool IsReady() const { return true; }
75*993b0882SAndroid Build Coastguard Worker };
76*993b0882SAndroid Build Coastguard Worker 
77*993b0882SAndroid Build Coastguard Worker class TFLiteEmbeddingExecutor : public EmbeddingExecutor {
78*993b0882SAndroid Build Coastguard Worker  public:
79*993b0882SAndroid Build Coastguard Worker   static std::unique_ptr<TFLiteEmbeddingExecutor> FromBuffer(
80*993b0882SAndroid Build Coastguard Worker       const flatbuffers::Vector<uint8_t>* model_spec_buffer, int embedding_size,
81*993b0882SAndroid Build Coastguard Worker       int quantization_bits,
82*993b0882SAndroid Build Coastguard Worker       const Model_::EmbeddingPruningMask* embedding_pruning_mask = nullptr);
83*993b0882SAndroid Build Coastguard Worker 
84*993b0882SAndroid Build Coastguard Worker   // Embeds the sparse_features into a dense embedding and adds (+) it
85*993b0882SAndroid Build Coastguard Worker   // element-wise to the dest vector.
86*993b0882SAndroid Build Coastguard Worker   bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
87*993b0882SAndroid Build Coastguard Worker                     int dest_size) const;
88*993b0882SAndroid Build Coastguard Worker 
89*993b0882SAndroid Build Coastguard Worker   // Auxiliary function for computing prefixes used in implementation of
90*993b0882SAndroid Build Coastguard Worker   // efficient mask indexing data structure.
91*993b0882SAndroid Build Coastguard Worker   void ComputePrefixCounts();
92*993b0882SAndroid Build Coastguard Worker 
93*993b0882SAndroid Build Coastguard Worker   // Function implementing mask indexing based on efficient data structure
94*993b0882SAndroid Build Coastguard Worker   int PruneBucketId(int bucket_id) const;
95*993b0882SAndroid Build Coastguard Worker 
96*993b0882SAndroid Build Coastguard Worker  protected:
97*993b0882SAndroid Build Coastguard Worker   explicit TFLiteEmbeddingExecutor(
98*993b0882SAndroid Build Coastguard Worker       std::unique_ptr<TfLiteModelExecutor> executor, int quantization_bits,
99*993b0882SAndroid Build Coastguard Worker       int num_buckets, int bytes_per_embedding, int output_embedding_size,
100*993b0882SAndroid Build Coastguard Worker       const TfLiteTensor* scales, const TfLiteTensor* embeddings,
101*993b0882SAndroid Build Coastguard Worker       std::unique_ptr<tflite::Interpreter> interpreter,
102*993b0882SAndroid Build Coastguard Worker       const Model_::EmbeddingPruningMask* embedding_pruning_mask = nullptr);
103*993b0882SAndroid Build Coastguard Worker 
104*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<TfLiteModelExecutor> executor_;
105*993b0882SAndroid Build Coastguard Worker 
106*993b0882SAndroid Build Coastguard Worker   int quantization_bits_;
107*993b0882SAndroid Build Coastguard Worker   int num_buckets_ = -1;
108*993b0882SAndroid Build Coastguard Worker   int bytes_per_embedding_ = -1;
109*993b0882SAndroid Build Coastguard Worker   int output_embedding_size_ = -1;
110*993b0882SAndroid Build Coastguard Worker   const TfLiteTensor* scales_ = nullptr;
111*993b0882SAndroid Build Coastguard Worker   const TfLiteTensor* embeddings_ = nullptr;
112*993b0882SAndroid Build Coastguard Worker 
113*993b0882SAndroid Build Coastguard Worker   // NOTE: This interpreter is used in a read-only way (as a storage for the
114*993b0882SAndroid Build Coastguard Worker   // model params), thus is still thread-safe.
115*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<tflite::Interpreter> interpreter_;
116*993b0882SAndroid Build Coastguard Worker 
117*993b0882SAndroid Build Coastguard Worker   std::vector<uint64> pruning_mask_;
118*993b0882SAndroid Build Coastguard Worker   std::vector<uint16> prefix_counts_;
119*993b0882SAndroid Build Coastguard Worker   int full_num_buckets_ = -1;
120*993b0882SAndroid Build Coastguard Worker 
121*993b0882SAndroid Build Coastguard Worker   // Index of row of embedding table corresponding to all pruned buckets.
122*993b0882SAndroid Build Coastguard Worker   int pruned_row_bucket_id_ = -1;
123*993b0882SAndroid Build Coastguard Worker };
124*993b0882SAndroid Build Coastguard Worker 
125*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
126*993b0882SAndroid Build Coastguard Worker 
127*993b0882SAndroid Build Coastguard Worker #endif  // LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_
128