1 // Copyright (C) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #ifndef ICING_INDEX_ITERATOR_DOC_HIT_INFO_ITERATOR_TEST_UTIL_H_
16 #define ICING_INDEX_ITERATOR_DOC_HIT_INFO_ITERATOR_TEST_UTIL_H_
17
18 #include <array>
19 #include <cinttypes>
20 #include <cstdint>
21 #include <cstring>
22 #include <string>
23 #include <utility>
24 #include <vector>
25
26 #include "icing/text_classifier/lib3/utils/base/status.h"
27 #include "icing/text_classifier/lib3/utils/base/statusor.h"
28 #include "icing/absl_ports/canonical_errors.h"
29 #include "icing/absl_ports/str_cat.h"
30 #include "icing/index/hit/doc-hit-info.h"
31 #include "icing/index/hit/hit.h"
32 #include "icing/index/iterator/doc-hit-info-iterator.h"
33 #include "icing/legacy/core/icing-string-util.h"
34 #include "icing/schema/section.h"
35 #include "icing/store/document-id.h"
36
37 namespace icing {
38 namespace lib {
39
40 class DocHitInfoTermFrequencyPair {
41 public:
42 DocHitInfoTermFrequencyPair(
43 const DocHitInfo& doc_hit_info,
44 const Hit::TermFrequencyArray& hit_term_frequency = {})
doc_hit_info_(doc_hit_info)45 : doc_hit_info_(doc_hit_info), hit_term_frequency_(hit_term_frequency) {}
46
UpdateSection(SectionId section_id,Hit::TermFrequency hit_term_frequency)47 void UpdateSection(SectionId section_id,
48 Hit::TermFrequency hit_term_frequency) {
49 doc_hit_info_.UpdateSection(section_id);
50 hit_term_frequency_[section_id] = hit_term_frequency;
51 }
52
MergeSectionsFrom(const DocHitInfoTermFrequencyPair & other)53 void MergeSectionsFrom(const DocHitInfoTermFrequencyPair& other) {
54 SectionIdMask other_mask = other.doc_hit_info_.hit_section_ids_mask();
55 doc_hit_info_.MergeSectionsFrom(other_mask);
56 while (other_mask) {
57 SectionId section_id = __builtin_ctzll(other_mask);
58 hit_term_frequency_[section_id] = other.hit_term_frequency_[section_id];
59 other_mask &= ~(UINT64_C(1) << section_id);
60 }
61 }
62
doc_hit_info()63 DocHitInfo doc_hit_info() const { return doc_hit_info_; }
64
hit_term_frequency(SectionId section_id)65 Hit::TermFrequency hit_term_frequency(SectionId section_id) const {
66 return hit_term_frequency_[section_id];
67 }
68
69 bool operator==(const DocHitInfoTermFrequencyPair& other) const {
70 if (!(doc_hit_info() == other.doc_hit_info())) {
71 return false;
72 }
73 return memcmp(&hit_term_frequency_, &other.hit_term_frequency_,
74 kTotalNumSections) == 0;
75 }
76
77 private:
78 DocHitInfo doc_hit_info_;
79 Hit::TermFrequencyArray hit_term_frequency_;
80 };
81
82 // Dummy class to help with testing. It starts with an kInvalidDocumentId doc
83 // hit info until an Advance is called (like normal DocHitInfoIterators). It
84 // will then proceed to return the doc_hit_infos in order as Advance's are
85 // called. After all doc_hit_infos are returned, Advance will return a NotFound
86 // error (also like normal DocHitInfoIterators).
87 class DocHitInfoIteratorDummy : public DocHitInfoLeafIterator {
88 public:
89 DocHitInfoIteratorDummy() = default;
90 explicit DocHitInfoIteratorDummy(
91 std::vector<DocHitInfoTermFrequencyPair> doc_hit_infos,
92 std::string term = "")
doc_hit_infos_(std::move (doc_hit_infos))93 : doc_hit_infos_(std::move(doc_hit_infos)), term_(std::move(term)) {}
94
95 explicit DocHitInfoIteratorDummy(const std::vector<DocHitInfo>& doc_hit_infos,
96 std::string term = "",
97 int term_start_index = 0,
98 int unnormalized_term_length = 0)
term_(std::move (term))99 : term_(std::move(term)),
100 term_start_index_(term_start_index),
101 unnormalized_term_length_(unnormalized_term_length) {
102 for (auto& doc_hit_info : doc_hit_infos) {
103 doc_hit_infos_.push_back(DocHitInfoTermFrequencyPair(doc_hit_info));
104 }
105 }
106
Advance()107 libtextclassifier3::Status Advance() override {
108 ++index_;
109 if (index_ < doc_hit_infos_.size()) {
110 doc_hit_info_ = doc_hit_infos_.at(index_).doc_hit_info();
111 return libtextclassifier3::Status::OK;
112 }
113
114 return absl_ports::ResourceExhaustedError(
115 "No more DocHitInfos in iterator");
116 }
117
TrimRightMostNode()118 libtextclassifier3::StatusOr<TrimmedNode> TrimRightMostNode() && override {
119 DocHitInfoIterator::TrimmedNode node = {nullptr, term_, term_start_index_,
120 unnormalized_term_length_};
121 return node;
122 }
123
124 // Imitates behavior of DocHitInfoIteratorTermMain/DocHitInfoIteratorTermLite
125 void PopulateMatchedTermsStats(
126 std::vector<TermMatchInfo>* matched_terms_stats,
127 SectionIdMask filtering_section_mask = kSectionIdMaskAll) const override {
128 if (index_ == -1 || index_ >= doc_hit_infos_.size()) {
129 // Current hit isn't valid, return.
130 return;
131 }
132 SectionIdMask section_mask =
133 doc_hit_info_.hit_section_ids_mask() & filtering_section_mask;
134 SectionIdMask section_mask_copy = section_mask;
135 std::array<Hit::TermFrequency, kTotalNumSections> section_term_frequencies =
136 {Hit::kNoTermFrequency};
137 while (section_mask_copy) {
138 SectionId section_id = __builtin_ctzll(section_mask_copy);
139 section_term_frequencies.at(section_id) =
140 doc_hit_infos_.at(index_).hit_term_frequency(section_id);
141 section_mask_copy &= ~(UINT64_C(1) << section_id);
142 }
143 TermMatchInfo term_stats(term_, section_mask,
144 std::move(section_term_frequencies));
145
146 for (auto& cur_term_stats : *matched_terms_stats) {
147 if (cur_term_stats.term == term_stats.term) {
148 // Same docId and same term, we don't need to add the term and the term
149 // frequency should always be the same
150 return;
151 }
152 }
153 matched_terms_stats->push_back(term_stats);
154 }
155
set_hit_section_ids_mask(SectionIdMask hit_section_ids_mask)156 void set_hit_section_ids_mask(SectionIdMask hit_section_ids_mask) {
157 doc_hit_info_.set_hit_section_ids_mask(hit_section_ids_mask);
158 }
159
GetCallStats()160 CallStats GetCallStats() const override { return call_stats_; }
161
SetCallStats(CallStats call_stats)162 void SetCallStats(CallStats call_stats) {
163 call_stats_ = std::move(call_stats);
164 }
165
ToString()166 std::string ToString() const override {
167 std::string ret = "<";
168 for (auto& doc_hit_info_pair : doc_hit_infos_) {
169 absl_ports::StrAppend(
170 &ret, IcingStringUtil::StringPrintf(
171 "[%d,%" PRIu64 "]",
172 doc_hit_info_pair.doc_hit_info().document_id(),
173 doc_hit_info_pair.doc_hit_info().hit_section_ids_mask()));
174 }
175 absl_ports::StrAppend(&ret, ">");
176 return ret;
177 }
178
179 private:
180 int32_t index_ = -1;
181 CallStats call_stats_;
182 std::vector<DocHitInfoTermFrequencyPair> doc_hit_infos_;
183 std::string term_;
184 int term_start_index_;
185 int unnormalized_term_length_;
186 };
187
GetDocumentIds(DocHitInfoIterator * iterator)188 inline std::vector<DocumentId> GetDocumentIds(DocHitInfoIterator* iterator) {
189 std::vector<DocumentId> ids;
190 while (iterator->Advance().ok()) {
191 ids.push_back(iterator->doc_hit_info().document_id());
192 }
193 return ids;
194 }
195
GetDocHitInfos(DocHitInfoIterator * iterator)196 inline std::vector<DocHitInfo> GetDocHitInfos(DocHitInfoIterator* iterator) {
197 std::vector<DocHitInfo> doc_hit_infos;
198 while (iterator->Advance().ok()) {
199 doc_hit_infos.push_back(iterator->doc_hit_info());
200 }
201 return doc_hit_infos;
202 }
203
204 } // namespace lib
205 } // namespace icing
206
207 #endif // ICING_INDEX_ITERATOR_DOC_HIT_INFO_ITERATOR_TEST_UTIL_H_
208