xref: /aosp_15_r20/external/libtextclassifier/native/actions/actions-suggestions.h (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
18 #define LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
19 
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <unordered_set>
25 #include <vector>
26 
27 #include "actions/actions_model_generated.h"
28 #include "actions/conversation_intent_detection/conversation-intent-detection.h"
29 #include "actions/feature-processor.h"
30 #include "actions/grammar-actions.h"
31 #include "actions/ranker.h"
32 #include "actions/regex-actions.h"
33 #include "actions/sensitive-classifier-base.h"
34 #include "actions/types.h"
35 #include "annotator/annotator.h"
36 #include "annotator/model-executor.h"
37 #include "annotator/types.h"
38 #include "utils/flatbuffers/flatbuffers.h"
39 #include "utils/flatbuffers/mutable.h"
40 #include "utils/i18n/locale.h"
41 #include "utils/memory/mmap.h"
42 #include "utils/tflite-model-executor.h"
43 #include "utils/utf8/unilib.h"
44 #include "utils/variant.h"
45 #include "utils/zlib/zlib.h"
46 #include "absl/container/flat_hash_map.h"
47 #include "absl/container/flat_hash_set.h"
48 #include "absl/random/random.h"
49 
50 namespace libtextclassifier3 {
51 
52 // Class for predicting actions following a conversation.
53 class ActionsSuggestions {
54  public:
55   // Creates ActionsSuggestions from given data buffer with model.
56   static std::unique_ptr<ActionsSuggestions> FromUnownedBuffer(
57       const uint8_t* buffer, const int size, const UniLib* unilib = nullptr,
58       const std::string& triggering_preconditions_overlay = "");
59 
60   // Creates ActionsSuggestions from model in the ScopedMmap object and takes
61   // ownership of it.
62   static std::unique_ptr<ActionsSuggestions> FromScopedMmap(
63       std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
64       const UniLib* unilib = nullptr,
65       const std::string& triggering_preconditions_overlay = "");
66   // Same as above, but also takes ownership of the unilib.
67   static std::unique_ptr<ActionsSuggestions> FromScopedMmap(
68       std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
69       std::unique_ptr<UniLib> unilib,
70       const std::string& triggering_preconditions_overlay);
71 
72   // Creates ActionsSuggestions from model given as a file descriptor, offset
73   // and size in it. If offset and size are less than 0, will ignore them and
74   // will just use the fd.
75   static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
76       const int fd, const int offset, const int size,
77       const UniLib* unilib = nullptr,
78       const std::string& triggering_preconditions_overlay = "");
79   // Same as above, but also takes ownership of the unilib.
80   static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
81       const int fd, const int offset, const int size,
82       std::unique_ptr<UniLib> unilib,
83       const std::string& triggering_preconditions_overlay = "");
84 
85   // Creates ActionsSuggestions from model given as a file descriptor.
86   static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
87       const int fd, const UniLib* unilib = nullptr,
88       const std::string& triggering_preconditions_overlay = "");
89   // Same as above, but also takes ownership of the unilib.
90   static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
91       const int fd, std::unique_ptr<UniLib> unilib,
92       const std::string& triggering_preconditions_overlay);
93 
94   // Creates ActionsSuggestions from model given as a POSIX path.
95   static std::unique_ptr<ActionsSuggestions> FromPath(
96       const std::string& path, const UniLib* unilib = nullptr,
97       const std::string& triggering_preconditions_overlay = "");
98   // Same as above, but also takes ownership of unilib.
99   static std::unique_ptr<ActionsSuggestions> FromPath(
100       const std::string& path, std::unique_ptr<UniLib> unilib,
101       const std::string& triggering_preconditions_overlay);
102 
103   ActionsSuggestionsResponse SuggestActions(
104       const Conversation& conversation,
105       const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;
106 
107   ActionsSuggestionsResponse SuggestActions(
108       const Conversation& conversation, const Annotator* annotator,
109       const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;
110 
111   bool InitializeConversationIntentDetection(
112       const std::string& serialized_config);
113 
114   const ActionsModel* model() const;
115   const reflection::Schema* entity_data_schema() const;
116 
117   static constexpr int kLocalUserId = 0;
118 
119  protected:
120   // Exposed for testing.
121   bool EmbedTokenId(const int32 token_id, std::vector<float>* embedding) const;
122 
123   // Embeds the tokens per message separately. Each message is padded to the
124   // maximum length with the padding token.
125   bool EmbedTokensPerMessage(const std::vector<std::vector<Token>>& tokens,
126                              std::vector<float>* embeddings,
127                              int* max_num_tokens_per_message) const;
128 
129   // Concatenates the embedded message tokens - separated by start and end
130   // token between messages.
131   // If the total token count is greater than the maximum length, tokens at the
132   // start are dropped to fit into the limit.
133   // If the total token count is smaller than the minimum length, padding tokens
134   // are added to the end.
135   // Messages are assumed to be ordered by recency - most recent is last.
136   bool EmbedAndFlattenTokens(const std::vector<std::vector<Token>>& tokens,
137                              std::vector<float>* embeddings,
138                              int* total_token_count) const;
139 
140   const ActionsModel* model_;
141 
142   // Feature extractor and options.
143   std::unique_ptr<const ActionsFeatureProcessor> feature_processor_;
144   std::unique_ptr<const EmbeddingExecutor> embedding_executor_;
145   std::vector<float> embedded_padding_token_;
146   std::vector<float> embedded_start_token_;
147   std::vector<float> embedded_end_token_;
148   int token_embedding_size_;
149 
150  private:
151   // Checks that model contains all required fields, and initializes internal
152   // datastructures.
153   bool ValidateAndInitialize();
154 
155   void SetOrCreateUnilib(const UniLib* unilib);
156 
157   // Prepare preconditions.
158   // Takes values from flag provided data, but falls back to model provided
159   // values for parameters that are not explicitly provided.
160   bool InitializeTriggeringPreconditions();
161 
162   // Tokenizes a conversation and produces the tokens per message.
163   std::vector<std::vector<Token>> Tokenize(
164       const std::vector<std::string>& context) const;
165 
166   bool AllocateInput(const int conversation_length, const int max_tokens,
167                      const int total_token_count,
168                      tflite::Interpreter* interpreter) const;
169 
170   bool SetupModelInput(const std::vector<std::string>& context,
171                        const std::vector<int>& user_ids,
172                        const std::vector<float>& time_diffs,
173                        const int num_suggestions,
174                        const ActionSuggestionOptions& options,
175                        tflite::Interpreter* interpreter) const;
176 
177   void FillSuggestionFromSpecWithEntityData(const ActionSuggestionSpec* spec,
178                                             ActionSuggestion* suggestion) const;
179 
180   void PopulateTextReplies(
181       const tflite::Interpreter* interpreter, int suggestion_index,
182       int score_index, const std::string& type, float priority_score,
183       const absl::flat_hash_set<std::string>& blocklist,
184       const absl::flat_hash_map<std::string, std::vector<std::string>>&
185           concept_mappings,
186       ActionsSuggestionsResponse* response) const;
187 
188   void PopulateIntentTriggering(const tflite::Interpreter* interpreter,
189                                 int suggestion_index, int score_index,
190                                 const ActionSuggestionSpec* task_spec,
191                                 ActionsSuggestionsResponse* response) const;
192 
193   bool ReadModelOutput(tflite::Interpreter* interpreter,
194                        const ActionSuggestionOptions& options,
195                        ActionsSuggestionsResponse* response) const;
196 
197   bool SuggestActionsFromModel(
198       const Conversation& conversation, const int num_messages,
199       const ActionSuggestionOptions& options,
200       ActionsSuggestionsResponse* response,
201       std::unique_ptr<tflite::Interpreter>* interpreter) const;
202 
203   Status SuggestActionsFromConversationIntentDetection(
204       const Conversation& conversation, const ActionSuggestionOptions& options,
205       std::vector<ActionSuggestion>* actions) const;
206 
207   // Creates options for annotation of a message.
208   AnnotationOptions AnnotationOptionsForMessage(
209       const ConversationMessage& message) const;
210 
211   void SuggestActionsFromAnnotations(
212       const Conversation& conversation,
213       std::vector<ActionSuggestion>* actions) const;
214 
215   void SuggestActionsFromAnnotation(
216       const int message_index, const ActionSuggestionAnnotation& annotation,
217       std::vector<ActionSuggestion>* actions) const;
218 
219   // Run annotator on the messages of a conversation.
220   Conversation AnnotateConversation(const Conversation& conversation,
221                                     const Annotator* annotator) const;
222 
223   // Deduplicates equivalent annotations - annotations that have the same type
224   // and same span text.
225   // Returns the indices of the deduplicated annotations.
226   std::vector<int> DeduplicateAnnotations(
227       const std::vector<ActionSuggestionAnnotation>& annotations) const;
228 
229   bool SuggestActionsFromLua(
230       const Conversation& conversation,
231       const TfLiteModelExecutor* model_executor,
232       const tflite::Interpreter* interpreter,
233       const reflection::Schema* annotation_entity_data_schema,
234       std::vector<ActionSuggestion>* actions) const;
235 
236   bool GatherActionsSuggestions(const Conversation& conversation,
237                                 const Annotator* annotator,
238                                 const ActionSuggestionOptions& options,
239                                 ActionsSuggestionsResponse* response) const;
240 
241   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap_;
242 
243   // Tensorflow Lite models.
244   std::unique_ptr<const TfLiteModelExecutor> model_executor_;
245 
246   // Regex rules model.
247   std::unique_ptr<RegexActions> regex_actions_;
248 
249   // The grammar rules model.
250   std::unique_ptr<GrammarActions> grammar_actions_;
251 
252   std::unique_ptr<UniLib> owned_unilib_;
253   const UniLib* unilib_;
254 
255   // Locales supported by the model.
256   std::vector<Locale> locales_;
257 
258   // Annotation entities used by the model.
259   std::unordered_set<std::string> annotation_entity_types_;
260 
261   // Builder for creating extra data.
262   const reflection::Schema* entity_data_schema_;
263   std::unique_ptr<MutableFlatbufferBuilder> entity_data_builder_;
264   std::unique_ptr<ActionsSuggestionsRanker> ranker_;
265 
266   std::string lua_bytecode_;
267 
268   // Triggering preconditions. These parameters can be backed by the model and
269   // (partially) be provided by flags.
270   TriggeringPreconditionsT preconditions_;
271   std::string triggering_preconditions_overlay_buffer_;
272   const TriggeringPreconditions* triggering_preconditions_overlay_;
273 
274   // Low confidence input ngram classifier.
275   std::unique_ptr<const SensitiveTopicModelBase> sensitive_model_;
276 
277   // Conversation intent detection model for additional actions.
278   std::unique_ptr<const ConversationIntentDetection>
279       conversation_intent_detection_;
280 
281   // Used for randomly selecting candidates.
282   mutable absl::BitGen bit_gen_;
283 };
284 
285 // Interprets the buffer as a Model flatbuffer and returns it for reading.
286 const ActionsModel* ViewActionsModel(const void* buffer, int size);
287 
288 // Opens model from given path and runs a function, passing the loaded Model
289 // flatbuffer as an argument.
290 //
291 // This is mainly useful if we don't want to pay the cost for the model
292 // initialization because we'll be only reading some flatbuffer values from the
293 // file.
294 template <typename ReturnType, typename Func>
VisitActionsModel(const std::string & path,Func function)295 ReturnType VisitActionsModel(const std::string& path, Func function) {
296   ScopedMmap mmap(path);
297   if (!mmap.handle().ok()) {
298     function(/*model=*/nullptr);
299   }
300   const ActionsModel* model =
301       ViewActionsModel(mmap.handle().start(), mmap.handle().num_bytes());
302   return function(model);
303 }
304 
305 class ActionsSuggestionsTypes {
306  public:
307   // Should be in sync with those defined in Android.
308   // android/frameworks/base/core/java/android/view/textclassifier/ConversationActions.java
ViewCalendar()309   static const std::string& ViewCalendar() {
310     static const std::string& value =
311         *[]() { return new std::string("view_calendar"); }();
312     return value;
313   }
ViewMap()314   static const std::string& ViewMap() {
315     static const std::string& value =
316         *[]() { return new std::string("view_map"); }();
317     return value;
318   }
TrackFlight()319   static const std::string& TrackFlight() {
320     static const std::string& value =
321         *[]() { return new std::string("track_flight"); }();
322     return value;
323   }
OpenUrl()324   static const std::string& OpenUrl() {
325     static const std::string& value =
326         *[]() { return new std::string("open_url"); }();
327     return value;
328   }
SendSms()329   static const std::string& SendSms() {
330     static const std::string& value =
331         *[]() { return new std::string("send_sms"); }();
332     return value;
333   }
CallPhone()334   static const std::string& CallPhone() {
335     static const std::string& value =
336         *[]() { return new std::string("call_phone"); }();
337     return value;
338   }
SendEmail()339   static const std::string& SendEmail() {
340     static const std::string& value =
341         *[]() { return new std::string("send_email"); }();
342     return value;
343   }
ShareLocation()344   static const std::string& ShareLocation() {
345     static const std::string& value =
346         *[]() { return new std::string("share_location"); }();
347     return value;
348   }
CreateReminder()349   static const std::string& CreateReminder() {
350     static const std::string& value =
351         *[]() { return new std::string("create_reminder"); }();
352     return value;
353   }
TextReply()354   static const std::string& TextReply() {
355     static const std::string& value =
356         *[]() { return new std::string("text_reply"); }();
357     return value;
358   }
AddContact()359   static const std::string& AddContact() {
360     static const std::string& value =
361         *[]() { return new std::string("add_contact"); }();
362     return value;
363   }
Copy()364   static const std::string& Copy() {
365     static const std::string& value =
366         *[]() { return new std::string("copy"); }();
367     return value;
368   }
369 };
370 
371 }  // namespace libtextclassifier3
372 
373 #endif  // LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
374