1 /*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "actions/ranker.h"
18
19 #include <functional>
20 #include <set>
21 #include <vector>
22
23 #include "actions/actions_model_generated.h"
24
25 #if !defined(TC3_DISABLE_LUA)
26 #include "actions/lua-ranker.h"
27 #endif
28 #include "actions/zlib-utils.h"
29 #include "annotator/types.h"
30 #include "utils/base/logging.h"
31 #if !defined(TC3_DISABLE_LUA)
32 #include "utils/lua-utils.h"
33 #endif
34
35 namespace libtextclassifier3 {
36 namespace {
37
SortByScoreAndType(std::vector<ActionSuggestion> * actions)38 void SortByScoreAndType(std::vector<ActionSuggestion>* actions) {
39 std::stable_sort(actions->begin(), actions->end(),
40 [](const ActionSuggestion& a, const ActionSuggestion& b) {
41 return a.score > b.score ||
42 (a.score >= b.score && a.type < b.type);
43 });
44 }
45
SortByPriorityAndScoreAndType(std::vector<ActionSuggestion> * actions)46 void SortByPriorityAndScoreAndType(std::vector<ActionSuggestion>* actions) {
47 std::stable_sort(
48 actions->begin(), actions->end(),
49 [](const ActionSuggestion& a, const ActionSuggestion& b) {
50 return a.priority_score > b.priority_score ||
51 (a.priority_score >= b.priority_score && a.score > b.score) ||
52 (a.priority_score >= b.priority_score && a.score >= b.score &&
53 a.type < b.type);
54 });
55 }
56
57 template <typename T>
Compare(const T & left,const T & right)58 int Compare(const T& left, const T& right) {
59 if (left < right) {
60 return -1;
61 }
62 if (left > right) {
63 return 1;
64 }
65 return 0;
66 }
67
68 template <>
Compare(const std::string & left,const std::string & right)69 int Compare(const std::string& left, const std::string& right) {
70 return left.compare(right);
71 }
72
73 template <>
Compare(const MessageTextSpan & span,const MessageTextSpan & other)74 int Compare(const MessageTextSpan& span, const MessageTextSpan& other) {
75 if (const int value = Compare(span.message_index, other.message_index)) {
76 return value;
77 }
78 if (const int value = Compare(span.span.first, other.span.first)) {
79 return value;
80 }
81 if (const int value = Compare(span.span.second, other.span.second)) {
82 return value;
83 }
84 return 0;
85 }
86
IsSameSpan(const MessageTextSpan & span,const MessageTextSpan & other)87 bool IsSameSpan(const MessageTextSpan& span, const MessageTextSpan& other) {
88 return Compare(span, other) == 0;
89 }
90
TextSpansIntersect(const MessageTextSpan & span,const MessageTextSpan & other)91 bool TextSpansIntersect(const MessageTextSpan& span,
92 const MessageTextSpan& other) {
93 return span.message_index == other.message_index &&
94 SpansOverlap(span.span, other.span);
95 }
96
97 template <>
Compare(const ActionSuggestionAnnotation & annotation,const ActionSuggestionAnnotation & other)98 int Compare(const ActionSuggestionAnnotation& annotation,
99 const ActionSuggestionAnnotation& other) {
100 if (const int value = Compare(annotation.span, other.span)) {
101 return value;
102 }
103 if (const int value = Compare(annotation.name, other.name)) {
104 return value;
105 }
106 if (const int value =
107 Compare(annotation.entity.collection, other.entity.collection)) {
108 return value;
109 }
110 return 0;
111 }
112
113 // Checks whether two annotations can be considered equivalent.
IsEquivalentActionAnnotation(const ActionSuggestionAnnotation & annotation,const ActionSuggestionAnnotation & other)114 bool IsEquivalentActionAnnotation(const ActionSuggestionAnnotation& annotation,
115 const ActionSuggestionAnnotation& other) {
116 return Compare(annotation, other) == 0;
117 }
118
119 // Compares actions based on annotations.
CompareAnnotationsOnly(const ActionSuggestion & action,const ActionSuggestion & other)120 int CompareAnnotationsOnly(const ActionSuggestion& action,
121 const ActionSuggestion& other) {
122 if (const int value =
123 Compare(action.annotations.size(), other.annotations.size())) {
124 return value;
125 }
126 for (int i = 0; i < action.annotations.size(); i++) {
127 if (const int value =
128 Compare(action.annotations[i], other.annotations[i])) {
129 return value;
130 }
131 }
132 return 0;
133 }
134
135 // Checks whether two actions have the same annotations.
HaveEquivalentAnnotations(const ActionSuggestion & action,const ActionSuggestion & other)136 bool HaveEquivalentAnnotations(const ActionSuggestion& action,
137 const ActionSuggestion& other) {
138 return CompareAnnotationsOnly(action, other) == 0;
139 }
140
141 template <>
Compare(const ActionSuggestion & action,const ActionSuggestion & other)142 int Compare(const ActionSuggestion& action, const ActionSuggestion& other) {
143 if (const int value = Compare(action.type, other.type)) {
144 return value;
145 }
146 if (const int value = Compare(action.response_text, other.response_text)) {
147 return value;
148 }
149 if (const int value = Compare(action.serialized_entity_data,
150 other.serialized_entity_data)) {
151 return value;
152 }
153 return CompareAnnotationsOnly(action, other);
154 }
155
156 // Checks whether two action suggestions can be considered equivalent.
IsEquivalentActionSuggestion(const ActionSuggestion & action,const ActionSuggestion & other)157 bool IsEquivalentActionSuggestion(const ActionSuggestion& action,
158 const ActionSuggestion& other) {
159 return Compare(action, other) == 0;
160 }
161
162 // Checks whether any action is equivalent to the given one.
IsAnyActionEquivalent(const ActionSuggestion & action,const std::vector<ActionSuggestion> & actions)163 bool IsAnyActionEquivalent(const ActionSuggestion& action,
164 const std::vector<ActionSuggestion>& actions) {
165 for (const ActionSuggestion& other : actions) {
166 if (IsEquivalentActionSuggestion(action, other)) {
167 return true;
168 }
169 }
170 return false;
171 }
172
IsConflicting(const ActionSuggestionAnnotation & annotation,const ActionSuggestionAnnotation & other)173 bool IsConflicting(const ActionSuggestionAnnotation& annotation,
174 const ActionSuggestionAnnotation& other) {
175 // Two annotations are conflicting if they are different but refer to
176 // overlapping spans in the conversation.
177 return (!IsEquivalentActionAnnotation(annotation, other) &&
178 TextSpansIntersect(annotation.span, other.span));
179 }
180
181 // Checks whether two action suggestions can be considered conflicting.
IsConflictingActionSuggestion(const ActionSuggestion & action,const ActionSuggestion & other)182 bool IsConflictingActionSuggestion(const ActionSuggestion& action,
183 const ActionSuggestion& other) {
184 // Actions are considered conflicting, iff they refer to the same text span,
185 // but were not generated from the same annotation.
186 if (action.annotations.empty() || other.annotations.empty()) {
187 return false;
188 }
189 for (const ActionSuggestionAnnotation& annotation : action.annotations) {
190 for (const ActionSuggestionAnnotation& other_annotation :
191 other.annotations) {
192 if (IsConflicting(annotation, other_annotation)) {
193 return true;
194 }
195 }
196 }
197 return false;
198 }
199
200 // Checks whether any action is considered conflicting with the given one.
IsAnyActionConflicting(const ActionSuggestion & action,const std::vector<ActionSuggestion> & actions)201 bool IsAnyActionConflicting(const ActionSuggestion& action,
202 const std::vector<ActionSuggestion>& actions) {
203 for (const ActionSuggestion& other : actions) {
204 if (IsConflictingActionSuggestion(action, other)) {
205 return true;
206 }
207 }
208 return false;
209 }
210
211 } // namespace
212
213 std::unique_ptr<ActionsSuggestionsRanker>
CreateActionsSuggestionsRanker(const RankingOptions * options,ZlibDecompressor * decompressor,const std::string & smart_reply_action_type)214 ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
215 const RankingOptions* options, ZlibDecompressor* decompressor,
216 const std::string& smart_reply_action_type) {
217 auto ranker = std::unique_ptr<ActionsSuggestionsRanker>(
218 new ActionsSuggestionsRanker(options, smart_reply_action_type));
219
220 if (!ranker->InitializeAndValidate(decompressor)) {
221 TC3_LOG(ERROR) << "Could not initialize action ranker.";
222 return nullptr;
223 }
224
225 return ranker;
226 }
227
InitializeAndValidate(ZlibDecompressor * decompressor)228 bool ActionsSuggestionsRanker::InitializeAndValidate(
229 ZlibDecompressor* decompressor) {
230 if (options_ == nullptr) {
231 TC3_LOG(ERROR) << "No ranking options specified.";
232 return false;
233 }
234
235 #if !defined(TC3_DISABLE_LUA)
236 std::string lua_ranking_script;
237 if (GetUncompressedString(options_->lua_ranking_script(),
238 options_->compressed_lua_ranking_script(),
239 decompressor, &lua_ranking_script) &&
240 !lua_ranking_script.empty()) {
241 if (!Compile(lua_ranking_script, &lua_bytecode_)) {
242 TC3_LOG(ERROR) << "Could not precompile lua ranking snippet.";
243 return false;
244 }
245 }
246 #endif
247
248 return true;
249 }
250
RankActions(const Conversation & conversation,ActionsSuggestionsResponse * response,const reflection::Schema * entity_data_schema,const reflection::Schema * annotations_entity_data_schema) const251 bool ActionsSuggestionsRanker::RankActions(
252 const Conversation& conversation, ActionsSuggestionsResponse* response,
253 const reflection::Schema* entity_data_schema,
254 const reflection::Schema* annotations_entity_data_schema) const {
255 if (options_->deduplicate_suggestions() ||
256 options_->deduplicate_suggestions_by_span()) {
257 // Order suggestions by [priority score -> score] for deduplication
258 SortByPriorityAndScoreAndType(&response->actions);
259
260 // Deduplicate, keeping the higher score actions.
261 if (options_->deduplicate_suggestions()) {
262 std::vector<ActionSuggestion> deduplicated_actions;
263 for (const ActionSuggestion& candidate : response->actions) {
264 // Check whether we already have an equivalent action.
265 if (!IsAnyActionEquivalent(candidate, deduplicated_actions)) {
266 deduplicated_actions.push_back(std::move(candidate));
267 }
268 }
269 response->actions = std::move(deduplicated_actions);
270 }
271
272 // Resolve conflicts between conflicting actions referring to the same
273 // text span.
274 if (options_->deduplicate_suggestions_by_span()) {
275 std::vector<ActionSuggestion> deduplicated_actions;
276 for (const ActionSuggestion& candidate : response->actions) {
277 // Check whether we already have a conflicting action.
278 if (!IsAnyActionConflicting(candidate, deduplicated_actions)) {
279 deduplicated_actions.push_back(std::move(candidate));
280 }
281 }
282 response->actions = std::move(deduplicated_actions);
283 }
284 }
285
286 bool sort_by_priority =
287 options_->sort_type() == RankingOptionsSortType_SORT_TYPE_PRIORITY_SCORE;
288 // Suppress smart replies if actions are present.
289 if (options_->suppress_smart_replies_with_actions()) {
290 std::vector<ActionSuggestion> non_smart_reply_actions;
291 for (const ActionSuggestion& action : response->actions) {
292 if (action.type != smart_reply_action_type_) {
293 non_smart_reply_actions.push_back(std::move(action));
294 }
295 }
296 response->actions = std::move(non_smart_reply_actions);
297 }
298
299 // Group by annotation if specified.
300 if (options_->group_by_annotations()) {
301 auto group_id = std::map<
302 ActionSuggestion, int,
303 std::function<bool(const ActionSuggestion&, const ActionSuggestion&)>>{
304 [](const ActionSuggestion& action, const ActionSuggestion& other) {
305 return (CompareAnnotationsOnly(action, other) < 0);
306 }};
307 typedef std::vector<ActionSuggestion> ActionSuggestionGroup;
308 std::vector<ActionSuggestionGroup> groups;
309
310 // Group actions by the annotation set they are based of.
311 for (const ActionSuggestion& action : response->actions) {
312 // Treat actions with no annotations idependently.
313 if (action.annotations.empty()) {
314 groups.emplace_back(1, action);
315 continue;
316 }
317
318 auto it = group_id.find(action);
319 if (it != group_id.end()) {
320 groups[it->second].push_back(action);
321 } else {
322 group_id[action] = groups.size();
323 groups.emplace_back(1, action);
324 }
325 }
326
327 // Sort within each group by score.
328 for (std::vector<ActionSuggestion>& group : groups) {
329 if (sort_by_priority) {
330 SortByPriorityAndScoreAndType(&group);
331 } else {
332 SortByScoreAndType(&group);
333 }
334 }
335
336 // Sort groups by maximum score or priority score.
337 if (sort_by_priority) {
338 std::stable_sort(
339 groups.begin(), groups.end(),
340 [](const std::vector<ActionSuggestion>& a,
341 const std::vector<ActionSuggestion>& b) {
342 return (a.begin()->priority_score > b.begin()->priority_score) ||
343 (a.begin()->priority_score >= b.begin()->priority_score &&
344 a.begin()->score > b.begin()->score) ||
345 (a.begin()->priority_score >= b.begin()->priority_score &&
346 a.begin()->score >= b.begin()->score &&
347 a.begin()->type < b.begin()->type);
348 });
349 } else {
350 std::stable_sort(groups.begin(), groups.end(),
351 [](const std::vector<ActionSuggestion>& a,
352 const std::vector<ActionSuggestion>& b) {
353 return a.begin()->score > b.begin()->score ||
354 (a.begin()->score >= b.begin()->score &&
355 a.begin()->type < b.begin()->type);
356 });
357 }
358
359 // Flatten result.
360 const size_t num_actions = response->actions.size();
361 response->actions.clear();
362 response->actions.reserve(num_actions);
363 for (const std::vector<ActionSuggestion>& actions : groups) {
364 response->actions.insert(response->actions.end(), actions.begin(),
365 actions.end());
366 }
367 } else if (sort_by_priority) {
368 SortByPriorityAndScoreAndType(&response->actions);
369 } else {
370 SortByScoreAndType(&response->actions);
371 }
372
373 #if !defined(TC3_DISABLE_LUA)
374 // Run lua ranking snippet, if provided.
375 if (!lua_bytecode_.empty()) {
376 auto lua_ranker = ActionsSuggestionsLuaRanker::Create(
377 conversation, lua_bytecode_, entity_data_schema,
378 annotations_entity_data_schema, response);
379 if (lua_ranker == nullptr || !lua_ranker->RankActions()) {
380 TC3_LOG(ERROR) << "Could not run lua ranking snippet.";
381 return false;
382 }
383 }
384 #endif
385
386 return true;
387 }
388
389 } // namespace libtextclassifier3
390