xref: /aosp_15_r20/external/icing/icing/result/result-retriever-v2.cc (revision 8b6cd535a057e39b3b86660c4aa06c99747c2136)
1 // Copyright (C) 2022 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "icing/result/result-retriever-v2.h"
16 
17 #include <cstddef>
18 #include <cstdint>
19 #include <memory>
20 #include <string>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24 
25 #include "icing/text_classifier/lib3/utils/base/statusor.h"
26 #include "icing/absl_ports/mutex.h"
27 #include "icing/proto/document.pb.h"
28 #include "icing/proto/search.pb.h"
29 #include "icing/result/page-result.h"
30 #include "icing/result/projection-tree.h"
31 #include "icing/result/projector.h"
32 #include "icing/result/result-adjustment-info.h"
33 #include "icing/result/result-state-v2.h"
34 #include "icing/result/snippet-context.h"
35 #include "icing/result/snippet-retriever.h"
36 #include "icing/schema/schema-store.h"
37 #include "icing/schema/section.h"
38 #include "icing/scoring/scored-document-hit.h"
39 #include "icing/store/document-filter-data.h"
40 #include "icing/store/document-store.h"
41 #include "icing/store/namespace-id.h"
42 #include "icing/tokenization/language-segmenter.h"
43 #include "icing/transform/normalizer.h"
44 #include "icing/util/logging.h"
45 #include "icing/util/status-macros.h"
46 
47 namespace icing {
48 namespace lib {
49 
50 namespace {
51 
ApplyProjection(const ResultAdjustmentInfo * adjustment_info,DocumentProto * document)52 void ApplyProjection(const ResultAdjustmentInfo* adjustment_info,
53                      DocumentProto* document) {
54   if (adjustment_info == nullptr) {
55     return;
56   }
57 
58   auto itr = adjustment_info->projection_tree_map.find(document->schema());
59   if (itr != adjustment_info->projection_tree_map.end()) {
60     projector::Project(itr->second.root().children, document);
61   } else {
62     auto wildcard_projection_tree_itr =
63         adjustment_info->projection_tree_map.find(
64             std::string(SchemaStore::kSchemaTypeWildcard));
65     if (wildcard_projection_tree_itr !=
66         adjustment_info->projection_tree_map.end()) {
67       projector::Project(wildcard_projection_tree_itr->second.root().children,
68                          document);
69     }
70   }
71 }
72 
ApplySnippet(ResultAdjustmentInfo * adjustment_info,const SnippetRetriever & snippet_retriever,const DocumentProto & document,SectionIdMask section_id_mask,SearchResultProto::ResultProto * result)73 bool ApplySnippet(ResultAdjustmentInfo* adjustment_info,
74                   const SnippetRetriever& snippet_retriever,
75                   const DocumentProto& document, SectionIdMask section_id_mask,
76                   SearchResultProto::ResultProto* result) {
77   if (adjustment_info == nullptr) {
78     return false;
79   }
80 
81   const SnippetContext& snippet_context = adjustment_info->snippet_context;
82   int& remaining_num_to_snippet = adjustment_info->remaining_num_to_snippet;
83 
84   if (snippet_context.snippet_spec.num_matches_per_property() > 0 &&
85       remaining_num_to_snippet > 0) {
86     SnippetProto snippet_proto = snippet_retriever.RetrieveSnippet(
87         snippet_context.query_terms, snippet_context.match_type,
88         snippet_context.snippet_spec, document, section_id_mask);
89     *result->mutable_snippet() = std::move(snippet_proto);
90     --remaining_num_to_snippet;
91     return true;
92   }
93 
94   return false;
95 }
96 
97 }  // namespace
98 
ShouldBeRemoved(const ScoredDocumentHit & scored_document_hit,const std::unordered_map<int32_t,int> & entry_id_group_id_map,const DocumentStore & document_store,std::vector<int> & group_result_limits,ResultSpecProto::ResultGroupingType result_group_type,int64_t current_time_ms) const99 bool GroupResultLimiterV2::ShouldBeRemoved(
100     const ScoredDocumentHit& scored_document_hit,
101     const std::unordered_map<int32_t, int>& entry_id_group_id_map,
102     const DocumentStore& document_store, std::vector<int>& group_result_limits,
103     ResultSpecProto::ResultGroupingType result_group_type,
104     int64_t current_time_ms) const {
105   auto document_filter_data_optional =
106       document_store.GetAliveDocumentFilterData(
107           scored_document_hit.document_id(), current_time_ms);
108   if (!document_filter_data_optional) {
109     // The document doesn't exist.
110     return true;
111   }
112   NamespaceId namespace_id =
113       document_filter_data_optional.value().namespace_id();
114   SchemaTypeId schema_type_id =
115       document_filter_data_optional.value().schema_type_id();
116   auto entry_id_or = document_store.GetResultGroupingEntryId(
117       result_group_type, namespace_id, schema_type_id);
118   if (!entry_id_or.ok()) {
119     return false;
120   }
121   int32_t entry_id = entry_id_or.ValueOrDie();
122   auto iter = entry_id_group_id_map.find(entry_id);
123   if (iter == entry_id_group_id_map.end()) {
124     // If a ResultGrouping Entry Id isn't found in entry_id_group_id_map, then
125     // there are no limits placed on results from this entry id.
126     return false;
127   }
128   int& count = group_result_limits.at(iter->second);
129   if (count <= 0) {
130     return true;
131   }
132   --count;
133   return false;
134 }
135 
136 libtextclassifier3::StatusOr<std::unique_ptr<ResultRetrieverV2>>
Create(const DocumentStore * doc_store,const SchemaStore * schema_store,const LanguageSegmenter * language_segmenter,const Normalizer * normalizer,std::unique_ptr<const GroupResultLimiterV2> group_result_limiter)137 ResultRetrieverV2::Create(
138     const DocumentStore* doc_store, const SchemaStore* schema_store,
139     const LanguageSegmenter* language_segmenter, const Normalizer* normalizer,
140     std::unique_ptr<const GroupResultLimiterV2> group_result_limiter) {
141   ICING_RETURN_ERROR_IF_NULL(doc_store);
142   ICING_RETURN_ERROR_IF_NULL(schema_store);
143   ICING_RETURN_ERROR_IF_NULL(language_segmenter);
144   ICING_RETURN_ERROR_IF_NULL(normalizer);
145   ICING_RETURN_ERROR_IF_NULL(group_result_limiter);
146 
147   ICING_ASSIGN_OR_RETURN(
148       std::unique_ptr<SnippetRetriever> snippet_retriever,
149       SnippetRetriever::Create(schema_store, language_segmenter, normalizer));
150 
151   return std::unique_ptr<ResultRetrieverV2>(
152       new ResultRetrieverV2(doc_store, std::move(snippet_retriever),
153                             std::move(group_result_limiter)));
154 }
155 
RetrieveNextPage(ResultStateV2 & result_state,int64_t current_time_ms) const156 std::pair<PageResult, bool> ResultRetrieverV2::RetrieveNextPage(
157     ResultStateV2& result_state, int64_t current_time_ms) const {
158   absl_ports::unique_lock l(&result_state.mutex);
159 
160   // For calculating page
161   int original_scored_document_hits_ranker_size =
162       result_state.scored_document_hits_ranker->size();
163   int num_results_with_snippets = 0;
164 
165   // Retrieve info
166   std::vector<SearchResultProto::ResultProto> results;
167   int32_t num_total_bytes = 0;
168   while (results.size() < result_state.num_per_page() &&
169          !result_state.scored_document_hits_ranker->empty()) {
170     JoinedScoredDocumentHit next_best_document_hit =
171         result_state.scored_document_hits_ranker->PopNext();
172     if (group_result_limiter_->ShouldBeRemoved(
173             next_best_document_hit.parent_scored_document_hit(),
174             result_state.entry_id_group_id_map(), doc_store_,
175             result_state.group_result_limits, result_state.result_group_type(),
176             current_time_ms)) {
177       continue;
178     }
179 
180     libtextclassifier3::StatusOr<DocumentProto> document_or = doc_store_.Get(
181         next_best_document_hit.parent_scored_document_hit().document_id());
182     if (!document_or.ok()) {
183       // Skip the document if getting errors.
184       ICING_LOG(WARNING) << "Fail to fetch document from document store: "
185                          << document_or.status().error_message();
186       continue;
187     }
188 
189     DocumentProto document = std::move(document_or).ValueOrDie();
190     // Apply parent projection
191     ApplyProjection(result_state.parent_adjustment_info(), &document);
192 
193     SearchResultProto::ResultProto result;
194     // Add parent snippet if requested.
195     if (ApplySnippet(result_state.parent_adjustment_info(), *snippet_retriever_,
196                      document,
197                      next_best_document_hit.parent_scored_document_hit()
198                          .hit_section_id_mask(),
199                      &result)) {
200       ++num_results_with_snippets;
201     }
202 
203     // Add the document, itself.
204     *result.mutable_document() = std::move(document);
205     result.set_score(next_best_document_hit.final_score());
206     const auto* parent_additional_scores =
207         next_best_document_hit.parent_scored_document_hit().additional_scores();
208     if (parent_additional_scores != nullptr) {
209       result.mutable_additional_scores()->Add(parent_additional_scores->begin(),
210                                               parent_additional_scores->end());
211     }
212 
213     // Retrieve child documents
214     for (const ScoredDocumentHit& child_scored_document_hit :
215          next_best_document_hit.child_scored_document_hits()) {
216       if (result.joined_results_size() >=
217           result_state.max_joined_children_per_parent_to_return()) {
218         break;
219       }
220 
221       libtextclassifier3::StatusOr<DocumentProto> child_document_or =
222           doc_store_.Get(child_scored_document_hit.document_id());
223       if (!child_document_or.ok()) {
224         // Skip the document if getting errors.
225         ICING_LOG(WARNING)
226             << "Fail to fetch child document from document store: "
227             << child_document_or.status().error_message();
228         continue;
229       }
230 
231       DocumentProto child_document = std::move(child_document_or).ValueOrDie();
232       ApplyProjection(result_state.child_adjustment_info(), &child_document);
233 
234       SearchResultProto::ResultProto* child_result =
235           result.add_joined_results();
236       // Add child snippet if requested.
237       ApplySnippet(result_state.child_adjustment_info(), *snippet_retriever_,
238                    child_document,
239                    child_scored_document_hit.hit_section_id_mask(),
240                    child_result);
241 
242       *child_result->mutable_document() = std::move(child_document);
243       child_result->set_score(child_scored_document_hit.score());
244       if (child_scored_document_hit.additional_scores() != nullptr) {
245         child_result->mutable_additional_scores()->Add(
246             child_scored_document_hit.additional_scores()->begin(),
247             child_scored_document_hit.additional_scores()->end());
248       }
249     }
250 
251     size_t result_bytes = result.ByteSizeLong();
252     results.push_back(std::move(result));
253 
254     // Check if num_total_bytes + result_bytes reaches or exceeds
255     // num_total_bytes_per_page_threshold. Use subtraction to avoid integer
256     // overflow.
257     if (result_bytes >=
258         result_state.num_total_bytes_per_page_threshold() - num_total_bytes) {
259       break;
260     }
261     num_total_bytes += result_bytes;
262   }
263 
264   // Update numbers in ResultState
265   result_state.num_returned += results.size();
266   result_state.IncrementNumTotalHits(
267       result_state.scored_document_hits_ranker->size() -
268       original_scored_document_hits_ranker_size);
269 
270   bool has_more_results = !result_state.scored_document_hits_ranker->empty();
271 
272   return std::make_pair(
273       PageResult(std::move(results), num_results_with_snippets,
274                  result_state.num_per_page()),
275       has_more_results);
276 }
277 
278 }  // namespace lib
279 }  // namespace icing
280