xref: /aosp_15_r20/external/icing/icing/index/main/doc-hit-info-iterator-term-main.h (revision 8b6cd535a057e39b3b86660c4aa06c99747c2136)
1 // Copyright (C) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef ICING_INDEX_ITERATOR_DOC_HIT_INFO_ITERATOR_TERM_MAIN_H_
16 #define ICING_INDEX_ITERATOR_DOC_HIT_INFO_ITERATOR_TERM_MAIN_H_
17 
18 #include <cstdint>
19 #include <memory>
20 #include <optional>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "icing/text_classifier/lib3/utils/base/status.h"
26 #include "icing/index/hit/doc-hit-info.h"
27 #include "icing/index/hit/hit.h"
28 #include "icing/index/iterator/doc-hit-info-iterator.h"
29 #include "icing/index/main/main-index.h"
30 #include "icing/index/main/posting-list-hit-accessor.h"
31 #include "icing/schema/section.h"
32 
33 namespace icing {
34 namespace lib {
35 
36 class DocHitInfoIteratorTermMain : public DocHitInfoLeafIterator {
37  public:
38   struct DocHitInfoAndTermFrequencyArray {
39     DocHitInfo doc_hit_info;
40     std::optional<Hit::TermFrequencyArray> term_frequency_array;
41 
42     explicit DocHitInfoAndTermFrequencyArray() = default;
43 
DocHitInfoAndTermFrequencyArrayDocHitInfoAndTermFrequencyArray44     explicit DocHitInfoAndTermFrequencyArray(
45         DocHitInfo doc_hit_info_in,
46         std::optional<Hit::TermFrequencyArray> term_frequency_array_in)
47         : doc_hit_info(std::move(doc_hit_info_in)),
48           term_frequency_array(std::move(term_frequency_array_in)) {}
49   };
50 
DocHitInfoIteratorTermMain(MainIndex * main_index,const std::string & term,int term_start_index,int unnormalized_term_length,SectionIdMask section_restrict_mask,bool need_hit_term_frequency)51   explicit DocHitInfoIteratorTermMain(MainIndex* main_index,
52                                       const std::string& term,
53                                       int term_start_index,
54                                       int unnormalized_term_length,
55                                       SectionIdMask section_restrict_mask,
56                                       bool need_hit_term_frequency)
57       : term_(term),
58         term_start_index_(term_start_index),
59         unnormalized_term_length_(unnormalized_term_length),
60         posting_list_accessor_(nullptr),
61         main_index_(main_index),
62         cached_doc_hit_infos_idx_(-1),
63         num_advance_calls_(0),
64         num_blocks_inspected_(0),
65         all_pages_consumed_(false),
66         section_restrict_mask_(section_restrict_mask),
67         need_hit_term_frequency_(need_hit_term_frequency) {}
68 
69   libtextclassifier3::Status Advance() override;
70 
71   libtextclassifier3::StatusOr<TrimmedNode> TrimRightMostNode() && override;
72 
GetCallStats()73   CallStats GetCallStats() const override {
74     return CallStats(
75         /*num_leaf_advance_calls_lite_index_in=*/0,
76         /*num_leaf_advance_calls_main_index_in=*/num_advance_calls_,
77         /*num_leaf_advance_calls_integer_index_in=*/0,
78         /*num_leaf_advance_calls_no_index_in=*/0,
79         /*num_blocks_inspected_in=*/num_blocks_inspected_);
80   }
81 
82   void PopulateMatchedTermsStats(
83       std::vector<TermMatchInfo>* matched_terms_stats,
84       SectionIdMask filtering_section_mask = kSectionIdMaskAll) const override {
85     if (cached_doc_hit_infos_idx_ == -1 ||
86         cached_doc_hit_infos_idx_ >= cached_doc_hit_infos_.size()) {
87       // Current hit isn't valid, return.
88       return;
89     }
90     SectionIdMask section_mask =
91         doc_hit_info_.hit_section_ids_mask() & filtering_section_mask;
92     SectionIdMask section_mask_copy = section_mask;
93     std::array<Hit::TermFrequency, kTotalNumSections> section_term_frequencies =
94         {Hit::kNoTermFrequency};
95     while (section_mask_copy) {
96       SectionId section_id = __builtin_ctzll(section_mask_copy);
97       if (need_hit_term_frequency_) {
98         section_term_frequencies.at(section_id) =
99             (*cached_doc_hit_infos_.at(cached_doc_hit_infos_idx_)
100                   .term_frequency_array)[section_id];
101       }
102       section_mask_copy &= ~(UINT64_C(1) << section_id);
103     }
104     TermMatchInfo term_stats(term_, section_mask,
105                              std::move(section_term_frequencies));
106 
107     for (const TermMatchInfo& cur_term_stats : *matched_terms_stats) {
108       if (cur_term_stats.term == term_stats.term) {
109         // Same docId and same term, we don't need to add the term and the term
110         // frequency should always be the same
111         return;
112       }
113     }
114     matched_terms_stats->push_back(std::move(term_stats));
115   }
116 
117  protected:
118   // Add DocHitInfos corresponding to term_ to cached_doc_hit_infos_.
119   virtual libtextclassifier3::Status RetrieveMoreHits() = 0;
120 
121   const std::string term_;
122 
123   // The start index of the given term in the search query
124   int term_start_index_;
125   // The length of the given unnormalized term in the search query
126   int unnormalized_term_length_;
127   // The accessor of the posting list chain for the requested term.
128   std::unique_ptr<PostingListHitAccessor> posting_list_accessor_;
129 
130   MainIndex* main_index_;
131   // Stores hits and optional term frequency arrays retrieved from the index.
132   // This may only be a subset of the hits that are present in the index.
133   // Current value pointed to by the Iterator is tracked by
134   // cached_doc_hit_infos_idx_.
135   std::vector<DocHitInfoAndTermFrequencyArray> cached_doc_hit_infos_;
136   int cached_doc_hit_infos_idx_;
137 
138   int num_advance_calls_;
139   int num_blocks_inspected_;
140   bool all_pages_consumed_;
141   // Mask indicating which sections hits should be considered for.
142   // Ex. 0000 0000 0000 0010 means that only hits from section 1 are desired.
143   const SectionIdMask section_restrict_mask_;
144   const bool need_hit_term_frequency_;
145 
146  private:
147   // Remaining number of hits including the current hit.
148   // Returns -1 if cached_doc_hit_infos_idx_ is invalid.
cached_doc_hit_info_count()149   int cached_doc_hit_info_count() const {
150     if (cached_doc_hit_infos_idx_ == -1 ||
151         cached_doc_hit_infos_idx_ >= cached_doc_hit_infos_.size()) {
152       return -1;
153     }
154     return cached_doc_hit_infos_.size() - cached_doc_hit_infos_idx_;
155   }
156 };
157 
158 class DocHitInfoIteratorTermMainExact : public DocHitInfoIteratorTermMain {
159  public:
DocHitInfoIteratorTermMainExact(MainIndex * main_index,const std::string & term,int term_start_index,int unnormalized_term_length,SectionIdMask section_restrict_mask,bool need_hit_term_frequency)160   explicit DocHitInfoIteratorTermMainExact(MainIndex* main_index,
161                                            const std::string& term,
162                                            int term_start_index,
163                                            int unnormalized_term_length,
164                                            SectionIdMask section_restrict_mask,
165                                            bool need_hit_term_frequency)
166       : DocHitInfoIteratorTermMain(
167             main_index, term, term_start_index, unnormalized_term_length,
168             section_restrict_mask, need_hit_term_frequency) {}
169 
170   std::string ToString() const override;
171 
172  protected:
173   libtextclassifier3::Status RetrieveMoreHits() override;
174 };
175 
176 class DocHitInfoIteratorTermMainPrefix : public DocHitInfoIteratorTermMain {
177  public:
DocHitInfoIteratorTermMainPrefix(MainIndex * main_index,const std::string & term,int term_start_index,int unnormalized_term_length,SectionIdMask section_restrict_mask,bool need_hit_term_frequency)178   explicit DocHitInfoIteratorTermMainPrefix(MainIndex* main_index,
179                                             const std::string& term,
180                                             int term_start_index,
181                                             int unnormalized_term_length,
182                                             SectionIdMask section_restrict_mask,
183                                             bool need_hit_term_frequency)
184       : DocHitInfoIteratorTermMain(
185             main_index, term, term_start_index, unnormalized_term_length,
186             section_restrict_mask, need_hit_term_frequency) {}
187 
188   std::string ToString() const override;
189 
190  protected:
191   libtextclassifier3::Status RetrieveMoreHits() override;
192 
193  private:
194   // Whether or not posting_list_accessor_ holds a posting list chain for
195   // 'term' or for a term for which 'term' is a prefix. This is necessary to
196   // determine whether to return hits that are not from a prefix section (hits
197   // not from a prefix section should only be returned if exact_ is true).
198   bool exact_;
199 };
200 
201 }  // namespace lib
202 }  // namespace icing
203 
204 #endif  // ICING_INDEX_ITERATOR_DOC_HIT_INFO_ITERATOR_TERM_MAIN_H_
205