xref: /aosp_15_r20/external/libtextclassifier/native/actions/ranker.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "actions/ranker.h"
18 
19 #include <functional>
20 #include <set>
21 #include <vector>
22 
23 #include "actions/actions_model_generated.h"
24 
25 #if !defined(TC3_DISABLE_LUA)
26 #include "actions/lua-ranker.h"
27 #endif
28 #include "actions/zlib-utils.h"
29 #include "annotator/types.h"
30 #include "utils/base/logging.h"
31 #if !defined(TC3_DISABLE_LUA)
32 #include "utils/lua-utils.h"
33 #endif
34 
35 namespace libtextclassifier3 {
36 namespace {
37 
SortByScoreAndType(std::vector<ActionSuggestion> * actions)38 void SortByScoreAndType(std::vector<ActionSuggestion>* actions) {
39   std::stable_sort(actions->begin(), actions->end(),
40                    [](const ActionSuggestion& a, const ActionSuggestion& b) {
41                      return a.score > b.score ||
42                             (a.score >= b.score && a.type < b.type);
43                    });
44 }
45 
SortByPriorityAndScoreAndType(std::vector<ActionSuggestion> * actions)46 void SortByPriorityAndScoreAndType(std::vector<ActionSuggestion>* actions) {
47   std::stable_sort(
48       actions->begin(), actions->end(),
49       [](const ActionSuggestion& a, const ActionSuggestion& b) {
50         return a.priority_score > b.priority_score ||
51                (a.priority_score >= b.priority_score && a.score > b.score) ||
52                (a.priority_score >= b.priority_score && a.score >= b.score &&
53                 a.type < b.type);
54       });
55 }
56 
57 template <typename T>
Compare(const T & left,const T & right)58 int Compare(const T& left, const T& right) {
59   if (left < right) {
60     return -1;
61   }
62   if (left > right) {
63     return 1;
64   }
65   return 0;
66 }
67 
68 template <>
Compare(const std::string & left,const std::string & right)69 int Compare(const std::string& left, const std::string& right) {
70   return left.compare(right);
71 }
72 
73 template <>
Compare(const MessageTextSpan & span,const MessageTextSpan & other)74 int Compare(const MessageTextSpan& span, const MessageTextSpan& other) {
75   if (const int value = Compare(span.message_index, other.message_index)) {
76     return value;
77   }
78   if (const int value = Compare(span.span.first, other.span.first)) {
79     return value;
80   }
81   if (const int value = Compare(span.span.second, other.span.second)) {
82     return value;
83   }
84   return 0;
85 }
86 
IsSameSpan(const MessageTextSpan & span,const MessageTextSpan & other)87 bool IsSameSpan(const MessageTextSpan& span, const MessageTextSpan& other) {
88   return Compare(span, other) == 0;
89 }
90 
TextSpansIntersect(const MessageTextSpan & span,const MessageTextSpan & other)91 bool TextSpansIntersect(const MessageTextSpan& span,
92                         const MessageTextSpan& other) {
93   return span.message_index == other.message_index &&
94          SpansOverlap(span.span, other.span);
95 }
96 
97 template <>
Compare(const ActionSuggestionAnnotation & annotation,const ActionSuggestionAnnotation & other)98 int Compare(const ActionSuggestionAnnotation& annotation,
99             const ActionSuggestionAnnotation& other) {
100   if (const int value = Compare(annotation.span, other.span)) {
101     return value;
102   }
103   if (const int value = Compare(annotation.name, other.name)) {
104     return value;
105   }
106   if (const int value =
107           Compare(annotation.entity.collection, other.entity.collection)) {
108     return value;
109   }
110   return 0;
111 }
112 
113 // Checks whether two annotations can be considered equivalent.
IsEquivalentActionAnnotation(const ActionSuggestionAnnotation & annotation,const ActionSuggestionAnnotation & other)114 bool IsEquivalentActionAnnotation(const ActionSuggestionAnnotation& annotation,
115                                   const ActionSuggestionAnnotation& other) {
116   return Compare(annotation, other) == 0;
117 }
118 
119 // Compares actions based on annotations.
CompareAnnotationsOnly(const ActionSuggestion & action,const ActionSuggestion & other)120 int CompareAnnotationsOnly(const ActionSuggestion& action,
121                            const ActionSuggestion& other) {
122   if (const int value =
123           Compare(action.annotations.size(), other.annotations.size())) {
124     return value;
125   }
126   for (int i = 0; i < action.annotations.size(); i++) {
127     if (const int value =
128             Compare(action.annotations[i], other.annotations[i])) {
129       return value;
130     }
131   }
132   return 0;
133 }
134 
135 // Checks whether two actions have the same annotations.
HaveEquivalentAnnotations(const ActionSuggestion & action,const ActionSuggestion & other)136 bool HaveEquivalentAnnotations(const ActionSuggestion& action,
137                                const ActionSuggestion& other) {
138   return CompareAnnotationsOnly(action, other) == 0;
139 }
140 
141 template <>
Compare(const ActionSuggestion & action,const ActionSuggestion & other)142 int Compare(const ActionSuggestion& action, const ActionSuggestion& other) {
143   if (const int value = Compare(action.type, other.type)) {
144     return value;
145   }
146   if (const int value = Compare(action.response_text, other.response_text)) {
147     return value;
148   }
149   if (const int value = Compare(action.serialized_entity_data,
150                                 other.serialized_entity_data)) {
151     return value;
152   }
153   return CompareAnnotationsOnly(action, other);
154 }
155 
156 // Checks whether two action suggestions can be considered equivalent.
IsEquivalentActionSuggestion(const ActionSuggestion & action,const ActionSuggestion & other)157 bool IsEquivalentActionSuggestion(const ActionSuggestion& action,
158                                   const ActionSuggestion& other) {
159   return Compare(action, other) == 0;
160 }
161 
162 // Checks whether any action is equivalent to the given one.
IsAnyActionEquivalent(const ActionSuggestion & action,const std::vector<ActionSuggestion> & actions)163 bool IsAnyActionEquivalent(const ActionSuggestion& action,
164                            const std::vector<ActionSuggestion>& actions) {
165   for (const ActionSuggestion& other : actions) {
166     if (IsEquivalentActionSuggestion(action, other)) {
167       return true;
168     }
169   }
170   return false;
171 }
172 
IsConflicting(const ActionSuggestionAnnotation & annotation,const ActionSuggestionAnnotation & other)173 bool IsConflicting(const ActionSuggestionAnnotation& annotation,
174                    const ActionSuggestionAnnotation& other) {
175   // Two annotations are conflicting if they are different but refer to
176   // overlapping spans in the conversation.
177   return (!IsEquivalentActionAnnotation(annotation, other) &&
178           TextSpansIntersect(annotation.span, other.span));
179 }
180 
181 // Checks whether two action suggestions can be considered conflicting.
IsConflictingActionSuggestion(const ActionSuggestion & action,const ActionSuggestion & other)182 bool IsConflictingActionSuggestion(const ActionSuggestion& action,
183                                    const ActionSuggestion& other) {
184   // Actions are considered conflicting, iff they refer to the same text span,
185   // but were not generated from the same annotation.
186   if (action.annotations.empty() || other.annotations.empty()) {
187     return false;
188   }
189   for (const ActionSuggestionAnnotation& annotation : action.annotations) {
190     for (const ActionSuggestionAnnotation& other_annotation :
191          other.annotations) {
192       if (IsConflicting(annotation, other_annotation)) {
193         return true;
194       }
195     }
196   }
197   return false;
198 }
199 
200 // Checks whether any action is considered conflicting with the given one.
IsAnyActionConflicting(const ActionSuggestion & action,const std::vector<ActionSuggestion> & actions)201 bool IsAnyActionConflicting(const ActionSuggestion& action,
202                             const std::vector<ActionSuggestion>& actions) {
203   for (const ActionSuggestion& other : actions) {
204     if (IsConflictingActionSuggestion(action, other)) {
205       return true;
206     }
207   }
208   return false;
209 }
210 
211 }  // namespace
212 
213 std::unique_ptr<ActionsSuggestionsRanker>
CreateActionsSuggestionsRanker(const RankingOptions * options,ZlibDecompressor * decompressor,const std::string & smart_reply_action_type)214 ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
215     const RankingOptions* options, ZlibDecompressor* decompressor,
216     const std::string& smart_reply_action_type) {
217   auto ranker = std::unique_ptr<ActionsSuggestionsRanker>(
218       new ActionsSuggestionsRanker(options, smart_reply_action_type));
219 
220   if (!ranker->InitializeAndValidate(decompressor)) {
221     TC3_LOG(ERROR) << "Could not initialize action ranker.";
222     return nullptr;
223   }
224 
225   return ranker;
226 }
227 
InitializeAndValidate(ZlibDecompressor * decompressor)228 bool ActionsSuggestionsRanker::InitializeAndValidate(
229     ZlibDecompressor* decompressor) {
230   if (options_ == nullptr) {
231     TC3_LOG(ERROR) << "No ranking options specified.";
232     return false;
233   }
234 
235 #if !defined(TC3_DISABLE_LUA)
236   std::string lua_ranking_script;
237   if (GetUncompressedString(options_->lua_ranking_script(),
238                             options_->compressed_lua_ranking_script(),
239                             decompressor, &lua_ranking_script) &&
240       !lua_ranking_script.empty()) {
241     if (!Compile(lua_ranking_script, &lua_bytecode_)) {
242       TC3_LOG(ERROR) << "Could not precompile lua ranking snippet.";
243       return false;
244     }
245   }
246 #endif
247 
248   return true;
249 }
250 
RankActions(const Conversation & conversation,ActionsSuggestionsResponse * response,const reflection::Schema * entity_data_schema,const reflection::Schema * annotations_entity_data_schema) const251 bool ActionsSuggestionsRanker::RankActions(
252     const Conversation& conversation, ActionsSuggestionsResponse* response,
253     const reflection::Schema* entity_data_schema,
254     const reflection::Schema* annotations_entity_data_schema) const {
255   if (options_->deduplicate_suggestions() ||
256       options_->deduplicate_suggestions_by_span()) {
257     // Order suggestions by [priority score -> score] for deduplication
258     SortByPriorityAndScoreAndType(&response->actions);
259 
260     // Deduplicate, keeping the higher score actions.
261     if (options_->deduplicate_suggestions()) {
262       std::vector<ActionSuggestion> deduplicated_actions;
263       for (const ActionSuggestion& candidate : response->actions) {
264         // Check whether we already have an equivalent action.
265         if (!IsAnyActionEquivalent(candidate, deduplicated_actions)) {
266           deduplicated_actions.push_back(std::move(candidate));
267         }
268       }
269       response->actions = std::move(deduplicated_actions);
270     }
271 
272     // Resolve conflicts between conflicting actions referring to the same
273     // text span.
274     if (options_->deduplicate_suggestions_by_span()) {
275       std::vector<ActionSuggestion> deduplicated_actions;
276       for (const ActionSuggestion& candidate : response->actions) {
277         // Check whether we already have a conflicting action.
278         if (!IsAnyActionConflicting(candidate, deduplicated_actions)) {
279           deduplicated_actions.push_back(std::move(candidate));
280         }
281       }
282       response->actions = std::move(deduplicated_actions);
283     }
284   }
285 
286   bool sort_by_priority =
287       options_->sort_type() == RankingOptionsSortType_SORT_TYPE_PRIORITY_SCORE;
288   // Suppress smart replies if actions are present.
289   if (options_->suppress_smart_replies_with_actions()) {
290     std::vector<ActionSuggestion> non_smart_reply_actions;
291     for (const ActionSuggestion& action : response->actions) {
292       if (action.type != smart_reply_action_type_) {
293         non_smart_reply_actions.push_back(std::move(action));
294       }
295     }
296     response->actions = std::move(non_smart_reply_actions);
297   }
298 
299   // Group by annotation if specified.
300   if (options_->group_by_annotations()) {
301     auto group_id = std::map<
302         ActionSuggestion, int,
303         std::function<bool(const ActionSuggestion&, const ActionSuggestion&)>>{
304         [](const ActionSuggestion& action, const ActionSuggestion& other) {
305           return (CompareAnnotationsOnly(action, other) < 0);
306         }};
307     typedef std::vector<ActionSuggestion> ActionSuggestionGroup;
308     std::vector<ActionSuggestionGroup> groups;
309 
310     // Group actions by the annotation set they are based of.
311     for (const ActionSuggestion& action : response->actions) {
312       // Treat actions with no annotations idependently.
313       if (action.annotations.empty()) {
314         groups.emplace_back(1, action);
315         continue;
316       }
317 
318       auto it = group_id.find(action);
319       if (it != group_id.end()) {
320         groups[it->second].push_back(action);
321       } else {
322         group_id[action] = groups.size();
323         groups.emplace_back(1, action);
324       }
325     }
326 
327     // Sort within each group by score.
328     for (std::vector<ActionSuggestion>& group : groups) {
329       if (sort_by_priority) {
330         SortByPriorityAndScoreAndType(&group);
331       } else {
332         SortByScoreAndType(&group);
333       }
334     }
335 
336     // Sort groups by maximum score or priority score.
337     if (sort_by_priority) {
338       std::stable_sort(
339           groups.begin(), groups.end(),
340           [](const std::vector<ActionSuggestion>& a,
341              const std::vector<ActionSuggestion>& b) {
342             return (a.begin()->priority_score > b.begin()->priority_score) ||
343                    (a.begin()->priority_score >= b.begin()->priority_score &&
344                     a.begin()->score > b.begin()->score) ||
345                    (a.begin()->priority_score >= b.begin()->priority_score &&
346                     a.begin()->score >= b.begin()->score &&
347                     a.begin()->type < b.begin()->type);
348           });
349     } else {
350       std::stable_sort(groups.begin(), groups.end(),
351                        [](const std::vector<ActionSuggestion>& a,
352                           const std::vector<ActionSuggestion>& b) {
353                          return a.begin()->score > b.begin()->score ||
354                                 (a.begin()->score >= b.begin()->score &&
355                                  a.begin()->type < b.begin()->type);
356                        });
357     }
358 
359     // Flatten result.
360     const size_t num_actions = response->actions.size();
361     response->actions.clear();
362     response->actions.reserve(num_actions);
363     for (const std::vector<ActionSuggestion>& actions : groups) {
364       response->actions.insert(response->actions.end(), actions.begin(),
365                                actions.end());
366     }
367   } else if (sort_by_priority) {
368     SortByPriorityAndScoreAndType(&response->actions);
369   } else {
370     SortByScoreAndType(&response->actions);
371   }
372 
373 #if !defined(TC3_DISABLE_LUA)
374   // Run lua ranking snippet, if provided.
375   if (!lua_bytecode_.empty()) {
376     auto lua_ranker = ActionsSuggestionsLuaRanker::Create(
377         conversation, lua_bytecode_, entity_data_schema,
378         annotations_entity_data_schema, response);
379     if (lua_ranker == nullptr || !lua_ranker->RankActions()) {
380       TC3_LOG(ERROR) << "Could not run lua ranking snippet.";
381       return false;
382     }
383   }
384 #endif
385 
386   return true;
387 }
388 
389 }  // namespace libtextclassifier3
390