xref: /aosp_15_r20/external/icing/icing/index/iterator/doc-hit-info-iterator-test-util.h (revision 8b6cd535a057e39b3b86660c4aa06c99747c2136)
1 // Copyright (C) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef ICING_INDEX_ITERATOR_DOC_HIT_INFO_ITERATOR_TEST_UTIL_H_
16 #define ICING_INDEX_ITERATOR_DOC_HIT_INFO_ITERATOR_TEST_UTIL_H_
17 
18 #include <array>
19 #include <cinttypes>
20 #include <cstdint>
21 #include <cstring>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "icing/text_classifier/lib3/utils/base/status.h"
27 #include "icing/text_classifier/lib3/utils/base/statusor.h"
28 #include "icing/absl_ports/canonical_errors.h"
29 #include "icing/absl_ports/str_cat.h"
30 #include "icing/index/hit/doc-hit-info.h"
31 #include "icing/index/hit/hit.h"
32 #include "icing/index/iterator/doc-hit-info-iterator.h"
33 #include "icing/legacy/core/icing-string-util.h"
34 #include "icing/schema/section.h"
35 #include "icing/store/document-id.h"
36 
37 namespace icing {
38 namespace lib {
39 
40 class DocHitInfoTermFrequencyPair {
41  public:
42   DocHitInfoTermFrequencyPair(
43       const DocHitInfo& doc_hit_info,
44       const Hit::TermFrequencyArray& hit_term_frequency = {})
doc_hit_info_(doc_hit_info)45       : doc_hit_info_(doc_hit_info), hit_term_frequency_(hit_term_frequency) {}
46 
UpdateSection(SectionId section_id,Hit::TermFrequency hit_term_frequency)47   void UpdateSection(SectionId section_id,
48                      Hit::TermFrequency hit_term_frequency) {
49     doc_hit_info_.UpdateSection(section_id);
50     hit_term_frequency_[section_id] = hit_term_frequency;
51   }
52 
MergeSectionsFrom(const DocHitInfoTermFrequencyPair & other)53   void MergeSectionsFrom(const DocHitInfoTermFrequencyPair& other) {
54     SectionIdMask other_mask = other.doc_hit_info_.hit_section_ids_mask();
55     doc_hit_info_.MergeSectionsFrom(other_mask);
56     while (other_mask) {
57       SectionId section_id = __builtin_ctzll(other_mask);
58       hit_term_frequency_[section_id] = other.hit_term_frequency_[section_id];
59       other_mask &= ~(UINT64_C(1) << section_id);
60     }
61   }
62 
doc_hit_info()63   DocHitInfo doc_hit_info() const { return doc_hit_info_; }
64 
hit_term_frequency(SectionId section_id)65   Hit::TermFrequency hit_term_frequency(SectionId section_id) const {
66     return hit_term_frequency_[section_id];
67   }
68 
69   bool operator==(const DocHitInfoTermFrequencyPair& other) const {
70     if (!(doc_hit_info() == other.doc_hit_info())) {
71       return false;
72     }
73     return memcmp(&hit_term_frequency_, &other.hit_term_frequency_,
74                   kTotalNumSections) == 0;
75   }
76 
77  private:
78   DocHitInfo doc_hit_info_;
79   Hit::TermFrequencyArray hit_term_frequency_;
80 };
81 
82 // Dummy class to help with testing. It starts with an kInvalidDocumentId doc
83 // hit info until an Advance is called (like normal DocHitInfoIterators). It
84 // will then proceed to return the doc_hit_infos in order as Advance's are
85 // called. After all doc_hit_infos are returned, Advance will return a NotFound
86 // error (also like normal DocHitInfoIterators).
87 class DocHitInfoIteratorDummy : public DocHitInfoLeafIterator {
88  public:
89   DocHitInfoIteratorDummy() = default;
90   explicit DocHitInfoIteratorDummy(
91       std::vector<DocHitInfoTermFrequencyPair> doc_hit_infos,
92       std::string term = "")
doc_hit_infos_(std::move (doc_hit_infos))93       : doc_hit_infos_(std::move(doc_hit_infos)), term_(std::move(term)) {}
94 
95   explicit DocHitInfoIteratorDummy(const std::vector<DocHitInfo>& doc_hit_infos,
96                                    std::string term = "",
97                                    int term_start_index = 0,
98                                    int unnormalized_term_length = 0)
term_(std::move (term))99       : term_(std::move(term)),
100         term_start_index_(term_start_index),
101         unnormalized_term_length_(unnormalized_term_length) {
102     for (auto& doc_hit_info : doc_hit_infos) {
103       doc_hit_infos_.push_back(DocHitInfoTermFrequencyPair(doc_hit_info));
104     }
105   }
106 
Advance()107   libtextclassifier3::Status Advance() override {
108     ++index_;
109     if (index_ < doc_hit_infos_.size()) {
110       doc_hit_info_ = doc_hit_infos_.at(index_).doc_hit_info();
111       return libtextclassifier3::Status::OK;
112     }
113 
114     return absl_ports::ResourceExhaustedError(
115         "No more DocHitInfos in iterator");
116   }
117 
TrimRightMostNode()118   libtextclassifier3::StatusOr<TrimmedNode> TrimRightMostNode() && override {
119     DocHitInfoIterator::TrimmedNode node = {nullptr, term_, term_start_index_,
120                                             unnormalized_term_length_};
121     return node;
122   }
123 
124   // Imitates behavior of DocHitInfoIteratorTermMain/DocHitInfoIteratorTermLite
125   void PopulateMatchedTermsStats(
126       std::vector<TermMatchInfo>* matched_terms_stats,
127       SectionIdMask filtering_section_mask = kSectionIdMaskAll) const override {
128     if (index_ == -1 || index_ >= doc_hit_infos_.size()) {
129       // Current hit isn't valid, return.
130       return;
131     }
132     SectionIdMask section_mask =
133         doc_hit_info_.hit_section_ids_mask() & filtering_section_mask;
134     SectionIdMask section_mask_copy = section_mask;
135     std::array<Hit::TermFrequency, kTotalNumSections> section_term_frequencies =
136         {Hit::kNoTermFrequency};
137     while (section_mask_copy) {
138       SectionId section_id = __builtin_ctzll(section_mask_copy);
139       section_term_frequencies.at(section_id) =
140           doc_hit_infos_.at(index_).hit_term_frequency(section_id);
141       section_mask_copy &= ~(UINT64_C(1) << section_id);
142     }
143     TermMatchInfo term_stats(term_, section_mask,
144                              std::move(section_term_frequencies));
145 
146     for (auto& cur_term_stats : *matched_terms_stats) {
147       if (cur_term_stats.term == term_stats.term) {
148         // Same docId and same term, we don't need to add the term and the term
149         // frequency should always be the same
150         return;
151       }
152     }
153     matched_terms_stats->push_back(term_stats);
154   }
155 
set_hit_section_ids_mask(SectionIdMask hit_section_ids_mask)156   void set_hit_section_ids_mask(SectionIdMask hit_section_ids_mask) {
157     doc_hit_info_.set_hit_section_ids_mask(hit_section_ids_mask);
158   }
159 
GetCallStats()160   CallStats GetCallStats() const override { return call_stats_; }
161 
SetCallStats(CallStats call_stats)162   void SetCallStats(CallStats call_stats) {
163     call_stats_ = std::move(call_stats);
164   }
165 
ToString()166   std::string ToString() const override {
167     std::string ret = "<";
168     for (auto& doc_hit_info_pair : doc_hit_infos_) {
169       absl_ports::StrAppend(
170           &ret, IcingStringUtil::StringPrintf(
171                     "[%d,%" PRIu64 "]",
172                     doc_hit_info_pair.doc_hit_info().document_id(),
173                     doc_hit_info_pair.doc_hit_info().hit_section_ids_mask()));
174     }
175     absl_ports::StrAppend(&ret, ">");
176     return ret;
177   }
178 
179  private:
180   int32_t index_ = -1;
181   CallStats call_stats_;
182   std::vector<DocHitInfoTermFrequencyPair> doc_hit_infos_;
183   std::string term_;
184   int term_start_index_;
185   int unnormalized_term_length_;
186 };
187 
GetDocumentIds(DocHitInfoIterator * iterator)188 inline std::vector<DocumentId> GetDocumentIds(DocHitInfoIterator* iterator) {
189   std::vector<DocumentId> ids;
190   while (iterator->Advance().ok()) {
191     ids.push_back(iterator->doc_hit_info().document_id());
192   }
193   return ids;
194 }
195 
GetDocHitInfos(DocHitInfoIterator * iterator)196 inline std::vector<DocHitInfo> GetDocHitInfos(DocHitInfoIterator* iterator) {
197   std::vector<DocHitInfo> doc_hit_infos;
198   while (iterator->Advance().ok()) {
199     doc_hit_infos.push_back(iterator->doc_hit_info());
200   }
201   return doc_hit_infos;
202 }
203 
204 }  // namespace lib
205 }  // namespace icing
206 
207 #endif  // ICING_INDEX_ITERATOR_DOC_HIT_INFO_ITERATOR_TEST_UTIL_H_
208