1 // Copyright (C) 2019 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef ICING_INDEX_ITERATOR_DOC_HIT_INFO_ITERATOR_TERM_MAIN_H_ 16 #define ICING_INDEX_ITERATOR_DOC_HIT_INFO_ITERATOR_TERM_MAIN_H_ 17 18 #include <cstdint> 19 #include <memory> 20 #include <optional> 21 #include <string> 22 #include <utility> 23 #include <vector> 24 25 #include "icing/text_classifier/lib3/utils/base/status.h" 26 #include "icing/index/hit/doc-hit-info.h" 27 #include "icing/index/hit/hit.h" 28 #include "icing/index/iterator/doc-hit-info-iterator.h" 29 #include "icing/index/main/main-index.h" 30 #include "icing/index/main/posting-list-hit-accessor.h" 31 #include "icing/schema/section.h" 32 33 namespace icing { 34 namespace lib { 35 36 class DocHitInfoIteratorTermMain : public DocHitInfoLeafIterator { 37 public: 38 struct DocHitInfoAndTermFrequencyArray { 39 DocHitInfo doc_hit_info; 40 std::optional<Hit::TermFrequencyArray> term_frequency_array; 41 42 explicit DocHitInfoAndTermFrequencyArray() = default; 43 DocHitInfoAndTermFrequencyArrayDocHitInfoAndTermFrequencyArray44 explicit DocHitInfoAndTermFrequencyArray( 45 DocHitInfo doc_hit_info_in, 46 std::optional<Hit::TermFrequencyArray> term_frequency_array_in) 47 : doc_hit_info(std::move(doc_hit_info_in)), 48 term_frequency_array(std::move(term_frequency_array_in)) {} 49 }; 50 DocHitInfoIteratorTermMain(MainIndex * main_index,const std::string & term,int term_start_index,int unnormalized_term_length,SectionIdMask section_restrict_mask,bool need_hit_term_frequency)51 explicit DocHitInfoIteratorTermMain(MainIndex* main_index, 52 const std::string& term, 53 int term_start_index, 54 int unnormalized_term_length, 55 SectionIdMask section_restrict_mask, 56 bool need_hit_term_frequency) 57 : term_(term), 58 term_start_index_(term_start_index), 59 unnormalized_term_length_(unnormalized_term_length), 60 posting_list_accessor_(nullptr), 61 main_index_(main_index), 62 cached_doc_hit_infos_idx_(-1), 63 num_advance_calls_(0), 64 num_blocks_inspected_(0), 65 all_pages_consumed_(false), 66 section_restrict_mask_(section_restrict_mask), 67 need_hit_term_frequency_(need_hit_term_frequency) {} 68 69 libtextclassifier3::Status Advance() override; 70 71 libtextclassifier3::StatusOr<TrimmedNode> TrimRightMostNode() && override; 72 GetCallStats()73 CallStats GetCallStats() const override { 74 return CallStats( 75 /*num_leaf_advance_calls_lite_index_in=*/0, 76 /*num_leaf_advance_calls_main_index_in=*/num_advance_calls_, 77 /*num_leaf_advance_calls_integer_index_in=*/0, 78 /*num_leaf_advance_calls_no_index_in=*/0, 79 /*num_blocks_inspected_in=*/num_blocks_inspected_); 80 } 81 82 void PopulateMatchedTermsStats( 83 std::vector<TermMatchInfo>* matched_terms_stats, 84 SectionIdMask filtering_section_mask = kSectionIdMaskAll) const override { 85 if (cached_doc_hit_infos_idx_ == -1 || 86 cached_doc_hit_infos_idx_ >= cached_doc_hit_infos_.size()) { 87 // Current hit isn't valid, return. 88 return; 89 } 90 SectionIdMask section_mask = 91 doc_hit_info_.hit_section_ids_mask() & filtering_section_mask; 92 SectionIdMask section_mask_copy = section_mask; 93 std::array<Hit::TermFrequency, kTotalNumSections> section_term_frequencies = 94 {Hit::kNoTermFrequency}; 95 while (section_mask_copy) { 96 SectionId section_id = __builtin_ctzll(section_mask_copy); 97 if (need_hit_term_frequency_) { 98 section_term_frequencies.at(section_id) = 99 (*cached_doc_hit_infos_.at(cached_doc_hit_infos_idx_) 100 .term_frequency_array)[section_id]; 101 } 102 section_mask_copy &= ~(UINT64_C(1) << section_id); 103 } 104 TermMatchInfo term_stats(term_, section_mask, 105 std::move(section_term_frequencies)); 106 107 for (const TermMatchInfo& cur_term_stats : *matched_terms_stats) { 108 if (cur_term_stats.term == term_stats.term) { 109 // Same docId and same term, we don't need to add the term and the term 110 // frequency should always be the same 111 return; 112 } 113 } 114 matched_terms_stats->push_back(std::move(term_stats)); 115 } 116 117 protected: 118 // Add DocHitInfos corresponding to term_ to cached_doc_hit_infos_. 119 virtual libtextclassifier3::Status RetrieveMoreHits() = 0; 120 121 const std::string term_; 122 123 // The start index of the given term in the search query 124 int term_start_index_; 125 // The length of the given unnormalized term in the search query 126 int unnormalized_term_length_; 127 // The accessor of the posting list chain for the requested term. 128 std::unique_ptr<PostingListHitAccessor> posting_list_accessor_; 129 130 MainIndex* main_index_; 131 // Stores hits and optional term frequency arrays retrieved from the index. 132 // This may only be a subset of the hits that are present in the index. 133 // Current value pointed to by the Iterator is tracked by 134 // cached_doc_hit_infos_idx_. 135 std::vector<DocHitInfoAndTermFrequencyArray> cached_doc_hit_infos_; 136 int cached_doc_hit_infos_idx_; 137 138 int num_advance_calls_; 139 int num_blocks_inspected_; 140 bool all_pages_consumed_; 141 // Mask indicating which sections hits should be considered for. 142 // Ex. 0000 0000 0000 0010 means that only hits from section 1 are desired. 143 const SectionIdMask section_restrict_mask_; 144 const bool need_hit_term_frequency_; 145 146 private: 147 // Remaining number of hits including the current hit. 148 // Returns -1 if cached_doc_hit_infos_idx_ is invalid. cached_doc_hit_info_count()149 int cached_doc_hit_info_count() const { 150 if (cached_doc_hit_infos_idx_ == -1 || 151 cached_doc_hit_infos_idx_ >= cached_doc_hit_infos_.size()) { 152 return -1; 153 } 154 return cached_doc_hit_infos_.size() - cached_doc_hit_infos_idx_; 155 } 156 }; 157 158 class DocHitInfoIteratorTermMainExact : public DocHitInfoIteratorTermMain { 159 public: DocHitInfoIteratorTermMainExact(MainIndex * main_index,const std::string & term,int term_start_index,int unnormalized_term_length,SectionIdMask section_restrict_mask,bool need_hit_term_frequency)160 explicit DocHitInfoIteratorTermMainExact(MainIndex* main_index, 161 const std::string& term, 162 int term_start_index, 163 int unnormalized_term_length, 164 SectionIdMask section_restrict_mask, 165 bool need_hit_term_frequency) 166 : DocHitInfoIteratorTermMain( 167 main_index, term, term_start_index, unnormalized_term_length, 168 section_restrict_mask, need_hit_term_frequency) {} 169 170 std::string ToString() const override; 171 172 protected: 173 libtextclassifier3::Status RetrieveMoreHits() override; 174 }; 175 176 class DocHitInfoIteratorTermMainPrefix : public DocHitInfoIteratorTermMain { 177 public: DocHitInfoIteratorTermMainPrefix(MainIndex * main_index,const std::string & term,int term_start_index,int unnormalized_term_length,SectionIdMask section_restrict_mask,bool need_hit_term_frequency)178 explicit DocHitInfoIteratorTermMainPrefix(MainIndex* main_index, 179 const std::string& term, 180 int term_start_index, 181 int unnormalized_term_length, 182 SectionIdMask section_restrict_mask, 183 bool need_hit_term_frequency) 184 : DocHitInfoIteratorTermMain( 185 main_index, term, term_start_index, unnormalized_term_length, 186 section_restrict_mask, need_hit_term_frequency) {} 187 188 std::string ToString() const override; 189 190 protected: 191 libtextclassifier3::Status RetrieveMoreHits() override; 192 193 private: 194 // Whether or not posting_list_accessor_ holds a posting list chain for 195 // 'term' or for a term for which 'term' is a prefix. This is necessary to 196 // determine whether to return hits that are not from a prefix section (hits 197 // not from a prefix section should only be returned if exact_ is true). 198 bool exact_; 199 }; 200 201 } // namespace lib 202 } // namespace icing 203 204 #endif // ICING_INDEX_ITERATOR_DOC_HIT_INFO_ITERATOR_TERM_MAIN_H_ 205