xref: /aosp_15_r20/external/libtextclassifier/native/actions/actions-suggestions.h (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #ifndef LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
18*993b0882SAndroid Build Coastguard Worker #define LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
19*993b0882SAndroid Build Coastguard Worker 
20*993b0882SAndroid Build Coastguard Worker #include <map>
21*993b0882SAndroid Build Coastguard Worker #include <memory>
22*993b0882SAndroid Build Coastguard Worker #include <string>
23*993b0882SAndroid Build Coastguard Worker #include <unordered_map>
24*993b0882SAndroid Build Coastguard Worker #include <unordered_set>
25*993b0882SAndroid Build Coastguard Worker #include <vector>
26*993b0882SAndroid Build Coastguard Worker 
27*993b0882SAndroid Build Coastguard Worker #include "actions/actions_model_generated.h"
28*993b0882SAndroid Build Coastguard Worker #include "actions/conversation_intent_detection/conversation-intent-detection.h"
29*993b0882SAndroid Build Coastguard Worker #include "actions/feature-processor.h"
30*993b0882SAndroid Build Coastguard Worker #include "actions/grammar-actions.h"
31*993b0882SAndroid Build Coastguard Worker #include "actions/ranker.h"
32*993b0882SAndroid Build Coastguard Worker #include "actions/regex-actions.h"
33*993b0882SAndroid Build Coastguard Worker #include "actions/sensitive-classifier-base.h"
34*993b0882SAndroid Build Coastguard Worker #include "actions/types.h"
35*993b0882SAndroid Build Coastguard Worker #include "annotator/annotator.h"
36*993b0882SAndroid Build Coastguard Worker #include "annotator/model-executor.h"
37*993b0882SAndroid Build Coastguard Worker #include "annotator/types.h"
38*993b0882SAndroid Build Coastguard Worker #include "utils/flatbuffers/flatbuffers.h"
39*993b0882SAndroid Build Coastguard Worker #include "utils/flatbuffers/mutable.h"
40*993b0882SAndroid Build Coastguard Worker #include "utils/i18n/locale.h"
41*993b0882SAndroid Build Coastguard Worker #include "utils/memory/mmap.h"
42*993b0882SAndroid Build Coastguard Worker #include "utils/tflite-model-executor.h"
43*993b0882SAndroid Build Coastguard Worker #include "utils/utf8/unilib.h"
44*993b0882SAndroid Build Coastguard Worker #include "utils/variant.h"
45*993b0882SAndroid Build Coastguard Worker #include "utils/zlib/zlib.h"
46*993b0882SAndroid Build Coastguard Worker #include "absl/container/flat_hash_map.h"
47*993b0882SAndroid Build Coastguard Worker #include "absl/container/flat_hash_set.h"
48*993b0882SAndroid Build Coastguard Worker #include "absl/random/random.h"
49*993b0882SAndroid Build Coastguard Worker 
50*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
51*993b0882SAndroid Build Coastguard Worker 
52*993b0882SAndroid Build Coastguard Worker // Class for predicting actions following a conversation.
53*993b0882SAndroid Build Coastguard Worker class ActionsSuggestions {
54*993b0882SAndroid Build Coastguard Worker  public:
55*993b0882SAndroid Build Coastguard Worker   // Creates ActionsSuggestions from given data buffer with model.
56*993b0882SAndroid Build Coastguard Worker   static std::unique_ptr<ActionsSuggestions> FromUnownedBuffer(
57*993b0882SAndroid Build Coastguard Worker       const uint8_t* buffer, const int size, const UniLib* unilib = nullptr,
58*993b0882SAndroid Build Coastguard Worker       const std::string& triggering_preconditions_overlay = "");
59*993b0882SAndroid Build Coastguard Worker 
60*993b0882SAndroid Build Coastguard Worker   // Creates ActionsSuggestions from model in the ScopedMmap object and takes
61*993b0882SAndroid Build Coastguard Worker   // ownership of it.
62*993b0882SAndroid Build Coastguard Worker   static std::unique_ptr<ActionsSuggestions> FromScopedMmap(
63*993b0882SAndroid Build Coastguard Worker       std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
64*993b0882SAndroid Build Coastguard Worker       const UniLib* unilib = nullptr,
65*993b0882SAndroid Build Coastguard Worker       const std::string& triggering_preconditions_overlay = "");
66*993b0882SAndroid Build Coastguard Worker   // Same as above, but also takes ownership of the unilib.
67*993b0882SAndroid Build Coastguard Worker   static std::unique_ptr<ActionsSuggestions> FromScopedMmap(
68*993b0882SAndroid Build Coastguard Worker       std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
69*993b0882SAndroid Build Coastguard Worker       std::unique_ptr<UniLib> unilib,
70*993b0882SAndroid Build Coastguard Worker       const std::string& triggering_preconditions_overlay);
71*993b0882SAndroid Build Coastguard Worker 
72*993b0882SAndroid Build Coastguard Worker   // Creates ActionsSuggestions from model given as a file descriptor, offset
73*993b0882SAndroid Build Coastguard Worker   // and size in it. If offset and size are less than 0, will ignore them and
74*993b0882SAndroid Build Coastguard Worker   // will just use the fd.
75*993b0882SAndroid Build Coastguard Worker   static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
76*993b0882SAndroid Build Coastguard Worker       const int fd, const int offset, const int size,
77*993b0882SAndroid Build Coastguard Worker       const UniLib* unilib = nullptr,
78*993b0882SAndroid Build Coastguard Worker       const std::string& triggering_preconditions_overlay = "");
79*993b0882SAndroid Build Coastguard Worker   // Same as above, but also takes ownership of the unilib.
80*993b0882SAndroid Build Coastguard Worker   static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
81*993b0882SAndroid Build Coastguard Worker       const int fd, const int offset, const int size,
82*993b0882SAndroid Build Coastguard Worker       std::unique_ptr<UniLib> unilib,
83*993b0882SAndroid Build Coastguard Worker       const std::string& triggering_preconditions_overlay = "");
84*993b0882SAndroid Build Coastguard Worker 
85*993b0882SAndroid Build Coastguard Worker   // Creates ActionsSuggestions from model given as a file descriptor.
86*993b0882SAndroid Build Coastguard Worker   static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
87*993b0882SAndroid Build Coastguard Worker       const int fd, const UniLib* unilib = nullptr,
88*993b0882SAndroid Build Coastguard Worker       const std::string& triggering_preconditions_overlay = "");
89*993b0882SAndroid Build Coastguard Worker   // Same as above, but also takes ownership of the unilib.
90*993b0882SAndroid Build Coastguard Worker   static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
91*993b0882SAndroid Build Coastguard Worker       const int fd, std::unique_ptr<UniLib> unilib,
92*993b0882SAndroid Build Coastguard Worker       const std::string& triggering_preconditions_overlay);
93*993b0882SAndroid Build Coastguard Worker 
94*993b0882SAndroid Build Coastguard Worker   // Creates ActionsSuggestions from model given as a POSIX path.
95*993b0882SAndroid Build Coastguard Worker   static std::unique_ptr<ActionsSuggestions> FromPath(
96*993b0882SAndroid Build Coastguard Worker       const std::string& path, const UniLib* unilib = nullptr,
97*993b0882SAndroid Build Coastguard Worker       const std::string& triggering_preconditions_overlay = "");
98*993b0882SAndroid Build Coastguard Worker   // Same as above, but also takes ownership of unilib.
99*993b0882SAndroid Build Coastguard Worker   static std::unique_ptr<ActionsSuggestions> FromPath(
100*993b0882SAndroid Build Coastguard Worker       const std::string& path, std::unique_ptr<UniLib> unilib,
101*993b0882SAndroid Build Coastguard Worker       const std::string& triggering_preconditions_overlay);
102*993b0882SAndroid Build Coastguard Worker 
103*993b0882SAndroid Build Coastguard Worker   ActionsSuggestionsResponse SuggestActions(
104*993b0882SAndroid Build Coastguard Worker       const Conversation& conversation,
105*993b0882SAndroid Build Coastguard Worker       const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;
106*993b0882SAndroid Build Coastguard Worker 
107*993b0882SAndroid Build Coastguard Worker   ActionsSuggestionsResponse SuggestActions(
108*993b0882SAndroid Build Coastguard Worker       const Conversation& conversation, const Annotator* annotator,
109*993b0882SAndroid Build Coastguard Worker       const ActionSuggestionOptions& options = ActionSuggestionOptions()) const;
110*993b0882SAndroid Build Coastguard Worker 
111*993b0882SAndroid Build Coastguard Worker   bool InitializeConversationIntentDetection(
112*993b0882SAndroid Build Coastguard Worker       const std::string& serialized_config);
113*993b0882SAndroid Build Coastguard Worker 
114*993b0882SAndroid Build Coastguard Worker   const ActionsModel* model() const;
115*993b0882SAndroid Build Coastguard Worker   const reflection::Schema* entity_data_schema() const;
116*993b0882SAndroid Build Coastguard Worker 
117*993b0882SAndroid Build Coastguard Worker   static constexpr int kLocalUserId = 0;
118*993b0882SAndroid Build Coastguard Worker 
119*993b0882SAndroid Build Coastguard Worker  protected:
120*993b0882SAndroid Build Coastguard Worker   // Exposed for testing.
121*993b0882SAndroid Build Coastguard Worker   bool EmbedTokenId(const int32 token_id, std::vector<float>* embedding) const;
122*993b0882SAndroid Build Coastguard Worker 
123*993b0882SAndroid Build Coastguard Worker   // Embeds the tokens per message separately. Each message is padded to the
124*993b0882SAndroid Build Coastguard Worker   // maximum length with the padding token.
125*993b0882SAndroid Build Coastguard Worker   bool EmbedTokensPerMessage(const std::vector<std::vector<Token>>& tokens,
126*993b0882SAndroid Build Coastguard Worker                              std::vector<float>* embeddings,
127*993b0882SAndroid Build Coastguard Worker                              int* max_num_tokens_per_message) const;
128*993b0882SAndroid Build Coastguard Worker 
129*993b0882SAndroid Build Coastguard Worker   // Concatenates the embedded message tokens - separated by start and end
130*993b0882SAndroid Build Coastguard Worker   // token between messages.
131*993b0882SAndroid Build Coastguard Worker   // If the total token count is greater than the maximum length, tokens at the
132*993b0882SAndroid Build Coastguard Worker   // start are dropped to fit into the limit.
133*993b0882SAndroid Build Coastguard Worker   // If the total token count is smaller than the minimum length, padding tokens
134*993b0882SAndroid Build Coastguard Worker   // are added to the end.
135*993b0882SAndroid Build Coastguard Worker   // Messages are assumed to be ordered by recency - most recent is last.
136*993b0882SAndroid Build Coastguard Worker   bool EmbedAndFlattenTokens(const std::vector<std::vector<Token>>& tokens,
137*993b0882SAndroid Build Coastguard Worker                              std::vector<float>* embeddings,
138*993b0882SAndroid Build Coastguard Worker                              int* total_token_count) const;
139*993b0882SAndroid Build Coastguard Worker 
140*993b0882SAndroid Build Coastguard Worker   const ActionsModel* model_;
141*993b0882SAndroid Build Coastguard Worker 
142*993b0882SAndroid Build Coastguard Worker   // Feature extractor and options.
143*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<const ActionsFeatureProcessor> feature_processor_;
144*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<const EmbeddingExecutor> embedding_executor_;
145*993b0882SAndroid Build Coastguard Worker   std::vector<float> embedded_padding_token_;
146*993b0882SAndroid Build Coastguard Worker   std::vector<float> embedded_start_token_;
147*993b0882SAndroid Build Coastguard Worker   std::vector<float> embedded_end_token_;
148*993b0882SAndroid Build Coastguard Worker   int token_embedding_size_;
149*993b0882SAndroid Build Coastguard Worker 
150*993b0882SAndroid Build Coastguard Worker  private:
151*993b0882SAndroid Build Coastguard Worker   // Checks that model contains all required fields, and initializes internal
152*993b0882SAndroid Build Coastguard Worker   // datastructures.
153*993b0882SAndroid Build Coastguard Worker   bool ValidateAndInitialize();
154*993b0882SAndroid Build Coastguard Worker 
155*993b0882SAndroid Build Coastguard Worker   void SetOrCreateUnilib(const UniLib* unilib);
156*993b0882SAndroid Build Coastguard Worker 
157*993b0882SAndroid Build Coastguard Worker   // Prepare preconditions.
158*993b0882SAndroid Build Coastguard Worker   // Takes values from flag provided data, but falls back to model provided
159*993b0882SAndroid Build Coastguard Worker   // values for parameters that are not explicitly provided.
160*993b0882SAndroid Build Coastguard Worker   bool InitializeTriggeringPreconditions();
161*993b0882SAndroid Build Coastguard Worker 
162*993b0882SAndroid Build Coastguard Worker   // Tokenizes a conversation and produces the tokens per message.
163*993b0882SAndroid Build Coastguard Worker   std::vector<std::vector<Token>> Tokenize(
164*993b0882SAndroid Build Coastguard Worker       const std::vector<std::string>& context) const;
165*993b0882SAndroid Build Coastguard Worker 
166*993b0882SAndroid Build Coastguard Worker   bool AllocateInput(const int conversation_length, const int max_tokens,
167*993b0882SAndroid Build Coastguard Worker                      const int total_token_count,
168*993b0882SAndroid Build Coastguard Worker                      tflite::Interpreter* interpreter) const;
169*993b0882SAndroid Build Coastguard Worker 
170*993b0882SAndroid Build Coastguard Worker   bool SetupModelInput(const std::vector<std::string>& context,
171*993b0882SAndroid Build Coastguard Worker                        const std::vector<int>& user_ids,
172*993b0882SAndroid Build Coastguard Worker                        const std::vector<float>& time_diffs,
173*993b0882SAndroid Build Coastguard Worker                        const int num_suggestions,
174*993b0882SAndroid Build Coastguard Worker                        const ActionSuggestionOptions& options,
175*993b0882SAndroid Build Coastguard Worker                        tflite::Interpreter* interpreter) const;
176*993b0882SAndroid Build Coastguard Worker 
177*993b0882SAndroid Build Coastguard Worker   void FillSuggestionFromSpecWithEntityData(const ActionSuggestionSpec* spec,
178*993b0882SAndroid Build Coastguard Worker                                             ActionSuggestion* suggestion) const;
179*993b0882SAndroid Build Coastguard Worker 
180*993b0882SAndroid Build Coastguard Worker   void PopulateTextReplies(
181*993b0882SAndroid Build Coastguard Worker       const tflite::Interpreter* interpreter, int suggestion_index,
182*993b0882SAndroid Build Coastguard Worker       int score_index, const std::string& type, float priority_score,
183*993b0882SAndroid Build Coastguard Worker       const absl::flat_hash_set<std::string>& blocklist,
184*993b0882SAndroid Build Coastguard Worker       const absl::flat_hash_map<std::string, std::vector<std::string>>&
185*993b0882SAndroid Build Coastguard Worker           concept_mappings,
186*993b0882SAndroid Build Coastguard Worker       ActionsSuggestionsResponse* response) const;
187*993b0882SAndroid Build Coastguard Worker 
188*993b0882SAndroid Build Coastguard Worker   void PopulateIntentTriggering(const tflite::Interpreter* interpreter,
189*993b0882SAndroid Build Coastguard Worker                                 int suggestion_index, int score_index,
190*993b0882SAndroid Build Coastguard Worker                                 const ActionSuggestionSpec* task_spec,
191*993b0882SAndroid Build Coastguard Worker                                 ActionsSuggestionsResponse* response) const;
192*993b0882SAndroid Build Coastguard Worker 
193*993b0882SAndroid Build Coastguard Worker   bool ReadModelOutput(tflite::Interpreter* interpreter,
194*993b0882SAndroid Build Coastguard Worker                        const ActionSuggestionOptions& options,
195*993b0882SAndroid Build Coastguard Worker                        ActionsSuggestionsResponse* response) const;
196*993b0882SAndroid Build Coastguard Worker 
197*993b0882SAndroid Build Coastguard Worker   bool SuggestActionsFromModel(
198*993b0882SAndroid Build Coastguard Worker       const Conversation& conversation, const int num_messages,
199*993b0882SAndroid Build Coastguard Worker       const ActionSuggestionOptions& options,
200*993b0882SAndroid Build Coastguard Worker       ActionsSuggestionsResponse* response,
201*993b0882SAndroid Build Coastguard Worker       std::unique_ptr<tflite::Interpreter>* interpreter) const;
202*993b0882SAndroid Build Coastguard Worker 
203*993b0882SAndroid Build Coastguard Worker   Status SuggestActionsFromConversationIntentDetection(
204*993b0882SAndroid Build Coastguard Worker       const Conversation& conversation, const ActionSuggestionOptions& options,
205*993b0882SAndroid Build Coastguard Worker       std::vector<ActionSuggestion>* actions) const;
206*993b0882SAndroid Build Coastguard Worker 
207*993b0882SAndroid Build Coastguard Worker   // Creates options for annotation of a message.
208*993b0882SAndroid Build Coastguard Worker   AnnotationOptions AnnotationOptionsForMessage(
209*993b0882SAndroid Build Coastguard Worker       const ConversationMessage& message) const;
210*993b0882SAndroid Build Coastguard Worker 
211*993b0882SAndroid Build Coastguard Worker   void SuggestActionsFromAnnotations(
212*993b0882SAndroid Build Coastguard Worker       const Conversation& conversation,
213*993b0882SAndroid Build Coastguard Worker       std::vector<ActionSuggestion>* actions) const;
214*993b0882SAndroid Build Coastguard Worker 
215*993b0882SAndroid Build Coastguard Worker   void SuggestActionsFromAnnotation(
216*993b0882SAndroid Build Coastguard Worker       const int message_index, const ActionSuggestionAnnotation& annotation,
217*993b0882SAndroid Build Coastguard Worker       std::vector<ActionSuggestion>* actions) const;
218*993b0882SAndroid Build Coastguard Worker 
219*993b0882SAndroid Build Coastguard Worker   // Run annotator on the messages of a conversation.
220*993b0882SAndroid Build Coastguard Worker   Conversation AnnotateConversation(const Conversation& conversation,
221*993b0882SAndroid Build Coastguard Worker                                     const Annotator* annotator) const;
222*993b0882SAndroid Build Coastguard Worker 
223*993b0882SAndroid Build Coastguard Worker   // Deduplicates equivalent annotations - annotations that have the same type
224*993b0882SAndroid Build Coastguard Worker   // and same span text.
225*993b0882SAndroid Build Coastguard Worker   // Returns the indices of the deduplicated annotations.
226*993b0882SAndroid Build Coastguard Worker   std::vector<int> DeduplicateAnnotations(
227*993b0882SAndroid Build Coastguard Worker       const std::vector<ActionSuggestionAnnotation>& annotations) const;
228*993b0882SAndroid Build Coastguard Worker 
229*993b0882SAndroid Build Coastguard Worker   bool SuggestActionsFromLua(
230*993b0882SAndroid Build Coastguard Worker       const Conversation& conversation,
231*993b0882SAndroid Build Coastguard Worker       const TfLiteModelExecutor* model_executor,
232*993b0882SAndroid Build Coastguard Worker       const tflite::Interpreter* interpreter,
233*993b0882SAndroid Build Coastguard Worker       const reflection::Schema* annotation_entity_data_schema,
234*993b0882SAndroid Build Coastguard Worker       std::vector<ActionSuggestion>* actions) const;
235*993b0882SAndroid Build Coastguard Worker 
236*993b0882SAndroid Build Coastguard Worker   bool GatherActionsSuggestions(const Conversation& conversation,
237*993b0882SAndroid Build Coastguard Worker                                 const Annotator* annotator,
238*993b0882SAndroid Build Coastguard Worker                                 const ActionSuggestionOptions& options,
239*993b0882SAndroid Build Coastguard Worker                                 ActionsSuggestionsResponse* response) const;
240*993b0882SAndroid Build Coastguard Worker 
241*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap_;
242*993b0882SAndroid Build Coastguard Worker 
243*993b0882SAndroid Build Coastguard Worker   // Tensorflow Lite models.
244*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<const TfLiteModelExecutor> model_executor_;
245*993b0882SAndroid Build Coastguard Worker 
246*993b0882SAndroid Build Coastguard Worker   // Regex rules model.
247*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<RegexActions> regex_actions_;
248*993b0882SAndroid Build Coastguard Worker 
249*993b0882SAndroid Build Coastguard Worker   // The grammar rules model.
250*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<GrammarActions> grammar_actions_;
251*993b0882SAndroid Build Coastguard Worker 
252*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<UniLib> owned_unilib_;
253*993b0882SAndroid Build Coastguard Worker   const UniLib* unilib_;
254*993b0882SAndroid Build Coastguard Worker 
255*993b0882SAndroid Build Coastguard Worker   // Locales supported by the model.
256*993b0882SAndroid Build Coastguard Worker   std::vector<Locale> locales_;
257*993b0882SAndroid Build Coastguard Worker 
258*993b0882SAndroid Build Coastguard Worker   // Annotation entities used by the model.
259*993b0882SAndroid Build Coastguard Worker   std::unordered_set<std::string> annotation_entity_types_;
260*993b0882SAndroid Build Coastguard Worker 
261*993b0882SAndroid Build Coastguard Worker   // Builder for creating extra data.
262*993b0882SAndroid Build Coastguard Worker   const reflection::Schema* entity_data_schema_;
263*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<MutableFlatbufferBuilder> entity_data_builder_;
264*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<ActionsSuggestionsRanker> ranker_;
265*993b0882SAndroid Build Coastguard Worker 
266*993b0882SAndroid Build Coastguard Worker   std::string lua_bytecode_;
267*993b0882SAndroid Build Coastguard Worker 
268*993b0882SAndroid Build Coastguard Worker   // Triggering preconditions. These parameters can be backed by the model and
269*993b0882SAndroid Build Coastguard Worker   // (partially) be provided by flags.
270*993b0882SAndroid Build Coastguard Worker   TriggeringPreconditionsT preconditions_;
271*993b0882SAndroid Build Coastguard Worker   std::string triggering_preconditions_overlay_buffer_;
272*993b0882SAndroid Build Coastguard Worker   const TriggeringPreconditions* triggering_preconditions_overlay_;
273*993b0882SAndroid Build Coastguard Worker 
274*993b0882SAndroid Build Coastguard Worker   // Low confidence input ngram classifier.
275*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<const SensitiveTopicModelBase> sensitive_model_;
276*993b0882SAndroid Build Coastguard Worker 
277*993b0882SAndroid Build Coastguard Worker   // Conversation intent detection model for additional actions.
278*993b0882SAndroid Build Coastguard Worker   std::unique_ptr<const ConversationIntentDetection>
279*993b0882SAndroid Build Coastguard Worker       conversation_intent_detection_;
280*993b0882SAndroid Build Coastguard Worker 
281*993b0882SAndroid Build Coastguard Worker   // Used for randomly selecting candidates.
282*993b0882SAndroid Build Coastguard Worker   mutable absl::BitGen bit_gen_;
283*993b0882SAndroid Build Coastguard Worker };
284*993b0882SAndroid Build Coastguard Worker 
285*993b0882SAndroid Build Coastguard Worker // Interprets the buffer as a Model flatbuffer and returns it for reading.
286*993b0882SAndroid Build Coastguard Worker const ActionsModel* ViewActionsModel(const void* buffer, int size);
287*993b0882SAndroid Build Coastguard Worker 
288*993b0882SAndroid Build Coastguard Worker // Opens model from given path and runs a function, passing the loaded Model
289*993b0882SAndroid Build Coastguard Worker // flatbuffer as an argument.
290*993b0882SAndroid Build Coastguard Worker //
291*993b0882SAndroid Build Coastguard Worker // This is mainly useful if we don't want to pay the cost for the model
292*993b0882SAndroid Build Coastguard Worker // initialization because we'll be only reading some flatbuffer values from the
293*993b0882SAndroid Build Coastguard Worker // file.
294*993b0882SAndroid Build Coastguard Worker template <typename ReturnType, typename Func>
VisitActionsModel(const std::string & path,Func function)295*993b0882SAndroid Build Coastguard Worker ReturnType VisitActionsModel(const std::string& path, Func function) {
296*993b0882SAndroid Build Coastguard Worker   ScopedMmap mmap(path);
297*993b0882SAndroid Build Coastguard Worker   if (!mmap.handle().ok()) {
298*993b0882SAndroid Build Coastguard Worker     function(/*model=*/nullptr);
299*993b0882SAndroid Build Coastguard Worker   }
300*993b0882SAndroid Build Coastguard Worker   const ActionsModel* model =
301*993b0882SAndroid Build Coastguard Worker       ViewActionsModel(mmap.handle().start(), mmap.handle().num_bytes());
302*993b0882SAndroid Build Coastguard Worker   return function(model);
303*993b0882SAndroid Build Coastguard Worker }
304*993b0882SAndroid Build Coastguard Worker 
305*993b0882SAndroid Build Coastguard Worker class ActionsSuggestionsTypes {
306*993b0882SAndroid Build Coastguard Worker  public:
307*993b0882SAndroid Build Coastguard Worker   // Should be in sync with those defined in Android.
308*993b0882SAndroid Build Coastguard Worker   // android/frameworks/base/core/java/android/view/textclassifier/ConversationActions.java
ViewCalendar()309*993b0882SAndroid Build Coastguard Worker   static const std::string& ViewCalendar() {
310*993b0882SAndroid Build Coastguard Worker     static const std::string& value =
311*993b0882SAndroid Build Coastguard Worker         *[]() { return new std::string("view_calendar"); }();
312*993b0882SAndroid Build Coastguard Worker     return value;
313*993b0882SAndroid Build Coastguard Worker   }
ViewMap()314*993b0882SAndroid Build Coastguard Worker   static const std::string& ViewMap() {
315*993b0882SAndroid Build Coastguard Worker     static const std::string& value =
316*993b0882SAndroid Build Coastguard Worker         *[]() { return new std::string("view_map"); }();
317*993b0882SAndroid Build Coastguard Worker     return value;
318*993b0882SAndroid Build Coastguard Worker   }
TrackFlight()319*993b0882SAndroid Build Coastguard Worker   static const std::string& TrackFlight() {
320*993b0882SAndroid Build Coastguard Worker     static const std::string& value =
321*993b0882SAndroid Build Coastguard Worker         *[]() { return new std::string("track_flight"); }();
322*993b0882SAndroid Build Coastguard Worker     return value;
323*993b0882SAndroid Build Coastguard Worker   }
OpenUrl()324*993b0882SAndroid Build Coastguard Worker   static const std::string& OpenUrl() {
325*993b0882SAndroid Build Coastguard Worker     static const std::string& value =
326*993b0882SAndroid Build Coastguard Worker         *[]() { return new std::string("open_url"); }();
327*993b0882SAndroid Build Coastguard Worker     return value;
328*993b0882SAndroid Build Coastguard Worker   }
SendSms()329*993b0882SAndroid Build Coastguard Worker   static const std::string& SendSms() {
330*993b0882SAndroid Build Coastguard Worker     static const std::string& value =
331*993b0882SAndroid Build Coastguard Worker         *[]() { return new std::string("send_sms"); }();
332*993b0882SAndroid Build Coastguard Worker     return value;
333*993b0882SAndroid Build Coastguard Worker   }
CallPhone()334*993b0882SAndroid Build Coastguard Worker   static const std::string& CallPhone() {
335*993b0882SAndroid Build Coastguard Worker     static const std::string& value =
336*993b0882SAndroid Build Coastguard Worker         *[]() { return new std::string("call_phone"); }();
337*993b0882SAndroid Build Coastguard Worker     return value;
338*993b0882SAndroid Build Coastguard Worker   }
SendEmail()339*993b0882SAndroid Build Coastguard Worker   static const std::string& SendEmail() {
340*993b0882SAndroid Build Coastguard Worker     static const std::string& value =
341*993b0882SAndroid Build Coastguard Worker         *[]() { return new std::string("send_email"); }();
342*993b0882SAndroid Build Coastguard Worker     return value;
343*993b0882SAndroid Build Coastguard Worker   }
ShareLocation()344*993b0882SAndroid Build Coastguard Worker   static const std::string& ShareLocation() {
345*993b0882SAndroid Build Coastguard Worker     static const std::string& value =
346*993b0882SAndroid Build Coastguard Worker         *[]() { return new std::string("share_location"); }();
347*993b0882SAndroid Build Coastguard Worker     return value;
348*993b0882SAndroid Build Coastguard Worker   }
CreateReminder()349*993b0882SAndroid Build Coastguard Worker   static const std::string& CreateReminder() {
350*993b0882SAndroid Build Coastguard Worker     static const std::string& value =
351*993b0882SAndroid Build Coastguard Worker         *[]() { return new std::string("create_reminder"); }();
352*993b0882SAndroid Build Coastguard Worker     return value;
353*993b0882SAndroid Build Coastguard Worker   }
TextReply()354*993b0882SAndroid Build Coastguard Worker   static const std::string& TextReply() {
355*993b0882SAndroid Build Coastguard Worker     static const std::string& value =
356*993b0882SAndroid Build Coastguard Worker         *[]() { return new std::string("text_reply"); }();
357*993b0882SAndroid Build Coastguard Worker     return value;
358*993b0882SAndroid Build Coastguard Worker   }
AddContact()359*993b0882SAndroid Build Coastguard Worker   static const std::string& AddContact() {
360*993b0882SAndroid Build Coastguard Worker     static const std::string& value =
361*993b0882SAndroid Build Coastguard Worker         *[]() { return new std::string("add_contact"); }();
362*993b0882SAndroid Build Coastguard Worker     return value;
363*993b0882SAndroid Build Coastguard Worker   }
Copy()364*993b0882SAndroid Build Coastguard Worker   static const std::string& Copy() {
365*993b0882SAndroid Build Coastguard Worker     static const std::string& value =
366*993b0882SAndroid Build Coastguard Worker         *[]() { return new std::string("copy"); }();
367*993b0882SAndroid Build Coastguard Worker     return value;
368*993b0882SAndroid Build Coastguard Worker   }
369*993b0882SAndroid Build Coastguard Worker };
370*993b0882SAndroid Build Coastguard Worker 
371*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
372*993b0882SAndroid Build Coastguard Worker 
373*993b0882SAndroid Build Coastguard Worker #endif  // LIBTEXTCLASSIFIER_ACTIONS_ACTIONS_SUGGESTIONS_H_
374