xref: /aosp_15_r20/external/libtextclassifier/native/actions/actions-suggestions.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "actions/actions-suggestions.h"
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "utils/base/statusor.h"
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/random/random.h"
26 
27 #if !defined(TC3_DISABLE_LUA)
28 #include "actions/lua-actions.h"
29 #endif
30 #include "actions/ngram-model.h"
31 #include "actions/tflite-sensitive-model.h"
32 #include "actions/types.h"
33 #include "actions/utils.h"
34 #include "actions/zlib-utils.h"
35 #include "annotator/collections.h"
36 #include "utils/base/logging.h"
37 #if !defined(TC3_DISABLE_LUA)
38 #include "utils/lua-utils.h"
39 #endif
40 #include "utils/normalization.h"
41 #include "utils/optional.h"
42 #include "utils/strings/split.h"
43 #include "utils/strings/stringpiece.h"
44 #include "utils/strings/utf8.h"
45 #include "utils/utf8/unicodetext.h"
46 #include "absl/container/flat_hash_set.h"
47 #include "absl/random/distributions.h"
48 #include "tensorflow/lite/string_util.h"
49 
50 namespace libtextclassifier3 {
51 
52 constexpr float kDefaultFloat = 0.0;
53 constexpr bool kDefaultBool = false;
54 constexpr int kDefaultInt = 1;
55 
56 namespace {
57 
LoadAndVerifyModel(const uint8_t * addr,int size)58 const ActionsModel* LoadAndVerifyModel(const uint8_t* addr, int size) {
59   flatbuffers::Verifier verifier(addr, size);
60   if (VerifyActionsModelBuffer(verifier)) {
61     return GetActionsModel(addr);
62   } else {
63     return nullptr;
64   }
65 }
66 
67 template <typename T>
ValueOrDefault(const flatbuffers::Table * values,const int32 field_offset,const T default_value)68 T ValueOrDefault(const flatbuffers::Table* values, const int32 field_offset,
69                  const T default_value) {
70   if (values == nullptr) {
71     return default_value;
72   }
73   return values->GetField<T>(field_offset, default_value);
74 }
75 
76 // Returns number of (tail) messages of a conversation to consider.
NumMessagesToConsider(const Conversation & conversation,const int max_conversation_history_length)77 int NumMessagesToConsider(const Conversation& conversation,
78                           const int max_conversation_history_length) {
79   return ((max_conversation_history_length < 0 ||
80            conversation.messages.size() < max_conversation_history_length)
81               ? conversation.messages.size()
82               : max_conversation_history_length);
83 }
84 
85 template <typename T>
PadOrTruncateToTargetLength(const std::vector<T> & inputs,const int max_length,const T pad_value)86 std::vector<T> PadOrTruncateToTargetLength(const std::vector<T>& inputs,
87                                            const int max_length,
88                                            const T pad_value) {
89   if (inputs.size() >= max_length) {
90     return std::vector<T>(inputs.begin(), inputs.begin() + max_length);
91   } else {
92     std::vector<T> result;
93     result.reserve(max_length);
94     result.insert(result.begin(), inputs.begin(), inputs.end());
95     result.insert(result.end(), max_length - inputs.size(), pad_value);
96     return result;
97   }
98 }
99 
100 template <typename T>
SetVectorOrScalarAsModelInput(const int param_index,const Variant & param_value,tflite::Interpreter * interpreter,const std::unique_ptr<const TfLiteModelExecutor> & model_executor)101 void SetVectorOrScalarAsModelInput(
102     const int param_index, const Variant& param_value,
103     tflite::Interpreter* interpreter,
104     const std::unique_ptr<const TfLiteModelExecutor>& model_executor) {
105   if (param_value.Has<std::vector<T>>()) {
106     model_executor->SetInput<T>(
107         param_index, param_value.ConstRefValue<std::vector<T>>(), interpreter);
108   } else if (param_value.Has<T>()) {
109     model_executor->SetInput<float>(param_index, param_value.Value<T>(),
110                                     interpreter);
111   } else {
112     TC3_LOG(ERROR) << "Variant type error!";
113   }
114 }
115 }  // namespace
116 
FromUnownedBuffer(const uint8_t * buffer,const int size,const UniLib * unilib,const std::string & triggering_preconditions_overlay)117 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromUnownedBuffer(
118     const uint8_t* buffer, const int size, const UniLib* unilib,
119     const std::string& triggering_preconditions_overlay) {
120   auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
121   const ActionsModel* model = LoadAndVerifyModel(buffer, size);
122   if (model == nullptr) {
123     return nullptr;
124   }
125   actions->model_ = model;
126   actions->SetOrCreateUnilib(unilib);
127   actions->triggering_preconditions_overlay_buffer_ =
128       triggering_preconditions_overlay;
129   if (!actions->ValidateAndInitialize()) {
130     return nullptr;
131   }
132   return actions;
133 }
134 
FromScopedMmap(std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,const UniLib * unilib,const std::string & triggering_preconditions_overlay)135 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromScopedMmap(
136     std::unique_ptr<libtextclassifier3::ScopedMmap> mmap, const UniLib* unilib,
137     const std::string& triggering_preconditions_overlay) {
138   if (!mmap->handle().ok()) {
139     TC3_VLOG(1) << "Mmap failed.";
140     return nullptr;
141   }
142   const ActionsModel* model = LoadAndVerifyModel(
143       reinterpret_cast<const uint8_t*>(mmap->handle().start()),
144       mmap->handle().num_bytes());
145   if (!model) {
146     TC3_LOG(ERROR) << "Model verification failed.";
147     return nullptr;
148   }
149   auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
150   actions->model_ = model;
151   actions->mmap_ = std::move(mmap);
152   actions->SetOrCreateUnilib(unilib);
153   actions->triggering_preconditions_overlay_buffer_ =
154       triggering_preconditions_overlay;
155   if (!actions->ValidateAndInitialize()) {
156     return nullptr;
157   }
158   return actions;
159 }
160 
FromScopedMmap(std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)161 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromScopedMmap(
162     std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
163     std::unique_ptr<UniLib> unilib,
164     const std::string& triggering_preconditions_overlay) {
165   if (!mmap->handle().ok()) {
166     TC3_VLOG(1) << "Mmap failed.";
167     return nullptr;
168   }
169   const ActionsModel* model = LoadAndVerifyModel(
170       reinterpret_cast<const uint8_t*>(mmap->handle().start()),
171       mmap->handle().num_bytes());
172   if (!model) {
173     TC3_LOG(ERROR) << "Model verification failed.";
174     return nullptr;
175   }
176   auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
177   actions->model_ = model;
178   actions->mmap_ = std::move(mmap);
179   actions->owned_unilib_ = std::move(unilib);
180   actions->unilib_ = actions->owned_unilib_.get();
181   actions->triggering_preconditions_overlay_buffer_ =
182       triggering_preconditions_overlay;
183   if (!actions->ValidateAndInitialize()) {
184     return nullptr;
185   }
186   return actions;
187 }
188 
FromFileDescriptor(const int fd,const int offset,const int size,const UniLib * unilib,const std::string & triggering_preconditions_overlay)189 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
190     const int fd, const int offset, const int size, const UniLib* unilib,
191     const std::string& triggering_preconditions_overlay) {
192   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap;
193   if (offset >= 0 && size >= 0) {
194     mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size));
195   } else {
196     mmap.reset(new libtextclassifier3::ScopedMmap(fd));
197   }
198   return FromScopedMmap(std::move(mmap), unilib,
199                         triggering_preconditions_overlay);
200 }
201 
FromFileDescriptor(const int fd,const int offset,const int size,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)202 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
203     const int fd, const int offset, const int size,
204     std::unique_ptr<UniLib> unilib,
205     const std::string& triggering_preconditions_overlay) {
206   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap;
207   if (offset >= 0 && size >= 0) {
208     mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size));
209   } else {
210     mmap.reset(new libtextclassifier3::ScopedMmap(fd));
211   }
212   return FromScopedMmap(std::move(mmap), std::move(unilib),
213                         triggering_preconditions_overlay);
214 }
215 
FromFileDescriptor(const int fd,const UniLib * unilib,const std::string & triggering_preconditions_overlay)216 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
217     const int fd, const UniLib* unilib,
218     const std::string& triggering_preconditions_overlay) {
219   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
220       new libtextclassifier3::ScopedMmap(fd));
221   return FromScopedMmap(std::move(mmap), unilib,
222                         triggering_preconditions_overlay);
223 }
224 
FromFileDescriptor(const int fd,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)225 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
226     const int fd, std::unique_ptr<UniLib> unilib,
227     const std::string& triggering_preconditions_overlay) {
228   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
229       new libtextclassifier3::ScopedMmap(fd));
230   return FromScopedMmap(std::move(mmap), std::move(unilib),
231                         triggering_preconditions_overlay);
232 }
233 
FromPath(const std::string & path,const UniLib * unilib,const std::string & triggering_preconditions_overlay)234 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromPath(
235     const std::string& path, const UniLib* unilib,
236     const std::string& triggering_preconditions_overlay) {
237   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
238       new libtextclassifier3::ScopedMmap(path));
239   return FromScopedMmap(std::move(mmap), unilib,
240                         triggering_preconditions_overlay);
241 }
242 
FromPath(const std::string & path,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)243 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromPath(
244     const std::string& path, std::unique_ptr<UniLib> unilib,
245     const std::string& triggering_preconditions_overlay) {
246   std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
247       new libtextclassifier3::ScopedMmap(path));
248   return FromScopedMmap(std::move(mmap), std::move(unilib),
249                         triggering_preconditions_overlay);
250 }
251 
SetOrCreateUnilib(const UniLib * unilib)252 void ActionsSuggestions::SetOrCreateUnilib(const UniLib* unilib) {
253   if (unilib != nullptr) {
254     unilib_ = unilib;
255   } else {
256     owned_unilib_.reset(new UniLib);
257     unilib_ = owned_unilib_.get();
258   }
259 }
260 
ValidateAndInitialize()261 bool ActionsSuggestions::ValidateAndInitialize() {
262   if (model_ == nullptr) {
263     TC3_LOG(ERROR) << "No model specified.";
264     return false;
265   }
266 
267   if (model_->smart_reply_action_type() == nullptr) {
268     TC3_LOG(ERROR) << "No smart reply action type specified.";
269     return false;
270   }
271 
272   if (!InitializeTriggeringPreconditions()) {
273     TC3_LOG(ERROR) << "Could not initialize preconditions.";
274     return false;
275   }
276 
277   if (model_->locales() &&
278       !ParseLocales(model_->locales()->c_str(), &locales_)) {
279     TC3_LOG(ERROR) << "Could not parse model supported locales.";
280     return false;
281   }
282 
283   if (model_->tflite_model_spec() != nullptr) {
284     model_executor_ = TfLiteModelExecutor::FromBuffer(
285         model_->tflite_model_spec()->tflite_model());
286     if (!model_executor_) {
287       TC3_LOG(ERROR) << "Could not initialize model executor.";
288       return false;
289     }
290   }
291 
292   // Gather annotation entities for the rules.
293   if (model_->annotation_actions_spec() != nullptr &&
294       model_->annotation_actions_spec()->annotation_mapping() != nullptr) {
295     for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
296          *model_->annotation_actions_spec()->annotation_mapping()) {
297       annotation_entity_types_.insert(mapping->annotation_collection()->str());
298     }
299   }
300 
301   if (model_->actions_entity_data_schema() != nullptr) {
302     entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
303         model_->actions_entity_data_schema()->Data(),
304         model_->actions_entity_data_schema()->size());
305     if (entity_data_schema_ == nullptr) {
306       TC3_LOG(ERROR) << "Could not load entity data schema data.";
307       return false;
308     }
309 
310     entity_data_builder_.reset(
311         new MutableFlatbufferBuilder(entity_data_schema_));
312   } else {
313     entity_data_schema_ = nullptr;
314   }
315 
316   // Initialize regular expressions model.
317   std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
318   regex_actions_.reset(
319       new RegexActions(unilib_, model_->smart_reply_action_type()->str()));
320   if (!regex_actions_->InitializeRules(
321           model_->rules(), model_->low_confidence_rules(),
322           triggering_preconditions_overlay_, decompressor.get())) {
323     TC3_LOG(ERROR) << "Could not initialize regex rules.";
324     return false;
325   }
326 
327   // Setup grammar model.
328   if (model_->rules() != nullptr &&
329       model_->rules()->grammar_rules() != nullptr) {
330     grammar_actions_.reset(new GrammarActions(
331         unilib_, model_->rules()->grammar_rules(), entity_data_builder_.get(),
332         model_->smart_reply_action_type()->str()));
333 
334     // Gather annotation entities for the grammars.
335     if (auto annotation_nt = model_->rules()
336                                  ->grammar_rules()
337                                  ->rules()
338                                  ->nonterminals()
339                                  ->annotation_nt()) {
340       for (const grammar::RulesSet_::Nonterminals_::AnnotationNtEntry* entry :
341            *annotation_nt) {
342         annotation_entity_types_.insert(entry->key()->str());
343       }
344     }
345   }
346 
347 #if !defined(TC3_DISABLE_LUA)
348   std::string actions_script;
349   if (GetUncompressedString(model_->lua_actions_script(),
350                             model_->compressed_lua_actions_script(),
351                             decompressor.get(), &actions_script) &&
352       !actions_script.empty()) {
353     if (!Compile(actions_script, &lua_bytecode_)) {
354       TC3_LOG(ERROR) << "Could not precompile lua actions snippet.";
355       return false;
356     }
357   }
358 #endif  // TC3_DISABLE_LUA
359 
360   if (!(ranker_ = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
361             model_->ranking_options(), decompressor.get(),
362             model_->smart_reply_action_type()->str()))) {
363     TC3_LOG(ERROR) << "Could not create an action suggestions ranker.";
364     return false;
365   }
366 
367   // Create feature processor if specified.
368   const ActionsTokenFeatureProcessorOptions* options =
369       model_->feature_processor_options();
370   if (options != nullptr) {
371     if (options->tokenizer_options() == nullptr) {
372       TC3_LOG(ERROR) << "No tokenizer options specified.";
373       return false;
374     }
375 
376     feature_processor_.reset(new ActionsFeatureProcessor(options, unilib_));
377     embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
378         options->embedding_model(), options->embedding_size(),
379         options->embedding_quantization_bits());
380 
381     if (embedding_executor_ == nullptr) {
382       TC3_LOG(ERROR) << "Could not initialize embedding executor.";
383       return false;
384     }
385 
386     // Cache embedding of padding, start and end token.
387     if (!EmbedTokenId(options->padding_token_id(), &embedded_padding_token_) ||
388         !EmbedTokenId(options->start_token_id(), &embedded_start_token_) ||
389         !EmbedTokenId(options->end_token_id(), &embedded_end_token_)) {
390       TC3_LOG(ERROR) << "Could not precompute token embeddings.";
391       return false;
392     }
393     token_embedding_size_ = feature_processor_->GetTokenEmbeddingSize();
394   }
395 
396   // Create low confidence model if specified.
397   if (model_->low_confidence_ngram_model() != nullptr) {
398     sensitive_model_ = NGramSensitiveModel::Create(
399         unilib_, model_->low_confidence_ngram_model(),
400         feature_processor_ == nullptr ? nullptr
401                                       : feature_processor_->tokenizer());
402     if (sensitive_model_ == nullptr) {
403       TC3_LOG(ERROR) << "Could not create ngram linear regression model.";
404       return false;
405     }
406   }
407   if (model_->low_confidence_tflite_model() != nullptr) {
408     sensitive_model_ =
409         TFLiteSensitiveModel::Create(model_->low_confidence_tflite_model());
410     if (sensitive_model_ == nullptr) {
411       TC3_LOG(ERROR) << "Could not create TFLite sensitive model.";
412       return false;
413     }
414   }
415 
416   return true;
417 }
418 
InitializeTriggeringPreconditions()419 bool ActionsSuggestions::InitializeTriggeringPreconditions() {
420   triggering_preconditions_overlay_ =
421       LoadAndVerifyFlatbuffer<TriggeringPreconditions>(
422           triggering_preconditions_overlay_buffer_);
423 
424   if (triggering_preconditions_overlay_ == nullptr &&
425       !triggering_preconditions_overlay_buffer_.empty()) {
426     TC3_LOG(ERROR) << "Could not load triggering preconditions overwrites.";
427     return false;
428   }
429   const flatbuffers::Table* overlay =
430       reinterpret_cast<const flatbuffers::Table*>(
431           triggering_preconditions_overlay_);
432   const TriggeringPreconditions* defaults = model_->preconditions();
433   if (defaults == nullptr) {
434     TC3_LOG(ERROR) << "No triggering conditions specified.";
435     return false;
436   }
437 
438   preconditions_.min_smart_reply_triggering_score = ValueOrDefault(
439       overlay, TriggeringPreconditions::VT_MIN_SMART_REPLY_TRIGGERING_SCORE,
440       defaults->min_smart_reply_triggering_score());
441   preconditions_.max_sensitive_topic_score = ValueOrDefault(
442       overlay, TriggeringPreconditions::VT_MAX_SENSITIVE_TOPIC_SCORE,
443       defaults->max_sensitive_topic_score());
444   preconditions_.suppress_on_sensitive_topic = ValueOrDefault(
445       overlay, TriggeringPreconditions::VT_SUPPRESS_ON_SENSITIVE_TOPIC,
446       defaults->suppress_on_sensitive_topic());
447   preconditions_.min_input_length =
448       ValueOrDefault(overlay, TriggeringPreconditions::VT_MIN_INPUT_LENGTH,
449                      defaults->min_input_length());
450   preconditions_.max_input_length =
451       ValueOrDefault(overlay, TriggeringPreconditions::VT_MAX_INPUT_LENGTH,
452                      defaults->max_input_length());
453   preconditions_.min_locale_match_fraction = ValueOrDefault(
454       overlay, TriggeringPreconditions::VT_MIN_LOCALE_MATCH_FRACTION,
455       defaults->min_locale_match_fraction());
456   preconditions_.handle_missing_locale_as_supported = ValueOrDefault(
457       overlay, TriggeringPreconditions::VT_HANDLE_MISSING_LOCALE_AS_SUPPORTED,
458       defaults->handle_missing_locale_as_supported());
459   preconditions_.handle_unknown_locale_as_supported = ValueOrDefault(
460       overlay, TriggeringPreconditions::VT_HANDLE_UNKNOWN_LOCALE_AS_SUPPORTED,
461       defaults->handle_unknown_locale_as_supported());
462   preconditions_.suppress_on_low_confidence_input = ValueOrDefault(
463       overlay, TriggeringPreconditions::VT_SUPPRESS_ON_LOW_CONFIDENCE_INPUT,
464       defaults->suppress_on_low_confidence_input());
465   preconditions_.min_reply_score_threshold = ValueOrDefault(
466       overlay, TriggeringPreconditions::VT_MIN_REPLY_SCORE_THRESHOLD,
467       defaults->min_reply_score_threshold());
468 
469   return true;
470 }
471 
EmbedTokenId(const int32 token_id,std::vector<float> * embedding) const472 bool ActionsSuggestions::EmbedTokenId(const int32 token_id,
473                                       std::vector<float>* embedding) const {
474   return feature_processor_->AppendFeatures(
475       {token_id},
476       /*dense_features=*/{}, embedding_executor_.get(), embedding);
477 }
478 
Tokenize(const std::vector<std::string> & context) const479 std::vector<std::vector<Token>> ActionsSuggestions::Tokenize(
480     const std::vector<std::string>& context) const {
481   std::vector<std::vector<Token>> tokens;
482   tokens.reserve(context.size());
483   for (const std::string& message : context) {
484     tokens.push_back(feature_processor_->tokenizer()->Tokenize(message));
485   }
486   return tokens;
487 }
488 
EmbedTokensPerMessage(const std::vector<std::vector<Token>> & tokens,std::vector<float> * embeddings,int * max_num_tokens_per_message) const489 bool ActionsSuggestions::EmbedTokensPerMessage(
490     const std::vector<std::vector<Token>>& tokens,
491     std::vector<float>* embeddings, int* max_num_tokens_per_message) const {
492   const int num_messages = tokens.size();
493   *max_num_tokens_per_message = 0;
494   for (int i = 0; i < num_messages; i++) {
495     const int num_message_tokens = tokens[i].size();
496     if (num_message_tokens > *max_num_tokens_per_message) {
497       *max_num_tokens_per_message = num_message_tokens;
498     }
499   }
500 
501   if (model_->feature_processor_options()->min_num_tokens_per_message() >
502       *max_num_tokens_per_message) {
503     *max_num_tokens_per_message =
504         model_->feature_processor_options()->min_num_tokens_per_message();
505   }
506   if (model_->feature_processor_options()->max_num_tokens_per_message() > 0 &&
507       *max_num_tokens_per_message >
508           model_->feature_processor_options()->max_num_tokens_per_message()) {
509     *max_num_tokens_per_message =
510         model_->feature_processor_options()->max_num_tokens_per_message();
511   }
512 
513   // Embed all tokens and add paddings to pad tokens of each message to the
514   // maximum number of tokens in a message of the conversation.
515   // If a number of tokens is specified in the model config, tokens at the
516   // beginning of a message are dropped if they don't fit in the limit.
517   for (int i = 0; i < num_messages; i++) {
518     const int start =
519         std::max<int>(tokens[i].size() - *max_num_tokens_per_message, 0);
520     for (int pos = start; pos < tokens[i].size(); pos++) {
521       if (!feature_processor_->AppendTokenFeatures(
522               tokens[i][pos], embedding_executor_.get(), embeddings)) {
523         TC3_LOG(ERROR) << "Could not run token feature extractor.";
524         return false;
525       }
526     }
527     // Add padding.
528     for (int k = tokens[i].size(); k < *max_num_tokens_per_message; k++) {
529       embeddings->insert(embeddings->end(), embedded_padding_token_.begin(),
530                          embedded_padding_token_.end());
531     }
532   }
533 
534   return true;
535 }
536 
EmbedAndFlattenTokens(const std::vector<std::vector<Token>> & tokens,std::vector<float> * embeddings,int * total_token_count) const537 bool ActionsSuggestions::EmbedAndFlattenTokens(
538     const std::vector<std::vector<Token>>& tokens,
539     std::vector<float>* embeddings, int* total_token_count) const {
540   const int num_messages = tokens.size();
541   int start_message = 0;
542   int message_token_offset = 0;
543 
544   // If a maximum model input length is specified, we need to check how
545   // much we need to trim at the start.
546   const int max_num_total_tokens =
547       model_->feature_processor_options()->max_num_total_tokens();
548   if (max_num_total_tokens > 0) {
549     int total_tokens = 0;
550     start_message = num_messages - 1;
551     for (; start_message >= 0; start_message--) {
552       // Tokens of the message + start and end token.
553       const int num_message_tokens = tokens[start_message].size() + 2;
554       total_tokens += num_message_tokens;
555 
556       // Check whether we exhausted the budget.
557       if (total_tokens >= max_num_total_tokens) {
558         message_token_offset = total_tokens - max_num_total_tokens;
559         break;
560       }
561     }
562   }
563 
564   // Add embeddings.
565   *total_token_count = 0;
566   for (int i = start_message; i < num_messages; i++) {
567     if (message_token_offset == 0) {
568       ++(*total_token_count);
569       // Add `start message` token.
570       embeddings->insert(embeddings->end(), embedded_start_token_.begin(),
571                          embedded_start_token_.end());
572     }
573 
574     for (int pos = std::max(0, message_token_offset - 1);
575          pos < tokens[i].size(); pos++) {
576       ++(*total_token_count);
577       if (!feature_processor_->AppendTokenFeatures(
578               tokens[i][pos], embedding_executor_.get(), embeddings)) {
579         TC3_LOG(ERROR) << "Could not run token feature extractor.";
580         return false;
581       }
582     }
583 
584     // Add `end message` token.
585     ++(*total_token_count);
586     embeddings->insert(embeddings->end(), embedded_end_token_.begin(),
587                        embedded_end_token_.end());
588 
589     // Reset for the subsequent messages.
590     message_token_offset = 0;
591   }
592 
593   // Add optional padding.
594   const int min_num_total_tokens =
595       model_->feature_processor_options()->min_num_total_tokens();
596   for (; *total_token_count < min_num_total_tokens; ++(*total_token_count)) {
597     embeddings->insert(embeddings->end(), embedded_padding_token_.begin(),
598                        embedded_padding_token_.end());
599   }
600 
601   return true;
602 }
603 
AllocateInput(const int conversation_length,const int max_tokens,const int total_token_count,tflite::Interpreter * interpreter) const604 bool ActionsSuggestions::AllocateInput(const int conversation_length,
605                                        const int max_tokens,
606                                        const int total_token_count,
607                                        tflite::Interpreter* interpreter) const {
608   if (model_->tflite_model_spec()->resize_inputs()) {
609     if (model_->tflite_model_spec()->input_context() >= 0) {
610       interpreter->ResizeInputTensor(
611           interpreter->inputs()[model_->tflite_model_spec()->input_context()],
612           {1, conversation_length});
613     }
614     if (model_->tflite_model_spec()->input_user_id() >= 0) {
615       interpreter->ResizeInputTensor(
616           interpreter->inputs()[model_->tflite_model_spec()->input_user_id()],
617           {1, conversation_length});
618     }
619     if (model_->tflite_model_spec()->input_time_diffs() >= 0) {
620       interpreter->ResizeInputTensor(
621           interpreter
622               ->inputs()[model_->tflite_model_spec()->input_time_diffs()],
623           {1, conversation_length});
624     }
625     if (model_->tflite_model_spec()->input_num_tokens() >= 0) {
626       interpreter->ResizeInputTensor(
627           interpreter
628               ->inputs()[model_->tflite_model_spec()->input_num_tokens()],
629           {conversation_length, 1});
630     }
631     if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
632       interpreter->ResizeInputTensor(
633           interpreter
634               ->inputs()[model_->tflite_model_spec()->input_token_embeddings()],
635           {conversation_length, max_tokens, token_embedding_size_});
636     }
637     if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
638       interpreter->ResizeInputTensor(
639           interpreter->inputs()[model_->tflite_model_spec()
640                                     ->input_flattened_token_embeddings()],
641           {1, total_token_count});
642     }
643   }
644 
645   return interpreter->AllocateTensors() == kTfLiteOk;
646 }
647 
SetupModelInput(const std::vector<std::string> & context,const std::vector<int> & user_ids,const std::vector<float> & time_diffs,const int num_suggestions,const ActionSuggestionOptions & options,tflite::Interpreter * interpreter) const648 bool ActionsSuggestions::SetupModelInput(
649     const std::vector<std::string>& context, const std::vector<int>& user_ids,
650     const std::vector<float>& time_diffs, const int num_suggestions,
651     const ActionSuggestionOptions& options,
652     tflite::Interpreter* interpreter) const {
653   // Compute token embeddings.
654   std::vector<std::vector<Token>> tokens;
655   std::vector<float> token_embeddings;
656   std::vector<float> flattened_token_embeddings;
657   int max_tokens = 0;
658   int total_token_count = 0;
659   if (model_->tflite_model_spec()->input_num_tokens() >= 0 ||
660       model_->tflite_model_spec()->input_token_embeddings() >= 0 ||
661       model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
662     if (feature_processor_ == nullptr) {
663       TC3_LOG(ERROR) << "No feature processor specified.";
664       return false;
665     }
666 
667     // Tokenize the messages in the conversation.
668     tokens = Tokenize(context);
669     if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
670       if (!EmbedTokensPerMessage(tokens, &token_embeddings, &max_tokens)) {
671         TC3_LOG(ERROR) << "Could not extract token features.";
672         return false;
673       }
674     }
675     if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
676       if (!EmbedAndFlattenTokens(tokens, &flattened_token_embeddings,
677                                  &total_token_count)) {
678         TC3_LOG(ERROR) << "Could not extract token features.";
679         return false;
680       }
681     }
682   }
683 
684   if (!AllocateInput(context.size(), max_tokens, total_token_count,
685                      interpreter)) {
686     TC3_LOG(ERROR) << "TensorFlow Lite model allocation failed.";
687     return false;
688   }
689   if (model_->tflite_model_spec()->input_context() >= 0) {
690     if (model_->tflite_model_spec()->input_length_to_pad() > 0) {
691       model_executor_->SetInput<std::string>(
692           model_->tflite_model_spec()->input_context(),
693           PadOrTruncateToTargetLength(
694               context, model_->tflite_model_spec()->input_length_to_pad(),
695               std::string("")),
696           interpreter);
697     } else {
698       model_executor_->SetInput<std::string>(
699           model_->tflite_model_spec()->input_context(), context, interpreter);
700     }
701   }
702   if (model_->tflite_model_spec()->input_context_length() >= 0) {
703     model_executor_->SetInput<int>(
704         model_->tflite_model_spec()->input_context_length(), context.size(),
705         interpreter);
706   }
707   if (model_->tflite_model_spec()->input_user_id() >= 0) {
708     if (model_->tflite_model_spec()->input_length_to_pad() > 0) {
709       model_executor_->SetInput<int>(
710           model_->tflite_model_spec()->input_user_id(),
711           PadOrTruncateToTargetLength(
712               user_ids, model_->tflite_model_spec()->input_length_to_pad(), 0),
713           interpreter);
714     } else {
715       model_executor_->SetInput<int>(
716           model_->tflite_model_spec()->input_user_id(), user_ids, interpreter);
717     }
718   }
719   if (model_->tflite_model_spec()->input_num_suggestions() >= 0) {
720     model_executor_->SetInput<int>(
721         model_->tflite_model_spec()->input_num_suggestions(), num_suggestions,
722         interpreter);
723   }
724   if (model_->tflite_model_spec()->input_time_diffs() >= 0) {
725     model_executor_->SetInput<float>(
726         model_->tflite_model_spec()->input_time_diffs(), time_diffs,
727         interpreter);
728   }
729   if (model_->tflite_model_spec()->input_num_tokens() >= 0) {
730     std::vector<int> num_tokens_per_message(tokens.size());
731     for (int i = 0; i < tokens.size(); i++) {
732       num_tokens_per_message[i] = tokens[i].size();
733     }
734     model_executor_->SetInput<int>(
735         model_->tflite_model_spec()->input_num_tokens(), num_tokens_per_message,
736         interpreter);
737   }
738   if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
739     model_executor_->SetInput<float>(
740         model_->tflite_model_spec()->input_token_embeddings(), token_embeddings,
741         interpreter);
742   }
743   if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
744     model_executor_->SetInput<float>(
745         model_->tflite_model_spec()->input_flattened_token_embeddings(),
746         flattened_token_embeddings, interpreter);
747   }
748   // Set up additional input parameters.
749   if (const auto* input_name_index =
750           model_->tflite_model_spec()->input_name_index()) {
751     const std::unordered_map<std::string, Variant>& model_parameters =
752         options.model_parameters;
753     for (const TensorflowLiteModelSpec_::InputNameIndexEntry* entry :
754          *input_name_index) {
755       const std::string param_name = entry->key()->str();
756       const int param_index = entry->value();
757       const TfLiteType param_type =
758           interpreter->tensor(interpreter->inputs()[param_index])->type;
759       const auto param_value_it = model_parameters.find(param_name);
760       const bool has_value = param_value_it != model_parameters.end();
761       switch (param_type) {
762         case kTfLiteFloat32:
763           if (has_value) {
764             SetVectorOrScalarAsModelInput<float>(param_index,
765                                                  param_value_it->second,
766                                                  interpreter, model_executor_);
767           } else {
768             model_executor_->SetInput<float>(param_index, kDefaultFloat,
769                                              interpreter);
770           }
771           break;
772         case kTfLiteInt32:
773           if (has_value) {
774             SetVectorOrScalarAsModelInput<int32_t>(
775                 param_index, param_value_it->second, interpreter,
776                 model_executor_);
777           } else {
778             model_executor_->SetInput<int32_t>(param_index, kDefaultInt,
779                                                interpreter);
780           }
781           break;
782         case kTfLiteInt64:
783           model_executor_->SetInput<int64_t>(
784               param_index,
785               has_value ? param_value_it->second.Value<int64>() : kDefaultInt,
786               interpreter);
787           break;
788         case kTfLiteUInt8:
789           model_executor_->SetInput<uint8_t>(
790               param_index,
791               has_value ? param_value_it->second.Value<uint8>() : kDefaultInt,
792               interpreter);
793           break;
794         case kTfLiteInt8:
795           model_executor_->SetInput<int8_t>(
796               param_index,
797               has_value ? param_value_it->second.Value<int8>() : kDefaultInt,
798               interpreter);
799           break;
800         case kTfLiteBool:
801           model_executor_->SetInput<bool>(
802               param_index,
803               has_value ? param_value_it->second.Value<bool>() : kDefaultBool,
804               interpreter);
805           break;
806         default:
807           TC3_LOG(ERROR) << "Unsupported type of additional input parameter: "
808                          << param_name;
809       }
810     }
811   }
812   return true;
813 }
814 
PopulateTextReplies(const tflite::Interpreter * interpreter,int suggestion_index,int score_index,const std::string & type,float priority_score,const absl::flat_hash_set<std::string> & blocklist,const absl::flat_hash_map<std::string,std::vector<std::string>> & concept_mappings,ActionsSuggestionsResponse * response) const815 void ActionsSuggestions::PopulateTextReplies(
816     const tflite::Interpreter* interpreter, int suggestion_index,
817     int score_index, const std::string& type, float priority_score,
818     const absl::flat_hash_set<std::string>& blocklist,
819     const absl::flat_hash_map<std::string, std::vector<std::string>>&
820         concept_mappings,
821     ActionsSuggestionsResponse* response) const {
822   const std::vector<tflite::StringRef> replies =
823       model_executor_->Output<tflite::StringRef>(suggestion_index, interpreter);
824   const TensorView<float> scores =
825       model_executor_->OutputView<float>(score_index, interpreter);
826 
827   for (int i = 0; i < replies.size(); i++) {
828     if (replies[i].len == 0) {
829       continue;
830     }
831     const float score = scores.data()[i];
832     if (score < preconditions_.min_reply_score_threshold) {
833       continue;
834     }
835     std::string response_text(replies[i].str, replies[i].len);
836     if (blocklist.contains(response_text)) {
837       continue;
838     }
839     if (concept_mappings.contains(response_text)) {
840       const int candidates_size = concept_mappings.at(response_text).size();
841       const int candidate_index = absl::Uniform<int>(
842           absl::IntervalOpenOpen, bit_gen_, 0, candidates_size);
843       response_text = concept_mappings.at(response_text)[candidate_index];
844     }
845 
846     response->actions.push_back({response_text, type, score, priority_score});
847   }
848 }
849 
FillSuggestionFromSpecWithEntityData(const ActionSuggestionSpec * spec,ActionSuggestion * suggestion) const850 void ActionsSuggestions::FillSuggestionFromSpecWithEntityData(
851     const ActionSuggestionSpec* spec, ActionSuggestion* suggestion) const {
852   std::unique_ptr<MutableFlatbuffer> entity_data =
853       entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
854                                       : nullptr;
855   FillSuggestionFromSpec(spec, entity_data.get(), suggestion);
856 }
857 
PopulateIntentTriggering(const tflite::Interpreter * interpreter,int suggestion_index,int score_index,const ActionSuggestionSpec * task_spec,ActionsSuggestionsResponse * response) const858 void ActionsSuggestions::PopulateIntentTriggering(
859     const tflite::Interpreter* interpreter, int suggestion_index,
860     int score_index, const ActionSuggestionSpec* task_spec,
861     ActionsSuggestionsResponse* response) const {
862   if (!task_spec || task_spec->type()->size() == 0) {
863     TC3_LOG(ERROR)
864         << "Task type for intent (action) triggering cannot be empty!";
865     return;
866   }
867   const TensorView<bool> intent_prediction =
868       model_executor_->OutputView<bool>(suggestion_index, interpreter);
869   const TensorView<float> intent_scores =
870       model_executor_->OutputView<float>(score_index, interpreter);
871   // Two result corresponding to binary triggering case.
872   TC3_CHECK_EQ(intent_prediction.size(), 2);
873   TC3_CHECK_EQ(intent_scores.size(), 2);
874   // We rely on in-graph thresholding logic so at this point the results
875   // have been ranked properly according to threshold.
876   const bool triggering = intent_prediction.data()[0];
877   const float trigger_score = intent_scores.data()[0];
878 
879   if (triggering) {
880     ActionSuggestion suggestion;
881     std::unique_ptr<MutableFlatbuffer> entity_data =
882         entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
883                                         : nullptr;
884     FillSuggestionFromSpecWithEntityData(task_spec, &suggestion);
885     suggestion.score = trigger_score;
886     response->actions.push_back(std::move(suggestion));
887   }
888 }
889 
ReadModelOutput(tflite::Interpreter * interpreter,const ActionSuggestionOptions & options,ActionsSuggestionsResponse * response) const890 bool ActionsSuggestions::ReadModelOutput(
891     tflite::Interpreter* interpreter, const ActionSuggestionOptions& options,
892     ActionsSuggestionsResponse* response) const {
893   // Read sensitivity and triggering score predictions.
894   if (model_->tflite_model_spec()->output_triggering_score() >= 0) {
895     const TensorView<float> triggering_score =
896         model_executor_->OutputView<float>(
897             model_->tflite_model_spec()->output_triggering_score(),
898             interpreter);
899     if (!triggering_score.is_valid() || triggering_score.size() == 0) {
900       TC3_LOG(ERROR) << "Could not compute triggering score.";
901       return false;
902     }
903     response->triggering_score = triggering_score.data()[0];
904     response->output_filtered_min_triggering_score =
905         (response->triggering_score <
906          preconditions_.min_smart_reply_triggering_score);
907   }
908   if (model_->tflite_model_spec()->output_sensitive_topic_score() >= 0) {
909     const TensorView<float> sensitive_topic_score =
910         model_executor_->OutputView<float>(
911             model_->tflite_model_spec()->output_sensitive_topic_score(),
912             interpreter);
913     if (!sensitive_topic_score.is_valid() ||
914         sensitive_topic_score.dim(0) != 1) {
915       TC3_LOG(ERROR) << "Could not compute sensitive topic score.";
916       return false;
917     }
918     response->sensitivity_score = sensitive_topic_score.data()[0];
919     response->is_sensitive = (response->sensitivity_score >
920                               preconditions_.max_sensitive_topic_score);
921   }
922 
923   // Suppress model outputs.
924   if (response->is_sensitive) {
925     return true;
926   }
927 
928   // Read smart reply predictions.
929   if (!response->output_filtered_min_triggering_score &&
930       model_->tflite_model_spec()->output_replies() >= 0) {
931     absl::flat_hash_set<std::string> empty_blocklist;
932     PopulateTextReplies(
933         interpreter, model_->tflite_model_spec()->output_replies(),
934         model_->tflite_model_spec()->output_replies_scores(),
935         model_->smart_reply_action_type()->str(),
936         /* priority_score */ 0.0, empty_blocklist, {}, response);
937   }
938 
939   // Read actions suggestions.
940   if (model_->tflite_model_spec()->output_actions_scores() >= 0) {
941     const TensorView<float> actions_scores = model_executor_->OutputView<float>(
942         model_->tflite_model_spec()->output_actions_scores(), interpreter);
943     for (int i = 0; i < model_->action_type()->size(); i++) {
944       const ActionTypeOptions* action_type = model_->action_type()->Get(i);
945       // Skip disabled action classes, such as the default other category.
946       if (!action_type->enabled()) {
947         continue;
948       }
949       const float score = actions_scores.data()[i];
950       if (score < action_type->min_triggering_score()) {
951         continue;
952       }
953 
954       // Create action from model output.
955       ActionSuggestion suggestion;
956       suggestion.type = action_type->name()->str();
957       std::unique_ptr<MutableFlatbuffer> entity_data =
958           entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
959                                           : nullptr;
960       FillSuggestionFromSpecWithEntityData(action_type->action(), &suggestion);
961       suggestion.score = score;
962       response->actions.push_back(std::move(suggestion));
963     }
964   }
965 
966   // Read multi-task predictions and construct the result properly.
967   if (const auto* prediction_metadata =
968           model_->tflite_model_spec()->prediction_metadata()) {
969     for (const PredictionMetadata* metadata : *prediction_metadata) {
970       const ActionSuggestionSpec* task_spec = metadata->task_spec();
971       const int suggestions_index = metadata->output_suggestions();
972       const int suggestions_scores_index =
973           metadata->output_suggestions_scores();
974       absl::flat_hash_set<std::string> response_text_blocklist;
975       absl::flat_hash_map<std::string, std::vector<std::string>>
976           concept_mappings;
977       switch (metadata->prediction_type()) {
978         case PredictionType_NEXT_MESSAGE_PREDICTION:
979           if (!task_spec || task_spec->type()->size() == 0) {
980             TC3_LOG(WARNING) << "Task type not provided, use default "
981                                 "smart_reply_action_type!";
982           }
983           if (task_spec) {
984             if (task_spec->response_text_blocklist()) {
985               for (const auto& val : *task_spec->response_text_blocklist()) {
986                 response_text_blocklist.insert(val->str());
987               }
988             }
989             if (task_spec->concept_mappings()) {
990               for (const auto& concept : *task_spec->concept_mappings()) {
991                 std::vector<std::string> candidates;
992                 for (const auto& candidate : *concept->candidates()) {
993                   candidates.push_back(candidate->str());
994                 }
995                 concept_mappings[concept->concept_name()->str()] = candidates;
996               }
997             }
998           }
999           PopulateTextReplies(
1000               interpreter, suggestions_index, suggestions_scores_index,
1001               task_spec ? task_spec->type()->str()
1002                         : model_->smart_reply_action_type()->str(),
1003               task_spec ? task_spec->priority_score() : 0.0,
1004               response_text_blocklist, concept_mappings, response);
1005           break;
1006         case PredictionType_INTENT_TRIGGERING:
1007           PopulateIntentTriggering(interpreter, suggestions_index,
1008                                    suggestions_scores_index, task_spec,
1009                                    response);
1010           break;
1011         default:
1012           TC3_LOG(ERROR) << "Unsupported prediction type!";
1013           return false;
1014       }
1015     }
1016   }
1017 
1018   return true;
1019 }
1020 
SuggestActionsFromModel(const Conversation & conversation,const int num_messages,const ActionSuggestionOptions & options,ActionsSuggestionsResponse * response,std::unique_ptr<tflite::Interpreter> * interpreter) const1021 bool ActionsSuggestions::SuggestActionsFromModel(
1022     const Conversation& conversation, const int num_messages,
1023     const ActionSuggestionOptions& options,
1024     ActionsSuggestionsResponse* response,
1025     std::unique_ptr<tflite::Interpreter>* interpreter) const {
1026   TC3_CHECK_LE(num_messages, conversation.messages.size());
1027 
1028   if (sensitive_model_ != nullptr &&
1029       sensitive_model_->EvalConversation(conversation, num_messages).first) {
1030     response->is_sensitive = true;
1031     return true;
1032   }
1033 
1034   if (!model_executor_) {
1035     return true;
1036   }
1037   *interpreter = model_executor_->CreateInterpreter();
1038 
1039   if (!*interpreter) {
1040     TC3_LOG(ERROR) << "Could not build TensorFlow Lite interpreter for the "
1041                       "actions suggestions model.";
1042     return false;
1043   }
1044 
1045   std::vector<std::string> context;
1046   std::vector<int> user_ids;
1047   std::vector<float> time_diffs;
1048   context.reserve(num_messages);
1049   user_ids.reserve(num_messages);
1050   time_diffs.reserve(num_messages);
1051 
1052   // Gather last `num_messages` messages from the conversation.
1053   int64 last_message_reference_time_ms_utc = 0;
1054   const float second_in_ms = 1000;
1055   for (int i = conversation.messages.size() - num_messages;
1056        i < conversation.messages.size(); i++) {
1057     const ConversationMessage& message = conversation.messages[i];
1058     context.push_back(message.text);
1059     user_ids.push_back(message.user_id);
1060 
1061     float time_diff_secs = 0;
1062     if (message.reference_time_ms_utc != 0 &&
1063         last_message_reference_time_ms_utc != 0) {
1064       time_diff_secs = std::max(0.0f, (message.reference_time_ms_utc -
1065                                        last_message_reference_time_ms_utc) /
1066                                           second_in_ms);
1067     }
1068     if (message.reference_time_ms_utc != 0) {
1069       last_message_reference_time_ms_utc = message.reference_time_ms_utc;
1070     }
1071     time_diffs.push_back(time_diff_secs);
1072   }
1073 
1074   if (!SetupModelInput(context, user_ids, time_diffs,
1075                        /*num_suggestions=*/model_->num_smart_replies(), options,
1076                        interpreter->get())) {
1077     TC3_LOG(ERROR) << "Failed to setup input for TensorFlow Lite model.";
1078     return false;
1079   }
1080 
1081   if ((*interpreter)->Invoke() != kTfLiteOk) {
1082     TC3_LOG(ERROR) << "Failed to invoke TensorFlow Lite interpreter.";
1083     return false;
1084   }
1085 
1086   return ReadModelOutput(interpreter->get(), options, response);
1087 }
1088 
SuggestActionsFromConversationIntentDetection(const Conversation & conversation,const ActionSuggestionOptions & options,std::vector<ActionSuggestion> * actions) const1089 Status ActionsSuggestions::SuggestActionsFromConversationIntentDetection(
1090     const Conversation& conversation, const ActionSuggestionOptions& options,
1091     std::vector<ActionSuggestion>* actions) const {
1092   TC3_ASSIGN_OR_RETURN(
1093       std::vector<ActionSuggestion> new_actions,
1094       conversation_intent_detection_->SuggestActions(conversation, options));
1095   for (auto& action : new_actions) {
1096     actions->push_back(std::move(action));
1097   }
1098   return Status::OK;
1099 }
1100 
AnnotationOptionsForMessage(const ConversationMessage & message) const1101 AnnotationOptions ActionsSuggestions::AnnotationOptionsForMessage(
1102     const ConversationMessage& message) const {
1103   AnnotationOptions options;
1104   options.detected_text_language_tags = message.detected_text_language_tags;
1105   options.reference_time_ms_utc = message.reference_time_ms_utc;
1106   options.reference_timezone = message.reference_timezone;
1107   options.annotation_usecase =
1108       model_->annotation_actions_spec()->annotation_usecase();
1109   options.is_serialized_entity_data_enabled =
1110       model_->annotation_actions_spec()->is_serialized_entity_data_enabled();
1111   options.entity_types = annotation_entity_types_;
1112   return options;
1113 }
1114 
1115 // Run annotator on the messages of a conversation.
AnnotateConversation(const Conversation & conversation,const Annotator * annotator) const1116 Conversation ActionsSuggestions::AnnotateConversation(
1117     const Conversation& conversation, const Annotator* annotator) const {
1118   if (annotator == nullptr) {
1119     return conversation;
1120   }
1121   const int num_messages_grammar =
1122       ((model_->rules() && model_->rules()->grammar_rules() &&
1123         model_->rules()
1124             ->grammar_rules()
1125             ->rules()
1126             ->nonterminals()
1127             ->annotation_nt())
1128            ? 1
1129            : 0);
1130   const int num_messages_mapping =
1131       (model_->annotation_actions_spec()
1132            ? std::max(model_->annotation_actions_spec()
1133                           ->max_history_from_any_person(),
1134                       model_->annotation_actions_spec()
1135                           ->max_history_from_last_person())
1136            : 0);
1137   const int num_messages = std::max(num_messages_grammar, num_messages_mapping);
1138   if (num_messages == 0) {
1139     // No annotations are used.
1140     return conversation;
1141   }
1142   Conversation annotated_conversation = conversation;
1143   for (int i = 0, message_index = annotated_conversation.messages.size() - 1;
1144        i < num_messages && message_index >= 0; i++, message_index--) {
1145     ConversationMessage* message =
1146         &annotated_conversation.messages[message_index];
1147     if (message->annotations.empty()) {
1148       message->annotations = annotator->Annotate(
1149           message->text, AnnotationOptionsForMessage(*message));
1150       ConvertDatetimeToTime(&message->annotations);
1151     }
1152   }
1153   return annotated_conversation;
1154 }
1155 
SuggestActionsFromAnnotations(const Conversation & conversation,std::vector<ActionSuggestion> * actions) const1156 void ActionsSuggestions::SuggestActionsFromAnnotations(
1157     const Conversation& conversation,
1158     std::vector<ActionSuggestion>* actions) const {
1159   if (model_->annotation_actions_spec() == nullptr ||
1160       model_->annotation_actions_spec()->annotation_mapping() == nullptr ||
1161       model_->annotation_actions_spec()->annotation_mapping()->size() == 0) {
1162     return;
1163   }
1164 
1165   // Create actions based on the annotations.
1166   const int max_from_any_person =
1167       model_->annotation_actions_spec()->max_history_from_any_person();
1168   const int max_from_last_person =
1169       model_->annotation_actions_spec()->max_history_from_last_person();
1170   const int last_person = conversation.messages.back().user_id;
1171 
1172   int num_messages_last_person = 0;
1173   int num_messages_any_person = 0;
1174   bool all_from_last_person = true;
1175   for (int message_index = conversation.messages.size() - 1; message_index >= 0;
1176        message_index--) {
1177     const ConversationMessage& message = conversation.messages[message_index];
1178     std::vector<AnnotatedSpan> annotations = message.annotations;
1179 
1180     // Update how many messages we have processed from the last person in the
1181     // conversation and from any person in the conversation.
1182     num_messages_any_person++;
1183     if (all_from_last_person && message.user_id == last_person) {
1184       num_messages_last_person++;
1185     } else {
1186       all_from_last_person = false;
1187     }
1188 
1189     if (num_messages_any_person > max_from_any_person &&
1190         (!all_from_last_person ||
1191          num_messages_last_person > max_from_last_person)) {
1192       break;
1193     }
1194 
1195     if (message.user_id == kLocalUserId) {
1196       if (model_->annotation_actions_spec()->only_until_last_sent()) {
1197         break;
1198       }
1199       if (!model_->annotation_actions_spec()->include_local_user_messages()) {
1200         continue;
1201       }
1202     }
1203 
1204     std::vector<ActionSuggestionAnnotation> action_annotations;
1205     action_annotations.reserve(annotations.size());
1206     for (const AnnotatedSpan& annotation : annotations) {
1207       if (annotation.classification.empty()) {
1208         continue;
1209       }
1210 
1211       const ClassificationResult& classification_result =
1212           annotation.classification[0];
1213 
1214       ActionSuggestionAnnotation action_annotation;
1215       action_annotation.span = {
1216           message_index, annotation.span,
1217           UTF8ToUnicodeText(message.text, /*do_copy=*/false)
1218               .UTF8Substring(annotation.span.first, annotation.span.second)};
1219       action_annotation.entity = classification_result;
1220       action_annotation.name = classification_result.collection;
1221       action_annotations.push_back(std::move(action_annotation));
1222     }
1223 
1224     if (model_->annotation_actions_spec()->deduplicate_annotations()) {
1225       // Create actions only for deduplicated annotations.
1226       for (const int annotation_id :
1227            DeduplicateAnnotations(action_annotations)) {
1228         SuggestActionsFromAnnotation(
1229             message_index, action_annotations[annotation_id], actions);
1230       }
1231     } else {
1232       // Create actions for all annotations.
1233       for (const ActionSuggestionAnnotation& annotation : action_annotations) {
1234         SuggestActionsFromAnnotation(message_index, annotation, actions);
1235       }
1236     }
1237   }
1238 }
1239 
SuggestActionsFromAnnotation(const int message_index,const ActionSuggestionAnnotation & annotation,std::vector<ActionSuggestion> * actions) const1240 void ActionsSuggestions::SuggestActionsFromAnnotation(
1241     const int message_index, const ActionSuggestionAnnotation& annotation,
1242     std::vector<ActionSuggestion>* actions) const {
1243   for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
1244        *model_->annotation_actions_spec()->annotation_mapping()) {
1245     if (annotation.entity.collection ==
1246         mapping->annotation_collection()->str()) {
1247       if (annotation.entity.score < mapping->min_annotation_score()) {
1248         continue;
1249       }
1250 
1251       std::unique_ptr<MutableFlatbuffer> entity_data =
1252           entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
1253                                           : nullptr;
1254 
1255       // Set annotation text as (additional) entity data field.
1256       if (mapping->entity_field() != nullptr) {
1257         TC3_CHECK_NE(entity_data, nullptr);
1258 
1259         UnicodeText normalized_annotation_text =
1260             UTF8ToUnicodeText(annotation.span.text, /*do_copy=*/false);
1261 
1262         // Apply normalization if specified.
1263         if (mapping->normalization_options() != nullptr) {
1264           normalized_annotation_text =
1265               NormalizeText(*unilib_, mapping->normalization_options(),
1266                             normalized_annotation_text);
1267         }
1268 
1269         entity_data->ParseAndSet(mapping->entity_field(),
1270                                  normalized_annotation_text.ToUTF8String());
1271       }
1272 
1273       ActionSuggestion suggestion;
1274       FillSuggestionFromSpec(mapping->action(), entity_data.get(), &suggestion);
1275       if (mapping->use_annotation_score()) {
1276         suggestion.score = annotation.entity.score;
1277       }
1278       suggestion.annotations = {annotation};
1279       actions->push_back(std::move(suggestion));
1280     }
1281   }
1282 }
1283 
DeduplicateAnnotations(const std::vector<ActionSuggestionAnnotation> & annotations) const1284 std::vector<int> ActionsSuggestions::DeduplicateAnnotations(
1285     const std::vector<ActionSuggestionAnnotation>& annotations) const {
1286   std::map<std::pair<std::string, std::string>, int> deduplicated_annotations;
1287 
1288   for (int i = 0; i < annotations.size(); i++) {
1289     const std::pair<std::string, std::string> key = {annotations[i].name,
1290                                                      annotations[i].span.text};
1291     auto entry = deduplicated_annotations.find(key);
1292     if (entry != deduplicated_annotations.end()) {
1293       // Kepp the annotation with the higher score.
1294       if (annotations[entry->second].entity.score <
1295           annotations[i].entity.score) {
1296         entry->second = i;
1297       }
1298       continue;
1299     }
1300     deduplicated_annotations.insert(entry, {key, i});
1301   }
1302 
1303   std::vector<int> result;
1304   result.reserve(deduplicated_annotations.size());
1305   for (const auto& key_and_annotation : deduplicated_annotations) {
1306     result.push_back(key_and_annotation.second);
1307   }
1308   return result;
1309 }
1310 
1311 #if !defined(TC3_DISABLE_LUA)
SuggestActionsFromLua(const Conversation & conversation,const TfLiteModelExecutor * model_executor,const tflite::Interpreter * interpreter,const reflection::Schema * annotation_entity_data_schema,std::vector<ActionSuggestion> * actions) const1312 bool ActionsSuggestions::SuggestActionsFromLua(
1313     const Conversation& conversation, const TfLiteModelExecutor* model_executor,
1314     const tflite::Interpreter* interpreter,
1315     const reflection::Schema* annotation_entity_data_schema,
1316     std::vector<ActionSuggestion>* actions) const {
1317   if (lua_bytecode_.empty()) {
1318     return true;
1319   }
1320 
1321   auto lua_actions = LuaActionsSuggestions::CreateLuaActionsSuggestions(
1322       lua_bytecode_, conversation, model_executor, model_->tflite_model_spec(),
1323       interpreter, entity_data_schema_, annotation_entity_data_schema);
1324   if (lua_actions == nullptr) {
1325     TC3_LOG(ERROR) << "Could not create lua actions.";
1326     return false;
1327   }
1328   return lua_actions->SuggestActions(actions);
1329 }
1330 #else
SuggestActionsFromLua(const Conversation & conversation,const TfLiteModelExecutor * model_executor,const tflite::Interpreter * interpreter,const reflection::Schema * annotation_entity_data_schema,std::vector<ActionSuggestion> * actions) const1331 bool ActionsSuggestions::SuggestActionsFromLua(
1332     const Conversation& conversation, const TfLiteModelExecutor* model_executor,
1333     const tflite::Interpreter* interpreter,
1334     const reflection::Schema* annotation_entity_data_schema,
1335     std::vector<ActionSuggestion>* actions) const {
1336   return true;
1337 }
1338 #endif
1339 
GatherActionsSuggestions(const Conversation & conversation,const Annotator * annotator,const ActionSuggestionOptions & options,ActionsSuggestionsResponse * response) const1340 bool ActionsSuggestions::GatherActionsSuggestions(
1341     const Conversation& conversation, const Annotator* annotator,
1342     const ActionSuggestionOptions& options,
1343     ActionsSuggestionsResponse* response) const {
1344   if (conversation.messages.empty()) {
1345     return true;
1346   }
1347 
1348   // Run annotator against messages.
1349   const Conversation annotated_conversation =
1350       AnnotateConversation(conversation, annotator);
1351 
1352   const int num_messages = NumMessagesToConsider(
1353       annotated_conversation, model_->max_conversation_history_length());
1354 
1355   if (num_messages <= 0) {
1356     TC3_LOG(INFO) << "No messages provided for actions suggestions.";
1357     return false;
1358   }
1359 
1360   SuggestActionsFromAnnotations(annotated_conversation, &response->actions);
1361 
1362   if (grammar_actions_ != nullptr &&
1363       !grammar_actions_->SuggestActions(annotated_conversation,
1364                                         &response->actions)) {
1365     TC3_LOG(ERROR) << "Could not suggest actions from grammar rules.";
1366     return false;
1367   }
1368 
1369   int input_text_length = 0;
1370   int num_matching_locales = 0;
1371   for (int i = annotated_conversation.messages.size() - num_messages;
1372        i < annotated_conversation.messages.size(); i++) {
1373     input_text_length += annotated_conversation.messages[i].text.length();
1374     std::vector<Locale> message_languages;
1375     if (!ParseLocales(
1376             annotated_conversation.messages[i].detected_text_language_tags,
1377             &message_languages)) {
1378       continue;
1379     }
1380     if (Locale::IsAnyLocaleSupported(
1381             message_languages, locales_,
1382             preconditions_.handle_unknown_locale_as_supported)) {
1383       ++num_matching_locales;
1384     }
1385   }
1386 
1387   // Bail out if we are provided with too few or too much input.
1388   if (input_text_length < preconditions_.min_input_length ||
1389       (preconditions_.max_input_length >= 0 &&
1390        input_text_length > preconditions_.max_input_length)) {
1391     TC3_LOG(INFO) << "Too much or not enough input for inference.";
1392     return response;
1393   }
1394 
1395   // Bail out if the text does not look like it can be handled by the model.
1396   const float matching_fraction =
1397       static_cast<float>(num_matching_locales) / num_messages;
1398   if (matching_fraction < preconditions_.min_locale_match_fraction) {
1399     TC3_LOG(INFO) << "Not enough locale matches.";
1400     response->output_filtered_locale_mismatch = true;
1401     return true;
1402   }
1403 
1404   std::vector<const UniLib::RegexPattern*> post_check_rules;
1405   if (preconditions_.suppress_on_low_confidence_input) {
1406     if (regex_actions_->IsLowConfidenceInput(annotated_conversation,
1407                                              num_messages, &post_check_rules)) {
1408       response->output_filtered_low_confidence = true;
1409       return true;
1410     }
1411   }
1412 
1413   std::unique_ptr<tflite::Interpreter> interpreter;
1414   if (!SuggestActionsFromModel(annotated_conversation, num_messages, options,
1415                                response, &interpreter)) {
1416     TC3_LOG(ERROR) << "Could not run model.";
1417     return false;
1418   }
1419 
1420   // SuggestActionsFromModel also detects if the conversation is sensitive,
1421   // either by using the old ngram model or the new model.
1422   // Suppress all predictions if the conversation was deemed sensitive.
1423   if (preconditions_.suppress_on_sensitive_topic && response->is_sensitive) {
1424     return true;
1425   }
1426 
1427   if (conversation_intent_detection_) {
1428     // TODO(zbin): Ensure the deduplication/ranking logic in ranker.cc works.
1429     auto actions = SuggestActionsFromConversationIntentDetection(
1430         annotated_conversation, options, &response->actions);
1431     if (!actions.ok()) {
1432       TC3_LOG(ERROR) << "Could not run conversation intent detection: "
1433                      << actions.error_message();
1434       return false;
1435     }
1436   }
1437 
1438   if (!SuggestActionsFromLua(
1439           annotated_conversation, model_executor_.get(), interpreter.get(),
1440           annotator != nullptr ? annotator->entity_data_schema() : nullptr,
1441           &response->actions)) {
1442     TC3_LOG(ERROR) << "Could not suggest actions from script.";
1443     return false;
1444   }
1445 
1446   if (!regex_actions_->SuggestActions(annotated_conversation,
1447                                       entity_data_builder_.get(),
1448                                       &response->actions)) {
1449     TC3_LOG(ERROR) << "Could not suggest actions from regex rules.";
1450     return false;
1451   }
1452 
1453   if (preconditions_.suppress_on_low_confidence_input &&
1454       !regex_actions_->FilterConfidenceOutput(post_check_rules,
1455                                               &response->actions)) {
1456     TC3_LOG(ERROR) << "Could not post-check actions.";
1457     return false;
1458   }
1459 
1460   return true;
1461 }
1462 
SuggestActions(const Conversation & conversation,const Annotator * annotator,const ActionSuggestionOptions & options) const1463 ActionsSuggestionsResponse ActionsSuggestions::SuggestActions(
1464     const Conversation& conversation, const Annotator* annotator,
1465     const ActionSuggestionOptions& options) const {
1466   ActionsSuggestionsResponse response;
1467 
1468   // Assert that messages are sorted correctly.
1469   for (int i = 1; i < conversation.messages.size(); i++) {
1470     if (conversation.messages[i].reference_time_ms_utc <
1471         conversation.messages[i - 1].reference_time_ms_utc) {
1472       TC3_LOG(ERROR) << "Messages are not sorted most recent last.";
1473       return response;
1474     }
1475   }
1476 
1477   // Check that messages are valid utf8.
1478   for (const ConversationMessage& message : conversation.messages) {
1479     if (message.text.size() > std::numeric_limits<int>::max()) {
1480       TC3_LOG(ERROR) << "Rejecting too long input: " << message.text.size();
1481       return {};
1482     }
1483 
1484     if (!unilib_->IsValidUtf8(UTF8ToUnicodeText(
1485             message.text.data(), message.text.size(), /*do_copy=*/false))) {
1486       TC3_LOG(ERROR) << "Not valid utf8 provided.";
1487       return response;
1488     }
1489   }
1490 
1491   if (!GatherActionsSuggestions(conversation, annotator, options, &response)) {
1492     TC3_LOG(ERROR) << "Could not gather actions suggestions.";
1493     response.actions.clear();
1494   } else if (!ranker_->RankActions(conversation, &response, entity_data_schema_,
1495                                    annotator != nullptr
1496                                        ? annotator->entity_data_schema()
1497                                        : nullptr)) {
1498     TC3_LOG(ERROR) << "Could not rank actions.";
1499     response.actions.clear();
1500   }
1501   return response;
1502 }
1503 
SuggestActions(const Conversation & conversation,const ActionSuggestionOptions & options) const1504 ActionsSuggestionsResponse ActionsSuggestions::SuggestActions(
1505     const Conversation& conversation,
1506     const ActionSuggestionOptions& options) const {
1507   return SuggestActions(conversation, /*annotator=*/nullptr, options);
1508 }
1509 
model() const1510 const ActionsModel* ActionsSuggestions::model() const { return model_; }
entity_data_schema() const1511 const reflection::Schema* ActionsSuggestions::entity_data_schema() const {
1512   return entity_data_schema_;
1513 }
1514 
ViewActionsModel(const void * buffer,int size)1515 const ActionsModel* ViewActionsModel(const void* buffer, int size) {
1516   if (buffer == nullptr) {
1517     return nullptr;
1518   }
1519   return LoadAndVerifyModel(reinterpret_cast<const uint8_t*>(buffer), size);
1520 }
1521 
InitializeConversationIntentDetection(const std::string & serialized_config)1522 bool ActionsSuggestions::InitializeConversationIntentDetection(
1523     const std::string& serialized_config) {
1524   auto conversation_intent_detection =
1525       std::make_unique<ConversationIntentDetection>();
1526   if (!conversation_intent_detection->Initialize(serialized_config).ok()) {
1527     TC3_LOG(ERROR) << "Failed to initialize conversation intent detection.";
1528     return false;
1529   }
1530   conversation_intent_detection_ = std::move(conversation_intent_detection);
1531   return true;
1532 }
1533 
1534 }  // namespace libtextclassifier3
1535