1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "actions/actions-suggestions.h"
18
19 #include <memory>
20 #include <string>
21 #include <vector>
22
23 #include "utils/base/statusor.h"
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/random/random.h"
26
27 #if !defined(TC3_DISABLE_LUA)
28 #include "actions/lua-actions.h"
29 #endif
30 #include "actions/ngram-model.h"
31 #include "actions/tflite-sensitive-model.h"
32 #include "actions/types.h"
33 #include "actions/utils.h"
34 #include "actions/zlib-utils.h"
35 #include "annotator/collections.h"
36 #include "utils/base/logging.h"
37 #if !defined(TC3_DISABLE_LUA)
38 #include "utils/lua-utils.h"
39 #endif
40 #include "utils/normalization.h"
41 #include "utils/optional.h"
42 #include "utils/strings/split.h"
43 #include "utils/strings/stringpiece.h"
44 #include "utils/strings/utf8.h"
45 #include "utils/utf8/unicodetext.h"
46 #include "absl/container/flat_hash_set.h"
47 #include "absl/random/distributions.h"
48 #include "tensorflow/lite/string_util.h"
49
50 namespace libtextclassifier3 {
51
52 constexpr float kDefaultFloat = 0.0;
53 constexpr bool kDefaultBool = false;
54 constexpr int kDefaultInt = 1;
55
56 namespace {
57
LoadAndVerifyModel(const uint8_t * addr,int size)58 const ActionsModel* LoadAndVerifyModel(const uint8_t* addr, int size) {
59 flatbuffers::Verifier verifier(addr, size);
60 if (VerifyActionsModelBuffer(verifier)) {
61 return GetActionsModel(addr);
62 } else {
63 return nullptr;
64 }
65 }
66
67 template <typename T>
ValueOrDefault(const flatbuffers::Table * values,const int32 field_offset,const T default_value)68 T ValueOrDefault(const flatbuffers::Table* values, const int32 field_offset,
69 const T default_value) {
70 if (values == nullptr) {
71 return default_value;
72 }
73 return values->GetField<T>(field_offset, default_value);
74 }
75
76 // Returns number of (tail) messages of a conversation to consider.
NumMessagesToConsider(const Conversation & conversation,const int max_conversation_history_length)77 int NumMessagesToConsider(const Conversation& conversation,
78 const int max_conversation_history_length) {
79 return ((max_conversation_history_length < 0 ||
80 conversation.messages.size() < max_conversation_history_length)
81 ? conversation.messages.size()
82 : max_conversation_history_length);
83 }
84
85 template <typename T>
PadOrTruncateToTargetLength(const std::vector<T> & inputs,const int max_length,const T pad_value)86 std::vector<T> PadOrTruncateToTargetLength(const std::vector<T>& inputs,
87 const int max_length,
88 const T pad_value) {
89 if (inputs.size() >= max_length) {
90 return std::vector<T>(inputs.begin(), inputs.begin() + max_length);
91 } else {
92 std::vector<T> result;
93 result.reserve(max_length);
94 result.insert(result.begin(), inputs.begin(), inputs.end());
95 result.insert(result.end(), max_length - inputs.size(), pad_value);
96 return result;
97 }
98 }
99
100 template <typename T>
SetVectorOrScalarAsModelInput(const int param_index,const Variant & param_value,tflite::Interpreter * interpreter,const std::unique_ptr<const TfLiteModelExecutor> & model_executor)101 void SetVectorOrScalarAsModelInput(
102 const int param_index, const Variant& param_value,
103 tflite::Interpreter* interpreter,
104 const std::unique_ptr<const TfLiteModelExecutor>& model_executor) {
105 if (param_value.Has<std::vector<T>>()) {
106 model_executor->SetInput<T>(
107 param_index, param_value.ConstRefValue<std::vector<T>>(), interpreter);
108 } else if (param_value.Has<T>()) {
109 model_executor->SetInput<float>(param_index, param_value.Value<T>(),
110 interpreter);
111 } else {
112 TC3_LOG(ERROR) << "Variant type error!";
113 }
114 }
115 } // namespace
116
FromUnownedBuffer(const uint8_t * buffer,const int size,const UniLib * unilib,const std::string & triggering_preconditions_overlay)117 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromUnownedBuffer(
118 const uint8_t* buffer, const int size, const UniLib* unilib,
119 const std::string& triggering_preconditions_overlay) {
120 auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
121 const ActionsModel* model = LoadAndVerifyModel(buffer, size);
122 if (model == nullptr) {
123 return nullptr;
124 }
125 actions->model_ = model;
126 actions->SetOrCreateUnilib(unilib);
127 actions->triggering_preconditions_overlay_buffer_ =
128 triggering_preconditions_overlay;
129 if (!actions->ValidateAndInitialize()) {
130 return nullptr;
131 }
132 return actions;
133 }
134
FromScopedMmap(std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,const UniLib * unilib,const std::string & triggering_preconditions_overlay)135 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromScopedMmap(
136 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap, const UniLib* unilib,
137 const std::string& triggering_preconditions_overlay) {
138 if (!mmap->handle().ok()) {
139 TC3_VLOG(1) << "Mmap failed.";
140 return nullptr;
141 }
142 const ActionsModel* model = LoadAndVerifyModel(
143 reinterpret_cast<const uint8_t*>(mmap->handle().start()),
144 mmap->handle().num_bytes());
145 if (!model) {
146 TC3_LOG(ERROR) << "Model verification failed.";
147 return nullptr;
148 }
149 auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
150 actions->model_ = model;
151 actions->mmap_ = std::move(mmap);
152 actions->SetOrCreateUnilib(unilib);
153 actions->triggering_preconditions_overlay_buffer_ =
154 triggering_preconditions_overlay;
155 if (!actions->ValidateAndInitialize()) {
156 return nullptr;
157 }
158 return actions;
159 }
160
FromScopedMmap(std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)161 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromScopedMmap(
162 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap,
163 std::unique_ptr<UniLib> unilib,
164 const std::string& triggering_preconditions_overlay) {
165 if (!mmap->handle().ok()) {
166 TC3_VLOG(1) << "Mmap failed.";
167 return nullptr;
168 }
169 const ActionsModel* model = LoadAndVerifyModel(
170 reinterpret_cast<const uint8_t*>(mmap->handle().start()),
171 mmap->handle().num_bytes());
172 if (!model) {
173 TC3_LOG(ERROR) << "Model verification failed.";
174 return nullptr;
175 }
176 auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions());
177 actions->model_ = model;
178 actions->mmap_ = std::move(mmap);
179 actions->owned_unilib_ = std::move(unilib);
180 actions->unilib_ = actions->owned_unilib_.get();
181 actions->triggering_preconditions_overlay_buffer_ =
182 triggering_preconditions_overlay;
183 if (!actions->ValidateAndInitialize()) {
184 return nullptr;
185 }
186 return actions;
187 }
188
FromFileDescriptor(const int fd,const int offset,const int size,const UniLib * unilib,const std::string & triggering_preconditions_overlay)189 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
190 const int fd, const int offset, const int size, const UniLib* unilib,
191 const std::string& triggering_preconditions_overlay) {
192 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap;
193 if (offset >= 0 && size >= 0) {
194 mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size));
195 } else {
196 mmap.reset(new libtextclassifier3::ScopedMmap(fd));
197 }
198 return FromScopedMmap(std::move(mmap), unilib,
199 triggering_preconditions_overlay);
200 }
201
FromFileDescriptor(const int fd,const int offset,const int size,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)202 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
203 const int fd, const int offset, const int size,
204 std::unique_ptr<UniLib> unilib,
205 const std::string& triggering_preconditions_overlay) {
206 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap;
207 if (offset >= 0 && size >= 0) {
208 mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size));
209 } else {
210 mmap.reset(new libtextclassifier3::ScopedMmap(fd));
211 }
212 return FromScopedMmap(std::move(mmap), std::move(unilib),
213 triggering_preconditions_overlay);
214 }
215
FromFileDescriptor(const int fd,const UniLib * unilib,const std::string & triggering_preconditions_overlay)216 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
217 const int fd, const UniLib* unilib,
218 const std::string& triggering_preconditions_overlay) {
219 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
220 new libtextclassifier3::ScopedMmap(fd));
221 return FromScopedMmap(std::move(mmap), unilib,
222 triggering_preconditions_overlay);
223 }
224
FromFileDescriptor(const int fd,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)225 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor(
226 const int fd, std::unique_ptr<UniLib> unilib,
227 const std::string& triggering_preconditions_overlay) {
228 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
229 new libtextclassifier3::ScopedMmap(fd));
230 return FromScopedMmap(std::move(mmap), std::move(unilib),
231 triggering_preconditions_overlay);
232 }
233
FromPath(const std::string & path,const UniLib * unilib,const std::string & triggering_preconditions_overlay)234 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromPath(
235 const std::string& path, const UniLib* unilib,
236 const std::string& triggering_preconditions_overlay) {
237 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
238 new libtextclassifier3::ScopedMmap(path));
239 return FromScopedMmap(std::move(mmap), unilib,
240 triggering_preconditions_overlay);
241 }
242
FromPath(const std::string & path,std::unique_ptr<UniLib> unilib,const std::string & triggering_preconditions_overlay)243 std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromPath(
244 const std::string& path, std::unique_ptr<UniLib> unilib,
245 const std::string& triggering_preconditions_overlay) {
246 std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
247 new libtextclassifier3::ScopedMmap(path));
248 return FromScopedMmap(std::move(mmap), std::move(unilib),
249 triggering_preconditions_overlay);
250 }
251
SetOrCreateUnilib(const UniLib * unilib)252 void ActionsSuggestions::SetOrCreateUnilib(const UniLib* unilib) {
253 if (unilib != nullptr) {
254 unilib_ = unilib;
255 } else {
256 owned_unilib_.reset(new UniLib);
257 unilib_ = owned_unilib_.get();
258 }
259 }
260
ValidateAndInitialize()261 bool ActionsSuggestions::ValidateAndInitialize() {
262 if (model_ == nullptr) {
263 TC3_LOG(ERROR) << "No model specified.";
264 return false;
265 }
266
267 if (model_->smart_reply_action_type() == nullptr) {
268 TC3_LOG(ERROR) << "No smart reply action type specified.";
269 return false;
270 }
271
272 if (!InitializeTriggeringPreconditions()) {
273 TC3_LOG(ERROR) << "Could not initialize preconditions.";
274 return false;
275 }
276
277 if (model_->locales() &&
278 !ParseLocales(model_->locales()->c_str(), &locales_)) {
279 TC3_LOG(ERROR) << "Could not parse model supported locales.";
280 return false;
281 }
282
283 if (model_->tflite_model_spec() != nullptr) {
284 model_executor_ = TfLiteModelExecutor::FromBuffer(
285 model_->tflite_model_spec()->tflite_model());
286 if (!model_executor_) {
287 TC3_LOG(ERROR) << "Could not initialize model executor.";
288 return false;
289 }
290 }
291
292 // Gather annotation entities for the rules.
293 if (model_->annotation_actions_spec() != nullptr &&
294 model_->annotation_actions_spec()->annotation_mapping() != nullptr) {
295 for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
296 *model_->annotation_actions_spec()->annotation_mapping()) {
297 annotation_entity_types_.insert(mapping->annotation_collection()->str());
298 }
299 }
300
301 if (model_->actions_entity_data_schema() != nullptr) {
302 entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>(
303 model_->actions_entity_data_schema()->Data(),
304 model_->actions_entity_data_schema()->size());
305 if (entity_data_schema_ == nullptr) {
306 TC3_LOG(ERROR) << "Could not load entity data schema data.";
307 return false;
308 }
309
310 entity_data_builder_.reset(
311 new MutableFlatbufferBuilder(entity_data_schema_));
312 } else {
313 entity_data_schema_ = nullptr;
314 }
315
316 // Initialize regular expressions model.
317 std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
318 regex_actions_.reset(
319 new RegexActions(unilib_, model_->smart_reply_action_type()->str()));
320 if (!regex_actions_->InitializeRules(
321 model_->rules(), model_->low_confidence_rules(),
322 triggering_preconditions_overlay_, decompressor.get())) {
323 TC3_LOG(ERROR) << "Could not initialize regex rules.";
324 return false;
325 }
326
327 // Setup grammar model.
328 if (model_->rules() != nullptr &&
329 model_->rules()->grammar_rules() != nullptr) {
330 grammar_actions_.reset(new GrammarActions(
331 unilib_, model_->rules()->grammar_rules(), entity_data_builder_.get(),
332 model_->smart_reply_action_type()->str()));
333
334 // Gather annotation entities for the grammars.
335 if (auto annotation_nt = model_->rules()
336 ->grammar_rules()
337 ->rules()
338 ->nonterminals()
339 ->annotation_nt()) {
340 for (const grammar::RulesSet_::Nonterminals_::AnnotationNtEntry* entry :
341 *annotation_nt) {
342 annotation_entity_types_.insert(entry->key()->str());
343 }
344 }
345 }
346
347 #if !defined(TC3_DISABLE_LUA)
348 std::string actions_script;
349 if (GetUncompressedString(model_->lua_actions_script(),
350 model_->compressed_lua_actions_script(),
351 decompressor.get(), &actions_script) &&
352 !actions_script.empty()) {
353 if (!Compile(actions_script, &lua_bytecode_)) {
354 TC3_LOG(ERROR) << "Could not precompile lua actions snippet.";
355 return false;
356 }
357 }
358 #endif // TC3_DISABLE_LUA
359
360 if (!(ranker_ = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
361 model_->ranking_options(), decompressor.get(),
362 model_->smart_reply_action_type()->str()))) {
363 TC3_LOG(ERROR) << "Could not create an action suggestions ranker.";
364 return false;
365 }
366
367 // Create feature processor if specified.
368 const ActionsTokenFeatureProcessorOptions* options =
369 model_->feature_processor_options();
370 if (options != nullptr) {
371 if (options->tokenizer_options() == nullptr) {
372 TC3_LOG(ERROR) << "No tokenizer options specified.";
373 return false;
374 }
375
376 feature_processor_.reset(new ActionsFeatureProcessor(options, unilib_));
377 embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer(
378 options->embedding_model(), options->embedding_size(),
379 options->embedding_quantization_bits());
380
381 if (embedding_executor_ == nullptr) {
382 TC3_LOG(ERROR) << "Could not initialize embedding executor.";
383 return false;
384 }
385
386 // Cache embedding of padding, start and end token.
387 if (!EmbedTokenId(options->padding_token_id(), &embedded_padding_token_) ||
388 !EmbedTokenId(options->start_token_id(), &embedded_start_token_) ||
389 !EmbedTokenId(options->end_token_id(), &embedded_end_token_)) {
390 TC3_LOG(ERROR) << "Could not precompute token embeddings.";
391 return false;
392 }
393 token_embedding_size_ = feature_processor_->GetTokenEmbeddingSize();
394 }
395
396 // Create low confidence model if specified.
397 if (model_->low_confidence_ngram_model() != nullptr) {
398 sensitive_model_ = NGramSensitiveModel::Create(
399 unilib_, model_->low_confidence_ngram_model(),
400 feature_processor_ == nullptr ? nullptr
401 : feature_processor_->tokenizer());
402 if (sensitive_model_ == nullptr) {
403 TC3_LOG(ERROR) << "Could not create ngram linear regression model.";
404 return false;
405 }
406 }
407 if (model_->low_confidence_tflite_model() != nullptr) {
408 sensitive_model_ =
409 TFLiteSensitiveModel::Create(model_->low_confidence_tflite_model());
410 if (sensitive_model_ == nullptr) {
411 TC3_LOG(ERROR) << "Could not create TFLite sensitive model.";
412 return false;
413 }
414 }
415
416 return true;
417 }
418
InitializeTriggeringPreconditions()419 bool ActionsSuggestions::InitializeTriggeringPreconditions() {
420 triggering_preconditions_overlay_ =
421 LoadAndVerifyFlatbuffer<TriggeringPreconditions>(
422 triggering_preconditions_overlay_buffer_);
423
424 if (triggering_preconditions_overlay_ == nullptr &&
425 !triggering_preconditions_overlay_buffer_.empty()) {
426 TC3_LOG(ERROR) << "Could not load triggering preconditions overwrites.";
427 return false;
428 }
429 const flatbuffers::Table* overlay =
430 reinterpret_cast<const flatbuffers::Table*>(
431 triggering_preconditions_overlay_);
432 const TriggeringPreconditions* defaults = model_->preconditions();
433 if (defaults == nullptr) {
434 TC3_LOG(ERROR) << "No triggering conditions specified.";
435 return false;
436 }
437
438 preconditions_.min_smart_reply_triggering_score = ValueOrDefault(
439 overlay, TriggeringPreconditions::VT_MIN_SMART_REPLY_TRIGGERING_SCORE,
440 defaults->min_smart_reply_triggering_score());
441 preconditions_.max_sensitive_topic_score = ValueOrDefault(
442 overlay, TriggeringPreconditions::VT_MAX_SENSITIVE_TOPIC_SCORE,
443 defaults->max_sensitive_topic_score());
444 preconditions_.suppress_on_sensitive_topic = ValueOrDefault(
445 overlay, TriggeringPreconditions::VT_SUPPRESS_ON_SENSITIVE_TOPIC,
446 defaults->suppress_on_sensitive_topic());
447 preconditions_.min_input_length =
448 ValueOrDefault(overlay, TriggeringPreconditions::VT_MIN_INPUT_LENGTH,
449 defaults->min_input_length());
450 preconditions_.max_input_length =
451 ValueOrDefault(overlay, TriggeringPreconditions::VT_MAX_INPUT_LENGTH,
452 defaults->max_input_length());
453 preconditions_.min_locale_match_fraction = ValueOrDefault(
454 overlay, TriggeringPreconditions::VT_MIN_LOCALE_MATCH_FRACTION,
455 defaults->min_locale_match_fraction());
456 preconditions_.handle_missing_locale_as_supported = ValueOrDefault(
457 overlay, TriggeringPreconditions::VT_HANDLE_MISSING_LOCALE_AS_SUPPORTED,
458 defaults->handle_missing_locale_as_supported());
459 preconditions_.handle_unknown_locale_as_supported = ValueOrDefault(
460 overlay, TriggeringPreconditions::VT_HANDLE_UNKNOWN_LOCALE_AS_SUPPORTED,
461 defaults->handle_unknown_locale_as_supported());
462 preconditions_.suppress_on_low_confidence_input = ValueOrDefault(
463 overlay, TriggeringPreconditions::VT_SUPPRESS_ON_LOW_CONFIDENCE_INPUT,
464 defaults->suppress_on_low_confidence_input());
465 preconditions_.min_reply_score_threshold = ValueOrDefault(
466 overlay, TriggeringPreconditions::VT_MIN_REPLY_SCORE_THRESHOLD,
467 defaults->min_reply_score_threshold());
468
469 return true;
470 }
471
EmbedTokenId(const int32 token_id,std::vector<float> * embedding) const472 bool ActionsSuggestions::EmbedTokenId(const int32 token_id,
473 std::vector<float>* embedding) const {
474 return feature_processor_->AppendFeatures(
475 {token_id},
476 /*dense_features=*/{}, embedding_executor_.get(), embedding);
477 }
478
Tokenize(const std::vector<std::string> & context) const479 std::vector<std::vector<Token>> ActionsSuggestions::Tokenize(
480 const std::vector<std::string>& context) const {
481 std::vector<std::vector<Token>> tokens;
482 tokens.reserve(context.size());
483 for (const std::string& message : context) {
484 tokens.push_back(feature_processor_->tokenizer()->Tokenize(message));
485 }
486 return tokens;
487 }
488
EmbedTokensPerMessage(const std::vector<std::vector<Token>> & tokens,std::vector<float> * embeddings,int * max_num_tokens_per_message) const489 bool ActionsSuggestions::EmbedTokensPerMessage(
490 const std::vector<std::vector<Token>>& tokens,
491 std::vector<float>* embeddings, int* max_num_tokens_per_message) const {
492 const int num_messages = tokens.size();
493 *max_num_tokens_per_message = 0;
494 for (int i = 0; i < num_messages; i++) {
495 const int num_message_tokens = tokens[i].size();
496 if (num_message_tokens > *max_num_tokens_per_message) {
497 *max_num_tokens_per_message = num_message_tokens;
498 }
499 }
500
501 if (model_->feature_processor_options()->min_num_tokens_per_message() >
502 *max_num_tokens_per_message) {
503 *max_num_tokens_per_message =
504 model_->feature_processor_options()->min_num_tokens_per_message();
505 }
506 if (model_->feature_processor_options()->max_num_tokens_per_message() > 0 &&
507 *max_num_tokens_per_message >
508 model_->feature_processor_options()->max_num_tokens_per_message()) {
509 *max_num_tokens_per_message =
510 model_->feature_processor_options()->max_num_tokens_per_message();
511 }
512
513 // Embed all tokens and add paddings to pad tokens of each message to the
514 // maximum number of tokens in a message of the conversation.
515 // If a number of tokens is specified in the model config, tokens at the
516 // beginning of a message are dropped if they don't fit in the limit.
517 for (int i = 0; i < num_messages; i++) {
518 const int start =
519 std::max<int>(tokens[i].size() - *max_num_tokens_per_message, 0);
520 for (int pos = start; pos < tokens[i].size(); pos++) {
521 if (!feature_processor_->AppendTokenFeatures(
522 tokens[i][pos], embedding_executor_.get(), embeddings)) {
523 TC3_LOG(ERROR) << "Could not run token feature extractor.";
524 return false;
525 }
526 }
527 // Add padding.
528 for (int k = tokens[i].size(); k < *max_num_tokens_per_message; k++) {
529 embeddings->insert(embeddings->end(), embedded_padding_token_.begin(),
530 embedded_padding_token_.end());
531 }
532 }
533
534 return true;
535 }
536
EmbedAndFlattenTokens(const std::vector<std::vector<Token>> & tokens,std::vector<float> * embeddings,int * total_token_count) const537 bool ActionsSuggestions::EmbedAndFlattenTokens(
538 const std::vector<std::vector<Token>>& tokens,
539 std::vector<float>* embeddings, int* total_token_count) const {
540 const int num_messages = tokens.size();
541 int start_message = 0;
542 int message_token_offset = 0;
543
544 // If a maximum model input length is specified, we need to check how
545 // much we need to trim at the start.
546 const int max_num_total_tokens =
547 model_->feature_processor_options()->max_num_total_tokens();
548 if (max_num_total_tokens > 0) {
549 int total_tokens = 0;
550 start_message = num_messages - 1;
551 for (; start_message >= 0; start_message--) {
552 // Tokens of the message + start and end token.
553 const int num_message_tokens = tokens[start_message].size() + 2;
554 total_tokens += num_message_tokens;
555
556 // Check whether we exhausted the budget.
557 if (total_tokens >= max_num_total_tokens) {
558 message_token_offset = total_tokens - max_num_total_tokens;
559 break;
560 }
561 }
562 }
563
564 // Add embeddings.
565 *total_token_count = 0;
566 for (int i = start_message; i < num_messages; i++) {
567 if (message_token_offset == 0) {
568 ++(*total_token_count);
569 // Add `start message` token.
570 embeddings->insert(embeddings->end(), embedded_start_token_.begin(),
571 embedded_start_token_.end());
572 }
573
574 for (int pos = std::max(0, message_token_offset - 1);
575 pos < tokens[i].size(); pos++) {
576 ++(*total_token_count);
577 if (!feature_processor_->AppendTokenFeatures(
578 tokens[i][pos], embedding_executor_.get(), embeddings)) {
579 TC3_LOG(ERROR) << "Could not run token feature extractor.";
580 return false;
581 }
582 }
583
584 // Add `end message` token.
585 ++(*total_token_count);
586 embeddings->insert(embeddings->end(), embedded_end_token_.begin(),
587 embedded_end_token_.end());
588
589 // Reset for the subsequent messages.
590 message_token_offset = 0;
591 }
592
593 // Add optional padding.
594 const int min_num_total_tokens =
595 model_->feature_processor_options()->min_num_total_tokens();
596 for (; *total_token_count < min_num_total_tokens; ++(*total_token_count)) {
597 embeddings->insert(embeddings->end(), embedded_padding_token_.begin(),
598 embedded_padding_token_.end());
599 }
600
601 return true;
602 }
603
AllocateInput(const int conversation_length,const int max_tokens,const int total_token_count,tflite::Interpreter * interpreter) const604 bool ActionsSuggestions::AllocateInput(const int conversation_length,
605 const int max_tokens,
606 const int total_token_count,
607 tflite::Interpreter* interpreter) const {
608 if (model_->tflite_model_spec()->resize_inputs()) {
609 if (model_->tflite_model_spec()->input_context() >= 0) {
610 interpreter->ResizeInputTensor(
611 interpreter->inputs()[model_->tflite_model_spec()->input_context()],
612 {1, conversation_length});
613 }
614 if (model_->tflite_model_spec()->input_user_id() >= 0) {
615 interpreter->ResizeInputTensor(
616 interpreter->inputs()[model_->tflite_model_spec()->input_user_id()],
617 {1, conversation_length});
618 }
619 if (model_->tflite_model_spec()->input_time_diffs() >= 0) {
620 interpreter->ResizeInputTensor(
621 interpreter
622 ->inputs()[model_->tflite_model_spec()->input_time_diffs()],
623 {1, conversation_length});
624 }
625 if (model_->tflite_model_spec()->input_num_tokens() >= 0) {
626 interpreter->ResizeInputTensor(
627 interpreter
628 ->inputs()[model_->tflite_model_spec()->input_num_tokens()],
629 {conversation_length, 1});
630 }
631 if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
632 interpreter->ResizeInputTensor(
633 interpreter
634 ->inputs()[model_->tflite_model_spec()->input_token_embeddings()],
635 {conversation_length, max_tokens, token_embedding_size_});
636 }
637 if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
638 interpreter->ResizeInputTensor(
639 interpreter->inputs()[model_->tflite_model_spec()
640 ->input_flattened_token_embeddings()],
641 {1, total_token_count});
642 }
643 }
644
645 return interpreter->AllocateTensors() == kTfLiteOk;
646 }
647
SetupModelInput(const std::vector<std::string> & context,const std::vector<int> & user_ids,const std::vector<float> & time_diffs,const int num_suggestions,const ActionSuggestionOptions & options,tflite::Interpreter * interpreter) const648 bool ActionsSuggestions::SetupModelInput(
649 const std::vector<std::string>& context, const std::vector<int>& user_ids,
650 const std::vector<float>& time_diffs, const int num_suggestions,
651 const ActionSuggestionOptions& options,
652 tflite::Interpreter* interpreter) const {
653 // Compute token embeddings.
654 std::vector<std::vector<Token>> tokens;
655 std::vector<float> token_embeddings;
656 std::vector<float> flattened_token_embeddings;
657 int max_tokens = 0;
658 int total_token_count = 0;
659 if (model_->tflite_model_spec()->input_num_tokens() >= 0 ||
660 model_->tflite_model_spec()->input_token_embeddings() >= 0 ||
661 model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
662 if (feature_processor_ == nullptr) {
663 TC3_LOG(ERROR) << "No feature processor specified.";
664 return false;
665 }
666
667 // Tokenize the messages in the conversation.
668 tokens = Tokenize(context);
669 if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
670 if (!EmbedTokensPerMessage(tokens, &token_embeddings, &max_tokens)) {
671 TC3_LOG(ERROR) << "Could not extract token features.";
672 return false;
673 }
674 }
675 if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
676 if (!EmbedAndFlattenTokens(tokens, &flattened_token_embeddings,
677 &total_token_count)) {
678 TC3_LOG(ERROR) << "Could not extract token features.";
679 return false;
680 }
681 }
682 }
683
684 if (!AllocateInput(context.size(), max_tokens, total_token_count,
685 interpreter)) {
686 TC3_LOG(ERROR) << "TensorFlow Lite model allocation failed.";
687 return false;
688 }
689 if (model_->tflite_model_spec()->input_context() >= 0) {
690 if (model_->tflite_model_spec()->input_length_to_pad() > 0) {
691 model_executor_->SetInput<std::string>(
692 model_->tflite_model_spec()->input_context(),
693 PadOrTruncateToTargetLength(
694 context, model_->tflite_model_spec()->input_length_to_pad(),
695 std::string("")),
696 interpreter);
697 } else {
698 model_executor_->SetInput<std::string>(
699 model_->tflite_model_spec()->input_context(), context, interpreter);
700 }
701 }
702 if (model_->tflite_model_spec()->input_context_length() >= 0) {
703 model_executor_->SetInput<int>(
704 model_->tflite_model_spec()->input_context_length(), context.size(),
705 interpreter);
706 }
707 if (model_->tflite_model_spec()->input_user_id() >= 0) {
708 if (model_->tflite_model_spec()->input_length_to_pad() > 0) {
709 model_executor_->SetInput<int>(
710 model_->tflite_model_spec()->input_user_id(),
711 PadOrTruncateToTargetLength(
712 user_ids, model_->tflite_model_spec()->input_length_to_pad(), 0),
713 interpreter);
714 } else {
715 model_executor_->SetInput<int>(
716 model_->tflite_model_spec()->input_user_id(), user_ids, interpreter);
717 }
718 }
719 if (model_->tflite_model_spec()->input_num_suggestions() >= 0) {
720 model_executor_->SetInput<int>(
721 model_->tflite_model_spec()->input_num_suggestions(), num_suggestions,
722 interpreter);
723 }
724 if (model_->tflite_model_spec()->input_time_diffs() >= 0) {
725 model_executor_->SetInput<float>(
726 model_->tflite_model_spec()->input_time_diffs(), time_diffs,
727 interpreter);
728 }
729 if (model_->tflite_model_spec()->input_num_tokens() >= 0) {
730 std::vector<int> num_tokens_per_message(tokens.size());
731 for (int i = 0; i < tokens.size(); i++) {
732 num_tokens_per_message[i] = tokens[i].size();
733 }
734 model_executor_->SetInput<int>(
735 model_->tflite_model_spec()->input_num_tokens(), num_tokens_per_message,
736 interpreter);
737 }
738 if (model_->tflite_model_spec()->input_token_embeddings() >= 0) {
739 model_executor_->SetInput<float>(
740 model_->tflite_model_spec()->input_token_embeddings(), token_embeddings,
741 interpreter);
742 }
743 if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) {
744 model_executor_->SetInput<float>(
745 model_->tflite_model_spec()->input_flattened_token_embeddings(),
746 flattened_token_embeddings, interpreter);
747 }
748 // Set up additional input parameters.
749 if (const auto* input_name_index =
750 model_->tflite_model_spec()->input_name_index()) {
751 const std::unordered_map<std::string, Variant>& model_parameters =
752 options.model_parameters;
753 for (const TensorflowLiteModelSpec_::InputNameIndexEntry* entry :
754 *input_name_index) {
755 const std::string param_name = entry->key()->str();
756 const int param_index = entry->value();
757 const TfLiteType param_type =
758 interpreter->tensor(interpreter->inputs()[param_index])->type;
759 const auto param_value_it = model_parameters.find(param_name);
760 const bool has_value = param_value_it != model_parameters.end();
761 switch (param_type) {
762 case kTfLiteFloat32:
763 if (has_value) {
764 SetVectorOrScalarAsModelInput<float>(param_index,
765 param_value_it->second,
766 interpreter, model_executor_);
767 } else {
768 model_executor_->SetInput<float>(param_index, kDefaultFloat,
769 interpreter);
770 }
771 break;
772 case kTfLiteInt32:
773 if (has_value) {
774 SetVectorOrScalarAsModelInput<int32_t>(
775 param_index, param_value_it->second, interpreter,
776 model_executor_);
777 } else {
778 model_executor_->SetInput<int32_t>(param_index, kDefaultInt,
779 interpreter);
780 }
781 break;
782 case kTfLiteInt64:
783 model_executor_->SetInput<int64_t>(
784 param_index,
785 has_value ? param_value_it->second.Value<int64>() : kDefaultInt,
786 interpreter);
787 break;
788 case kTfLiteUInt8:
789 model_executor_->SetInput<uint8_t>(
790 param_index,
791 has_value ? param_value_it->second.Value<uint8>() : kDefaultInt,
792 interpreter);
793 break;
794 case kTfLiteInt8:
795 model_executor_->SetInput<int8_t>(
796 param_index,
797 has_value ? param_value_it->second.Value<int8>() : kDefaultInt,
798 interpreter);
799 break;
800 case kTfLiteBool:
801 model_executor_->SetInput<bool>(
802 param_index,
803 has_value ? param_value_it->second.Value<bool>() : kDefaultBool,
804 interpreter);
805 break;
806 default:
807 TC3_LOG(ERROR) << "Unsupported type of additional input parameter: "
808 << param_name;
809 }
810 }
811 }
812 return true;
813 }
814
PopulateTextReplies(const tflite::Interpreter * interpreter,int suggestion_index,int score_index,const std::string & type,float priority_score,const absl::flat_hash_set<std::string> & blocklist,const absl::flat_hash_map<std::string,std::vector<std::string>> & concept_mappings,ActionsSuggestionsResponse * response) const815 void ActionsSuggestions::PopulateTextReplies(
816 const tflite::Interpreter* interpreter, int suggestion_index,
817 int score_index, const std::string& type, float priority_score,
818 const absl::flat_hash_set<std::string>& blocklist,
819 const absl::flat_hash_map<std::string, std::vector<std::string>>&
820 concept_mappings,
821 ActionsSuggestionsResponse* response) const {
822 const std::vector<tflite::StringRef> replies =
823 model_executor_->Output<tflite::StringRef>(suggestion_index, interpreter);
824 const TensorView<float> scores =
825 model_executor_->OutputView<float>(score_index, interpreter);
826
827 for (int i = 0; i < replies.size(); i++) {
828 if (replies[i].len == 0) {
829 continue;
830 }
831 const float score = scores.data()[i];
832 if (score < preconditions_.min_reply_score_threshold) {
833 continue;
834 }
835 std::string response_text(replies[i].str, replies[i].len);
836 if (blocklist.contains(response_text)) {
837 continue;
838 }
839 if (concept_mappings.contains(response_text)) {
840 const int candidates_size = concept_mappings.at(response_text).size();
841 const int candidate_index = absl::Uniform<int>(
842 absl::IntervalOpenOpen, bit_gen_, 0, candidates_size);
843 response_text = concept_mappings.at(response_text)[candidate_index];
844 }
845
846 response->actions.push_back({response_text, type, score, priority_score});
847 }
848 }
849
FillSuggestionFromSpecWithEntityData(const ActionSuggestionSpec * spec,ActionSuggestion * suggestion) const850 void ActionsSuggestions::FillSuggestionFromSpecWithEntityData(
851 const ActionSuggestionSpec* spec, ActionSuggestion* suggestion) const {
852 std::unique_ptr<MutableFlatbuffer> entity_data =
853 entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
854 : nullptr;
855 FillSuggestionFromSpec(spec, entity_data.get(), suggestion);
856 }
857
PopulateIntentTriggering(const tflite::Interpreter * interpreter,int suggestion_index,int score_index,const ActionSuggestionSpec * task_spec,ActionsSuggestionsResponse * response) const858 void ActionsSuggestions::PopulateIntentTriggering(
859 const tflite::Interpreter* interpreter, int suggestion_index,
860 int score_index, const ActionSuggestionSpec* task_spec,
861 ActionsSuggestionsResponse* response) const {
862 if (!task_spec || task_spec->type()->size() == 0) {
863 TC3_LOG(ERROR)
864 << "Task type for intent (action) triggering cannot be empty!";
865 return;
866 }
867 const TensorView<bool> intent_prediction =
868 model_executor_->OutputView<bool>(suggestion_index, interpreter);
869 const TensorView<float> intent_scores =
870 model_executor_->OutputView<float>(score_index, interpreter);
871 // Two result corresponding to binary triggering case.
872 TC3_CHECK_EQ(intent_prediction.size(), 2);
873 TC3_CHECK_EQ(intent_scores.size(), 2);
874 // We rely on in-graph thresholding logic so at this point the results
875 // have been ranked properly according to threshold.
876 const bool triggering = intent_prediction.data()[0];
877 const float trigger_score = intent_scores.data()[0];
878
879 if (triggering) {
880 ActionSuggestion suggestion;
881 std::unique_ptr<MutableFlatbuffer> entity_data =
882 entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
883 : nullptr;
884 FillSuggestionFromSpecWithEntityData(task_spec, &suggestion);
885 suggestion.score = trigger_score;
886 response->actions.push_back(std::move(suggestion));
887 }
888 }
889
ReadModelOutput(tflite::Interpreter * interpreter,const ActionSuggestionOptions & options,ActionsSuggestionsResponse * response) const890 bool ActionsSuggestions::ReadModelOutput(
891 tflite::Interpreter* interpreter, const ActionSuggestionOptions& options,
892 ActionsSuggestionsResponse* response) const {
893 // Read sensitivity and triggering score predictions.
894 if (model_->tflite_model_spec()->output_triggering_score() >= 0) {
895 const TensorView<float> triggering_score =
896 model_executor_->OutputView<float>(
897 model_->tflite_model_spec()->output_triggering_score(),
898 interpreter);
899 if (!triggering_score.is_valid() || triggering_score.size() == 0) {
900 TC3_LOG(ERROR) << "Could not compute triggering score.";
901 return false;
902 }
903 response->triggering_score = triggering_score.data()[0];
904 response->output_filtered_min_triggering_score =
905 (response->triggering_score <
906 preconditions_.min_smart_reply_triggering_score);
907 }
908 if (model_->tflite_model_spec()->output_sensitive_topic_score() >= 0) {
909 const TensorView<float> sensitive_topic_score =
910 model_executor_->OutputView<float>(
911 model_->tflite_model_spec()->output_sensitive_topic_score(),
912 interpreter);
913 if (!sensitive_topic_score.is_valid() ||
914 sensitive_topic_score.dim(0) != 1) {
915 TC3_LOG(ERROR) << "Could not compute sensitive topic score.";
916 return false;
917 }
918 response->sensitivity_score = sensitive_topic_score.data()[0];
919 response->is_sensitive = (response->sensitivity_score >
920 preconditions_.max_sensitive_topic_score);
921 }
922
923 // Suppress model outputs.
924 if (response->is_sensitive) {
925 return true;
926 }
927
928 // Read smart reply predictions.
929 if (!response->output_filtered_min_triggering_score &&
930 model_->tflite_model_spec()->output_replies() >= 0) {
931 absl::flat_hash_set<std::string> empty_blocklist;
932 PopulateTextReplies(
933 interpreter, model_->tflite_model_spec()->output_replies(),
934 model_->tflite_model_spec()->output_replies_scores(),
935 model_->smart_reply_action_type()->str(),
936 /* priority_score */ 0.0, empty_blocklist, {}, response);
937 }
938
939 // Read actions suggestions.
940 if (model_->tflite_model_spec()->output_actions_scores() >= 0) {
941 const TensorView<float> actions_scores = model_executor_->OutputView<float>(
942 model_->tflite_model_spec()->output_actions_scores(), interpreter);
943 for (int i = 0; i < model_->action_type()->size(); i++) {
944 const ActionTypeOptions* action_type = model_->action_type()->Get(i);
945 // Skip disabled action classes, such as the default other category.
946 if (!action_type->enabled()) {
947 continue;
948 }
949 const float score = actions_scores.data()[i];
950 if (score < action_type->min_triggering_score()) {
951 continue;
952 }
953
954 // Create action from model output.
955 ActionSuggestion suggestion;
956 suggestion.type = action_type->name()->str();
957 std::unique_ptr<MutableFlatbuffer> entity_data =
958 entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
959 : nullptr;
960 FillSuggestionFromSpecWithEntityData(action_type->action(), &suggestion);
961 suggestion.score = score;
962 response->actions.push_back(std::move(suggestion));
963 }
964 }
965
966 // Read multi-task predictions and construct the result properly.
967 if (const auto* prediction_metadata =
968 model_->tflite_model_spec()->prediction_metadata()) {
969 for (const PredictionMetadata* metadata : *prediction_metadata) {
970 const ActionSuggestionSpec* task_spec = metadata->task_spec();
971 const int suggestions_index = metadata->output_suggestions();
972 const int suggestions_scores_index =
973 metadata->output_suggestions_scores();
974 absl::flat_hash_set<std::string> response_text_blocklist;
975 absl::flat_hash_map<std::string, std::vector<std::string>>
976 concept_mappings;
977 switch (metadata->prediction_type()) {
978 case PredictionType_NEXT_MESSAGE_PREDICTION:
979 if (!task_spec || task_spec->type()->size() == 0) {
980 TC3_LOG(WARNING) << "Task type not provided, use default "
981 "smart_reply_action_type!";
982 }
983 if (task_spec) {
984 if (task_spec->response_text_blocklist()) {
985 for (const auto& val : *task_spec->response_text_blocklist()) {
986 response_text_blocklist.insert(val->str());
987 }
988 }
989 if (task_spec->concept_mappings()) {
990 for (const auto& concept : *task_spec->concept_mappings()) {
991 std::vector<std::string> candidates;
992 for (const auto& candidate : *concept->candidates()) {
993 candidates.push_back(candidate->str());
994 }
995 concept_mappings[concept->concept_name()->str()] = candidates;
996 }
997 }
998 }
999 PopulateTextReplies(
1000 interpreter, suggestions_index, suggestions_scores_index,
1001 task_spec ? task_spec->type()->str()
1002 : model_->smart_reply_action_type()->str(),
1003 task_spec ? task_spec->priority_score() : 0.0,
1004 response_text_blocklist, concept_mappings, response);
1005 break;
1006 case PredictionType_INTENT_TRIGGERING:
1007 PopulateIntentTriggering(interpreter, suggestions_index,
1008 suggestions_scores_index, task_spec,
1009 response);
1010 break;
1011 default:
1012 TC3_LOG(ERROR) << "Unsupported prediction type!";
1013 return false;
1014 }
1015 }
1016 }
1017
1018 return true;
1019 }
1020
SuggestActionsFromModel(const Conversation & conversation,const int num_messages,const ActionSuggestionOptions & options,ActionsSuggestionsResponse * response,std::unique_ptr<tflite::Interpreter> * interpreter) const1021 bool ActionsSuggestions::SuggestActionsFromModel(
1022 const Conversation& conversation, const int num_messages,
1023 const ActionSuggestionOptions& options,
1024 ActionsSuggestionsResponse* response,
1025 std::unique_ptr<tflite::Interpreter>* interpreter) const {
1026 TC3_CHECK_LE(num_messages, conversation.messages.size());
1027
1028 if (sensitive_model_ != nullptr &&
1029 sensitive_model_->EvalConversation(conversation, num_messages).first) {
1030 response->is_sensitive = true;
1031 return true;
1032 }
1033
1034 if (!model_executor_) {
1035 return true;
1036 }
1037 *interpreter = model_executor_->CreateInterpreter();
1038
1039 if (!*interpreter) {
1040 TC3_LOG(ERROR) << "Could not build TensorFlow Lite interpreter for the "
1041 "actions suggestions model.";
1042 return false;
1043 }
1044
1045 std::vector<std::string> context;
1046 std::vector<int> user_ids;
1047 std::vector<float> time_diffs;
1048 context.reserve(num_messages);
1049 user_ids.reserve(num_messages);
1050 time_diffs.reserve(num_messages);
1051
1052 // Gather last `num_messages` messages from the conversation.
1053 int64 last_message_reference_time_ms_utc = 0;
1054 const float second_in_ms = 1000;
1055 for (int i = conversation.messages.size() - num_messages;
1056 i < conversation.messages.size(); i++) {
1057 const ConversationMessage& message = conversation.messages[i];
1058 context.push_back(message.text);
1059 user_ids.push_back(message.user_id);
1060
1061 float time_diff_secs = 0;
1062 if (message.reference_time_ms_utc != 0 &&
1063 last_message_reference_time_ms_utc != 0) {
1064 time_diff_secs = std::max(0.0f, (message.reference_time_ms_utc -
1065 last_message_reference_time_ms_utc) /
1066 second_in_ms);
1067 }
1068 if (message.reference_time_ms_utc != 0) {
1069 last_message_reference_time_ms_utc = message.reference_time_ms_utc;
1070 }
1071 time_diffs.push_back(time_diff_secs);
1072 }
1073
1074 if (!SetupModelInput(context, user_ids, time_diffs,
1075 /*num_suggestions=*/model_->num_smart_replies(), options,
1076 interpreter->get())) {
1077 TC3_LOG(ERROR) << "Failed to setup input for TensorFlow Lite model.";
1078 return false;
1079 }
1080
1081 if ((*interpreter)->Invoke() != kTfLiteOk) {
1082 TC3_LOG(ERROR) << "Failed to invoke TensorFlow Lite interpreter.";
1083 return false;
1084 }
1085
1086 return ReadModelOutput(interpreter->get(), options, response);
1087 }
1088
SuggestActionsFromConversationIntentDetection(const Conversation & conversation,const ActionSuggestionOptions & options,std::vector<ActionSuggestion> * actions) const1089 Status ActionsSuggestions::SuggestActionsFromConversationIntentDetection(
1090 const Conversation& conversation, const ActionSuggestionOptions& options,
1091 std::vector<ActionSuggestion>* actions) const {
1092 TC3_ASSIGN_OR_RETURN(
1093 std::vector<ActionSuggestion> new_actions,
1094 conversation_intent_detection_->SuggestActions(conversation, options));
1095 for (auto& action : new_actions) {
1096 actions->push_back(std::move(action));
1097 }
1098 return Status::OK;
1099 }
1100
AnnotationOptionsForMessage(const ConversationMessage & message) const1101 AnnotationOptions ActionsSuggestions::AnnotationOptionsForMessage(
1102 const ConversationMessage& message) const {
1103 AnnotationOptions options;
1104 options.detected_text_language_tags = message.detected_text_language_tags;
1105 options.reference_time_ms_utc = message.reference_time_ms_utc;
1106 options.reference_timezone = message.reference_timezone;
1107 options.annotation_usecase =
1108 model_->annotation_actions_spec()->annotation_usecase();
1109 options.is_serialized_entity_data_enabled =
1110 model_->annotation_actions_spec()->is_serialized_entity_data_enabled();
1111 options.entity_types = annotation_entity_types_;
1112 return options;
1113 }
1114
1115 // Run annotator on the messages of a conversation.
AnnotateConversation(const Conversation & conversation,const Annotator * annotator) const1116 Conversation ActionsSuggestions::AnnotateConversation(
1117 const Conversation& conversation, const Annotator* annotator) const {
1118 if (annotator == nullptr) {
1119 return conversation;
1120 }
1121 const int num_messages_grammar =
1122 ((model_->rules() && model_->rules()->grammar_rules() &&
1123 model_->rules()
1124 ->grammar_rules()
1125 ->rules()
1126 ->nonterminals()
1127 ->annotation_nt())
1128 ? 1
1129 : 0);
1130 const int num_messages_mapping =
1131 (model_->annotation_actions_spec()
1132 ? std::max(model_->annotation_actions_spec()
1133 ->max_history_from_any_person(),
1134 model_->annotation_actions_spec()
1135 ->max_history_from_last_person())
1136 : 0);
1137 const int num_messages = std::max(num_messages_grammar, num_messages_mapping);
1138 if (num_messages == 0) {
1139 // No annotations are used.
1140 return conversation;
1141 }
1142 Conversation annotated_conversation = conversation;
1143 for (int i = 0, message_index = annotated_conversation.messages.size() - 1;
1144 i < num_messages && message_index >= 0; i++, message_index--) {
1145 ConversationMessage* message =
1146 &annotated_conversation.messages[message_index];
1147 if (message->annotations.empty()) {
1148 message->annotations = annotator->Annotate(
1149 message->text, AnnotationOptionsForMessage(*message));
1150 ConvertDatetimeToTime(&message->annotations);
1151 }
1152 }
1153 return annotated_conversation;
1154 }
1155
SuggestActionsFromAnnotations(const Conversation & conversation,std::vector<ActionSuggestion> * actions) const1156 void ActionsSuggestions::SuggestActionsFromAnnotations(
1157 const Conversation& conversation,
1158 std::vector<ActionSuggestion>* actions) const {
1159 if (model_->annotation_actions_spec() == nullptr ||
1160 model_->annotation_actions_spec()->annotation_mapping() == nullptr ||
1161 model_->annotation_actions_spec()->annotation_mapping()->size() == 0) {
1162 return;
1163 }
1164
1165 // Create actions based on the annotations.
1166 const int max_from_any_person =
1167 model_->annotation_actions_spec()->max_history_from_any_person();
1168 const int max_from_last_person =
1169 model_->annotation_actions_spec()->max_history_from_last_person();
1170 const int last_person = conversation.messages.back().user_id;
1171
1172 int num_messages_last_person = 0;
1173 int num_messages_any_person = 0;
1174 bool all_from_last_person = true;
1175 for (int message_index = conversation.messages.size() - 1; message_index >= 0;
1176 message_index--) {
1177 const ConversationMessage& message = conversation.messages[message_index];
1178 std::vector<AnnotatedSpan> annotations = message.annotations;
1179
1180 // Update how many messages we have processed from the last person in the
1181 // conversation and from any person in the conversation.
1182 num_messages_any_person++;
1183 if (all_from_last_person && message.user_id == last_person) {
1184 num_messages_last_person++;
1185 } else {
1186 all_from_last_person = false;
1187 }
1188
1189 if (num_messages_any_person > max_from_any_person &&
1190 (!all_from_last_person ||
1191 num_messages_last_person > max_from_last_person)) {
1192 break;
1193 }
1194
1195 if (message.user_id == kLocalUserId) {
1196 if (model_->annotation_actions_spec()->only_until_last_sent()) {
1197 break;
1198 }
1199 if (!model_->annotation_actions_spec()->include_local_user_messages()) {
1200 continue;
1201 }
1202 }
1203
1204 std::vector<ActionSuggestionAnnotation> action_annotations;
1205 action_annotations.reserve(annotations.size());
1206 for (const AnnotatedSpan& annotation : annotations) {
1207 if (annotation.classification.empty()) {
1208 continue;
1209 }
1210
1211 const ClassificationResult& classification_result =
1212 annotation.classification[0];
1213
1214 ActionSuggestionAnnotation action_annotation;
1215 action_annotation.span = {
1216 message_index, annotation.span,
1217 UTF8ToUnicodeText(message.text, /*do_copy=*/false)
1218 .UTF8Substring(annotation.span.first, annotation.span.second)};
1219 action_annotation.entity = classification_result;
1220 action_annotation.name = classification_result.collection;
1221 action_annotations.push_back(std::move(action_annotation));
1222 }
1223
1224 if (model_->annotation_actions_spec()->deduplicate_annotations()) {
1225 // Create actions only for deduplicated annotations.
1226 for (const int annotation_id :
1227 DeduplicateAnnotations(action_annotations)) {
1228 SuggestActionsFromAnnotation(
1229 message_index, action_annotations[annotation_id], actions);
1230 }
1231 } else {
1232 // Create actions for all annotations.
1233 for (const ActionSuggestionAnnotation& annotation : action_annotations) {
1234 SuggestActionsFromAnnotation(message_index, annotation, actions);
1235 }
1236 }
1237 }
1238 }
1239
SuggestActionsFromAnnotation(const int message_index,const ActionSuggestionAnnotation & annotation,std::vector<ActionSuggestion> * actions) const1240 void ActionsSuggestions::SuggestActionsFromAnnotation(
1241 const int message_index, const ActionSuggestionAnnotation& annotation,
1242 std::vector<ActionSuggestion>* actions) const {
1243 for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
1244 *model_->annotation_actions_spec()->annotation_mapping()) {
1245 if (annotation.entity.collection ==
1246 mapping->annotation_collection()->str()) {
1247 if (annotation.entity.score < mapping->min_annotation_score()) {
1248 continue;
1249 }
1250
1251 std::unique_ptr<MutableFlatbuffer> entity_data =
1252 entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot()
1253 : nullptr;
1254
1255 // Set annotation text as (additional) entity data field.
1256 if (mapping->entity_field() != nullptr) {
1257 TC3_CHECK_NE(entity_data, nullptr);
1258
1259 UnicodeText normalized_annotation_text =
1260 UTF8ToUnicodeText(annotation.span.text, /*do_copy=*/false);
1261
1262 // Apply normalization if specified.
1263 if (mapping->normalization_options() != nullptr) {
1264 normalized_annotation_text =
1265 NormalizeText(*unilib_, mapping->normalization_options(),
1266 normalized_annotation_text);
1267 }
1268
1269 entity_data->ParseAndSet(mapping->entity_field(),
1270 normalized_annotation_text.ToUTF8String());
1271 }
1272
1273 ActionSuggestion suggestion;
1274 FillSuggestionFromSpec(mapping->action(), entity_data.get(), &suggestion);
1275 if (mapping->use_annotation_score()) {
1276 suggestion.score = annotation.entity.score;
1277 }
1278 suggestion.annotations = {annotation};
1279 actions->push_back(std::move(suggestion));
1280 }
1281 }
1282 }
1283
DeduplicateAnnotations(const std::vector<ActionSuggestionAnnotation> & annotations) const1284 std::vector<int> ActionsSuggestions::DeduplicateAnnotations(
1285 const std::vector<ActionSuggestionAnnotation>& annotations) const {
1286 std::map<std::pair<std::string, std::string>, int> deduplicated_annotations;
1287
1288 for (int i = 0; i < annotations.size(); i++) {
1289 const std::pair<std::string, std::string> key = {annotations[i].name,
1290 annotations[i].span.text};
1291 auto entry = deduplicated_annotations.find(key);
1292 if (entry != deduplicated_annotations.end()) {
1293 // Kepp the annotation with the higher score.
1294 if (annotations[entry->second].entity.score <
1295 annotations[i].entity.score) {
1296 entry->second = i;
1297 }
1298 continue;
1299 }
1300 deduplicated_annotations.insert(entry, {key, i});
1301 }
1302
1303 std::vector<int> result;
1304 result.reserve(deduplicated_annotations.size());
1305 for (const auto& key_and_annotation : deduplicated_annotations) {
1306 result.push_back(key_and_annotation.second);
1307 }
1308 return result;
1309 }
1310
1311 #if !defined(TC3_DISABLE_LUA)
SuggestActionsFromLua(const Conversation & conversation,const TfLiteModelExecutor * model_executor,const tflite::Interpreter * interpreter,const reflection::Schema * annotation_entity_data_schema,std::vector<ActionSuggestion> * actions) const1312 bool ActionsSuggestions::SuggestActionsFromLua(
1313 const Conversation& conversation, const TfLiteModelExecutor* model_executor,
1314 const tflite::Interpreter* interpreter,
1315 const reflection::Schema* annotation_entity_data_schema,
1316 std::vector<ActionSuggestion>* actions) const {
1317 if (lua_bytecode_.empty()) {
1318 return true;
1319 }
1320
1321 auto lua_actions = LuaActionsSuggestions::CreateLuaActionsSuggestions(
1322 lua_bytecode_, conversation, model_executor, model_->tflite_model_spec(),
1323 interpreter, entity_data_schema_, annotation_entity_data_schema);
1324 if (lua_actions == nullptr) {
1325 TC3_LOG(ERROR) << "Could not create lua actions.";
1326 return false;
1327 }
1328 return lua_actions->SuggestActions(actions);
1329 }
1330 #else
SuggestActionsFromLua(const Conversation & conversation,const TfLiteModelExecutor * model_executor,const tflite::Interpreter * interpreter,const reflection::Schema * annotation_entity_data_schema,std::vector<ActionSuggestion> * actions) const1331 bool ActionsSuggestions::SuggestActionsFromLua(
1332 const Conversation& conversation, const TfLiteModelExecutor* model_executor,
1333 const tflite::Interpreter* interpreter,
1334 const reflection::Schema* annotation_entity_data_schema,
1335 std::vector<ActionSuggestion>* actions) const {
1336 return true;
1337 }
1338 #endif
1339
GatherActionsSuggestions(const Conversation & conversation,const Annotator * annotator,const ActionSuggestionOptions & options,ActionsSuggestionsResponse * response) const1340 bool ActionsSuggestions::GatherActionsSuggestions(
1341 const Conversation& conversation, const Annotator* annotator,
1342 const ActionSuggestionOptions& options,
1343 ActionsSuggestionsResponse* response) const {
1344 if (conversation.messages.empty()) {
1345 return true;
1346 }
1347
1348 // Run annotator against messages.
1349 const Conversation annotated_conversation =
1350 AnnotateConversation(conversation, annotator);
1351
1352 const int num_messages = NumMessagesToConsider(
1353 annotated_conversation, model_->max_conversation_history_length());
1354
1355 if (num_messages <= 0) {
1356 TC3_LOG(INFO) << "No messages provided for actions suggestions.";
1357 return false;
1358 }
1359
1360 SuggestActionsFromAnnotations(annotated_conversation, &response->actions);
1361
1362 if (grammar_actions_ != nullptr &&
1363 !grammar_actions_->SuggestActions(annotated_conversation,
1364 &response->actions)) {
1365 TC3_LOG(ERROR) << "Could not suggest actions from grammar rules.";
1366 return false;
1367 }
1368
1369 int input_text_length = 0;
1370 int num_matching_locales = 0;
1371 for (int i = annotated_conversation.messages.size() - num_messages;
1372 i < annotated_conversation.messages.size(); i++) {
1373 input_text_length += annotated_conversation.messages[i].text.length();
1374 std::vector<Locale> message_languages;
1375 if (!ParseLocales(
1376 annotated_conversation.messages[i].detected_text_language_tags,
1377 &message_languages)) {
1378 continue;
1379 }
1380 if (Locale::IsAnyLocaleSupported(
1381 message_languages, locales_,
1382 preconditions_.handle_unknown_locale_as_supported)) {
1383 ++num_matching_locales;
1384 }
1385 }
1386
1387 // Bail out if we are provided with too few or too much input.
1388 if (input_text_length < preconditions_.min_input_length ||
1389 (preconditions_.max_input_length >= 0 &&
1390 input_text_length > preconditions_.max_input_length)) {
1391 TC3_LOG(INFO) << "Too much or not enough input for inference.";
1392 return response;
1393 }
1394
1395 // Bail out if the text does not look like it can be handled by the model.
1396 const float matching_fraction =
1397 static_cast<float>(num_matching_locales) / num_messages;
1398 if (matching_fraction < preconditions_.min_locale_match_fraction) {
1399 TC3_LOG(INFO) << "Not enough locale matches.";
1400 response->output_filtered_locale_mismatch = true;
1401 return true;
1402 }
1403
1404 std::vector<const UniLib::RegexPattern*> post_check_rules;
1405 if (preconditions_.suppress_on_low_confidence_input) {
1406 if (regex_actions_->IsLowConfidenceInput(annotated_conversation,
1407 num_messages, &post_check_rules)) {
1408 response->output_filtered_low_confidence = true;
1409 return true;
1410 }
1411 }
1412
1413 std::unique_ptr<tflite::Interpreter> interpreter;
1414 if (!SuggestActionsFromModel(annotated_conversation, num_messages, options,
1415 response, &interpreter)) {
1416 TC3_LOG(ERROR) << "Could not run model.";
1417 return false;
1418 }
1419
1420 // SuggestActionsFromModel also detects if the conversation is sensitive,
1421 // either by using the old ngram model or the new model.
1422 // Suppress all predictions if the conversation was deemed sensitive.
1423 if (preconditions_.suppress_on_sensitive_topic && response->is_sensitive) {
1424 return true;
1425 }
1426
1427 if (conversation_intent_detection_) {
1428 // TODO(zbin): Ensure the deduplication/ranking logic in ranker.cc works.
1429 auto actions = SuggestActionsFromConversationIntentDetection(
1430 annotated_conversation, options, &response->actions);
1431 if (!actions.ok()) {
1432 TC3_LOG(ERROR) << "Could not run conversation intent detection: "
1433 << actions.error_message();
1434 return false;
1435 }
1436 }
1437
1438 if (!SuggestActionsFromLua(
1439 annotated_conversation, model_executor_.get(), interpreter.get(),
1440 annotator != nullptr ? annotator->entity_data_schema() : nullptr,
1441 &response->actions)) {
1442 TC3_LOG(ERROR) << "Could not suggest actions from script.";
1443 return false;
1444 }
1445
1446 if (!regex_actions_->SuggestActions(annotated_conversation,
1447 entity_data_builder_.get(),
1448 &response->actions)) {
1449 TC3_LOG(ERROR) << "Could not suggest actions from regex rules.";
1450 return false;
1451 }
1452
1453 if (preconditions_.suppress_on_low_confidence_input &&
1454 !regex_actions_->FilterConfidenceOutput(post_check_rules,
1455 &response->actions)) {
1456 TC3_LOG(ERROR) << "Could not post-check actions.";
1457 return false;
1458 }
1459
1460 return true;
1461 }
1462
SuggestActions(const Conversation & conversation,const Annotator * annotator,const ActionSuggestionOptions & options) const1463 ActionsSuggestionsResponse ActionsSuggestions::SuggestActions(
1464 const Conversation& conversation, const Annotator* annotator,
1465 const ActionSuggestionOptions& options) const {
1466 ActionsSuggestionsResponse response;
1467
1468 // Assert that messages are sorted correctly.
1469 for (int i = 1; i < conversation.messages.size(); i++) {
1470 if (conversation.messages[i].reference_time_ms_utc <
1471 conversation.messages[i - 1].reference_time_ms_utc) {
1472 TC3_LOG(ERROR) << "Messages are not sorted most recent last.";
1473 return response;
1474 }
1475 }
1476
1477 // Check that messages are valid utf8.
1478 for (const ConversationMessage& message : conversation.messages) {
1479 if (message.text.size() > std::numeric_limits<int>::max()) {
1480 TC3_LOG(ERROR) << "Rejecting too long input: " << message.text.size();
1481 return {};
1482 }
1483
1484 if (!unilib_->IsValidUtf8(UTF8ToUnicodeText(
1485 message.text.data(), message.text.size(), /*do_copy=*/false))) {
1486 TC3_LOG(ERROR) << "Not valid utf8 provided.";
1487 return response;
1488 }
1489 }
1490
1491 if (!GatherActionsSuggestions(conversation, annotator, options, &response)) {
1492 TC3_LOG(ERROR) << "Could not gather actions suggestions.";
1493 response.actions.clear();
1494 } else if (!ranker_->RankActions(conversation, &response, entity_data_schema_,
1495 annotator != nullptr
1496 ? annotator->entity_data_schema()
1497 : nullptr)) {
1498 TC3_LOG(ERROR) << "Could not rank actions.";
1499 response.actions.clear();
1500 }
1501 return response;
1502 }
1503
SuggestActions(const Conversation & conversation,const ActionSuggestionOptions & options) const1504 ActionsSuggestionsResponse ActionsSuggestions::SuggestActions(
1505 const Conversation& conversation,
1506 const ActionSuggestionOptions& options) const {
1507 return SuggestActions(conversation, /*annotator=*/nullptr, options);
1508 }
1509
model() const1510 const ActionsModel* ActionsSuggestions::model() const { return model_; }
entity_data_schema() const1511 const reflection::Schema* ActionsSuggestions::entity_data_schema() const {
1512 return entity_data_schema_;
1513 }
1514
ViewActionsModel(const void * buffer,int size)1515 const ActionsModel* ViewActionsModel(const void* buffer, int size) {
1516 if (buffer == nullptr) {
1517 return nullptr;
1518 }
1519 return LoadAndVerifyModel(reinterpret_cast<const uint8_t*>(buffer), size);
1520 }
1521
InitializeConversationIntentDetection(const std::string & serialized_config)1522 bool ActionsSuggestions::InitializeConversationIntentDetection(
1523 const std::string& serialized_config) {
1524 auto conversation_intent_detection =
1525 std::make_unique<ConversationIntentDetection>();
1526 if (!conversation_intent_detection->Initialize(serialized_config).ok()) {
1527 TC3_LOG(ERROR) << "Failed to initialize conversation intent detection.";
1528 return false;
1529 }
1530 conversation_intent_detection_ = std::move(conversation_intent_detection);
1531 return true;
1532 }
1533
1534 } // namespace libtextclassifier3
1535