1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
18 #define LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
19
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <unordered_map>
24 #include <unordered_set>
25 #include <vector>
26
27 #include "actions/actions_model_generated.h"
28 #include "actions/conversation_intent_detection/conversation-intent-detection.h"
29 #include "actions/feature-processor.h"
30 #include "actions/grammar-actions.h"
31 #include "actions/ranker.h"
32 #include "actions/regex-actions.h"
33 #include "actions/sensitive-classifier-base.h"
34 #include "actions/types.h"
35 #include "annotator/annotator.h"
36 #include "annotator/model-executor.h"
37 #include "annotator/types.h"
38 #include "utils/flatbuffers/flatbuffers.h"
39 #include "utils/flatbuffers/mutable.h"
40 #include "utils/i18n/locale.h"
41 #include "utils/memory/mmap.h"
42 #include "utils/tflite-model-executor.h"
43 #include "utils/utf8/unilib.h"
44 #include "utils/variant.h"
45 #include "utils/zlib/zlib.h"
46 #include "absl/container/flat_hash_map.h"
47 #include "absl/container/flat_hash_set.h"
48 #include "absl/random/random.h"
49
50 namespace libtextclassifier3 {
51
52 // Class for predicting actions following a conversation.
53 class ActionsSuggestions {
54 public:
55 // Creates ActionsSuggestions from given data buffer with model.
56 static std::unique_ptr<ActionsSuggestions> FromUnownedBuffer(
57 const uint8_t* buffer, const int size, const UniLib* unilib = nullptr,
58 const std::string& triggering_preconditions_overlay = "");
59
60 // Creates ActionsSuggestions from model in the ScopedMmap object and takes
61 // ownership of it.
62 static std::unique_ptr<ActionsSuggestions> FromScopedMmap(
63 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
64 const UniLib* unilib = nullptr,
65 const std::string& triggering_preconditions_overlay = "");
66 // Same as above, but also takes ownership of the unilib.
67 static std::unique_ptr<ActionsSuggestions> FromScopedMmap(
68 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
69 std::unique_ptr<UniLib> unilib,
70 const std::string& triggering_preconditions_overlay);
71
72 // Creates ActionsSuggestions from model given as a file descriptor, offset
73 // and size in it. If offset and size are less than 0, will ignore them and
74 // will just use the fd.
75 static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
76 const int fd, const int offset, const int size,
77 const UniLib* unilib = nullptr,
78 const std::string& triggering_preconditions_overlay = "");
79 // Same as above, but also takes ownership of the unilib.
80 static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
81 const int fd, const int offset, const int size,
82 std::unique_ptr<UniLib> unilib,
83 const std::string& triggering_preconditions_overlay = "");
84
85 // Creates ActionsSuggestions from model given as a file descriptor.
86 static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
87 const int fd, const UniLib* unilib = nullptr,
88 const std::string& triggering_preconditions_overlay = "");
89 // Same as above, but also takes ownership of the unilib.
90 static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
91 const int fd, std::unique_ptr<UniLib> unilib,
92 const std::string& triggering_preconditions_overlay);
93
94 // Creates ActionsSuggestions from model given as a POSIX path.
95 static std::unique_ptr<ActionsSuggestions> FromPath(
96 const std::string& path, const UniLib* unilib = nullptr,
97 const std::string& triggering_preconditions_overlay = "");
98 // Same as above, but also takes ownership of unilib.
99 static std::unique_ptr<ActionsSuggestions> FromPath(
100 const std::string& path, std::unique_ptr<UniLib> unilib,
101 const std::string& triggering_preconditions_overlay);
102
103 ActionsSuggestionsResponse SuggestActions(
104 const Conversation& conversation,
105 const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;
106
107 ActionsSuggestionsResponse SuggestActions(
108 const Conversation& conversation, const Annotator* annotator,
109 const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;
110
111 bool InitializeConversationIntentDetection(
112 const std::string& serialized_config);
113
114 const ActionsModel* model() const;
115 const reflection::Schema* entity_data_schema() const;
116
117 static constexpr int kLocalUserId = 0;
118
119 protected:
120 // Exposed for testing.
121 bool EmbedTokenId(const int32 token_id, std::vector<float>* embedding) const;
122
123 // Embeds the tokens per message separately. Each message is padded to the
124 // maximum length with the padding token.
125 bool EmbedTokensPerMessage(const std::vector<std::vector<Token>>& tokens,
126 std::vector<float>* embeddings,
127 int* max_num_tokens_per_message) const;
128
129 // Concatenates the embedded message tokens - separated by start and end
130 // token between messages.
131 // If the total token count is greater than the maximum length, tokens at the
132 // start are dropped to fit into the limit.
133 // If the total token count is smaller than the minimum length, padding tokens
134 // are added to the end.
135 // Messages are assumed to be ordered by recency - most recent is last.
136 bool EmbedAndFlattenTokens(const std::vector<std::vector<Token>>& tokens,
137 std::vector<float>* embeddings,
138 int* total_token_count) const;
139
140 const ActionsModel* model_;
141
142 // Feature extractor and options.
143 std::unique_ptr<const ActionsFeatureProcessor> feature_processor_;
144 std::unique_ptr<const EmbeddingExecutor> embedding_executor_;
145 std::vector<float> embedded_padding_token_;
146 std::vector<float> embedded_start_token_;
147 std::vector<float> embedded_end_token_;
148 int token_embedding_size_;
149
150 private:
151 // Checks that model contains all required fields, and initializes internal
152 // datastructures.
153 bool ValidateAndInitialize();
154
155 void SetOrCreateUnilib(const UniLib* unilib);
156
157 // Prepare preconditions.
158 // Takes values from flag provided data, but falls back to model provided
159 // values for parameters that are not explicitly provided.
160 bool InitializeTriggeringPreconditions();
161
162 // Tokenizes a conversation and produces the tokens per message.
163 std::vector<std::vector<Token>> Tokenize(
164 const std::vector<std::string>& context) const;
165
166 bool AllocateInput(const int conversation_length, const int max_tokens,
167 const int total_token_count,
168 tflite::Interpreter* interpreter) const;
169
170 bool SetupModelInput(const std::vector<std::string>& context,
171 const std::vector<int>& user_ids,
172 const std::vector<float>& time_diffs,
173 const int num_suggestions,
174 const ActionSuggestionOptions& options,
175 tflite::Interpreter* interpreter) const;
176
177 void FillSuggestionFromSpecWithEntityData(const ActionSuggestionSpec* spec,
178 ActionSuggestion* suggestion) const;
179
180 void PopulateTextReplies(
181 const tflite::Interpreter* interpreter, int suggestion_index,
182 int score_index, const std::string& type, float priority_score,
183 const absl::flat_hash_set<std::string>& blocklist,
184 const absl::flat_hash_map<std::string, std::vector<std::string>>&
185 concept_mappings,
186 ActionsSuggestionsResponse* response) const;
187
188 void PopulateIntentTriggering(const tflite::Interpreter* interpreter,
189 int suggestion_index, int score_index,
190 const ActionSuggestionSpec* task_spec,
191 ActionsSuggestionsResponse* response) const;
192
193 bool ReadModelOutput(tflite::Interpreter* interpreter,
194 const ActionSuggestionOptions& options,
195 ActionsSuggestionsResponse* response) const;
196
197 bool SuggestActionsFromModel(
198 const Conversation& conversation, const int num_messages,
199 const ActionSuggestionOptions& options,
200 ActionsSuggestionsResponse* response,
201 std::unique_ptr<tflite::Interpreter>* interpreter) const;
202
203 Status SuggestActionsFromConversationIntentDetection(
204 const Conversation& conversation, const ActionSuggestionOptions& options,
205 std::vector<ActionSuggestion>* actions) const;
206
207 // Creates options for annotation of a message.
208 AnnotationOptions AnnotationOptionsForMessage(
209 const ConversationMessage& message) const;
210
211 void SuggestActionsFromAnnotations(
212 const Conversation& conversation,
213 std::vector<ActionSuggestion>* actions) const;
214
215 void SuggestActionsFromAnnotation(
216 const int message_index, const ActionSuggestionAnnotation& annotation,
217 std::vector<ActionSuggestion>* actions) const;
218
219 // Run annotator on the messages of a conversation.
220 Conversation AnnotateConversation(const Conversation& conversation,
221 const Annotator* annotator) const;
222
223 // Deduplicates equivalent annotations - annotations that have the same type
224 // and same span text.
225 // Returns the indices of the deduplicated annotations.
226 std::vector<int> DeduplicateAnnotations(
227 const std::vector<ActionSuggestionAnnotation>& annotations) const;
228
229 bool SuggestActionsFromLua(
230 const Conversation& conversation,
231 const TfLiteModelExecutor* model_executor,
232 const tflite::Interpreter* interpreter,
233 const reflection::Schema* annotation_entity_data_schema,
234 std::vector<ActionSuggestion>* actions) const;
235
236 bool GatherActionsSuggestions(const Conversation& conversation,
237 const Annotator* annotator,
238 const ActionSuggestionOptions& options,
239 ActionsSuggestionsResponse* response) const;
240
241 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap_;
242
243 // Tensorflow Lite models.
244 std::unique_ptr<const TfLiteModelExecutor> model_executor_;
245
246 // Regex rules model.
247 std::unique_ptr<RegexActions> regex_actions_;
248
249 // The grammar rules model.
250 std::unique_ptr<GrammarActions> grammar_actions_;
251
252 std::unique_ptr<UniLib> owned_unilib_;
253 const UniLib* unilib_;
254
255 // Locales supported by the model.
256 std::vector<Locale> locales_;
257
258 // Annotation entities used by the model.
259 std::unordered_set<std::string> annotation_entity_types_;
260
261 // Builder for creating extra data.
262 const reflection::Schema* entity_data_schema_;
263 std::unique_ptr<MutableFlatbufferBuilder> entity_data_builder_;
264 std::unique_ptr<ActionsSuggestionsRanker> ranker_;
265
266 std::string lua_bytecode_;
267
268 // Triggering preconditions. These parameters can be backed by the model and
269 // (partially) be provided by flags.
270 TriggeringPreconditionsT preconditions_;
271 std::string triggering_preconditions_overlay_buffer_;
272 const TriggeringPreconditions* triggering_preconditions_overlay_;
273
274 // Low confidence input ngram classifier.
275 std::unique_ptr<const SensitiveTopicModelBase> sensitive_model_;
276
277 // Conversation intent detection model for additional actions.
278 std::unique_ptr<const ConversationIntentDetection>
279 conversation_intent_detection_;
280
281 // Used for randomly selecting candidates.
282 mutable absl::BitGen bit_gen_;
283 };
284
285 // Interprets the buffer as a Model flatbuffer and returns it for reading.
286 const ActionsModel* ViewActionsModel(const void* buffer, int size);
287
288 // Opens model from given path and runs a function, passing the loaded Model
289 // flatbuffer as an argument.
290 //
291 // This is mainly useful if we don't want to pay the cost for the model
292 // initialization because we'll be only reading some flatbuffer values from the
293 // file.
294 template <typename ReturnType, typename Func>
VisitActionsModel(const std::string & path,Func function)295 ReturnType VisitActionsModel(const std::string& path, Func function) {
296 ScopedMmap mmap(path);
297 if (!mmap.handle().ok()) {
298 function(/*model=*/nullptr);
299 }
300 const ActionsModel* model =
301 ViewActionsModel(mmap.handle().start(), mmap.handle().num_bytes());
302 return function(model);
303 }
304
305 class ActionsSuggestionsTypes {
306 public:
307 // Should be in sync with those defined in Android.
308 // android/frameworks/base/core/java/android/view/textclassifier/ConversationActions.java
ViewCalendar()309 static const std::string& ViewCalendar() {
310 static const std::string& value =
311 *[]() { return new std::string("view_calendar"); }();
312 return value;
313 }
ViewMap()314 static const std::string& ViewMap() {
315 static const std::string& value =
316 *[]() { return new std::string("view_map"); }();
317 return value;
318 }
TrackFlight()319 static const std::string& TrackFlight() {
320 static const std::string& value =
321 *[]() { return new std::string("track_flight"); }();
322 return value;
323 }
OpenUrl()324 static const std::string& OpenUrl() {
325 static const std::string& value =
326 *[]() { return new std::string("open_url"); }();
327 return value;
328 }
SendSms()329 static const std::string& SendSms() {
330 static const std::string& value =
331 *[]() { return new std::string("send_sms"); }();
332 return value;
333 }
CallPhone()334 static const std::string& CallPhone() {
335 static const std::string& value =
336 *[]() { return new std::string("call_phone"); }();
337 return value;
338 }
SendEmail()339 static const std::string& SendEmail() {
340 static const std::string& value =
341 *[]() { return new std::string("send_email"); }();
342 return value;
343 }
ShareLocation()344 static const std::string& ShareLocation() {
345 static const std::string& value =
346 *[]() { return new std::string("share_location"); }();
347 return value;
348 }
CreateReminder()349 static const std::string& CreateReminder() {
350 static const std::string& value =
351 *[]() { return new std::string("create_reminder"); }();
352 return value;
353 }
TextReply()354 static const std::string& TextReply() {
355 static const std::string& value =
356 *[]() { return new std::string("text_reply"); }();
357 return value;
358 }
AddContact()359 static const std::string& AddContact() {
360 static const std::string& value =
361 *[]() { return new std::string("add_contact"); }();
362 return value;
363 }
Copy()364 static const std::string& Copy() {
365 static const std::string& value =
366 *[]() { return new std::string("copy"); }();
367 return value;
368 }
369 };
370
371 } // namespace libtextclassifier3
372
373 #endif // LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
374