1 // Copyright (C) 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #ifndef ICING_INDEX_EMBED_POSTING_LIST_EMBEDDING_HIT_SERIALIZER_H_
16 #define ICING_INDEX_EMBED_POSTING_LIST_EMBEDDING_HIT_SERIALIZER_H_
17
18 #include <cstdint>
19 #include <vector>
20
21 #include "icing/text_classifier/lib3/utils/base/status.h"
22 #include "icing/text_classifier/lib3/utils/base/statusor.h"
23 #include "icing/file/posting_list/posting-list-common.h"
24 #include "icing/file/posting_list/posting-list-used.h"
25 #include "icing/index/embed/embedding-hit.h"
26 #include "icing/util/status-macros.h"
27
28 namespace icing {
29 namespace lib {
30
31 // A serializer class to serialize hits to PostingListUsed. Layout described in
32 // comments in posting-list-embedding-hit-serializer.cc.
33 class PostingListEmbeddingHitSerializer : public PostingListSerializer {
34 public:
35 static constexpr uint32_t kSpecialHitsSize =
36 kNumSpecialData * sizeof(EmbeddingHit);
37
GetDataTypeBytes()38 uint32_t GetDataTypeBytes() const override { return sizeof(EmbeddingHit); }
39
GetMinPostingListSize()40 uint32_t GetMinPostingListSize() const override {
41 static constexpr uint32_t kMinPostingListSize = kSpecialHitsSize;
42 static_assert(sizeof(PostingListIndex) <= kMinPostingListSize,
43 "PostingListIndex must be small enough to fit in a "
44 "minimum-sized Posting List.");
45
46 return kMinPostingListSize;
47 }
48
49 uint32_t GetMinPostingListSizeToFit(
50 const PostingListUsed* posting_list_used) const override;
51
52 uint32_t GetBytesUsed(
53 const PostingListUsed* posting_list_used) const override;
54
55 void Clear(PostingListUsed* posting_list_used) const override;
56
57 libtextclassifier3::Status MoveFrom(PostingListUsed* dst,
58 PostingListUsed* src) const override;
59
60 // Prepend a hit to the posting list.
61 //
62 // RETURNS:
63 // - INVALID_ARGUMENT if !hit.is_valid() or if hit is not less than the
64 // previously added hit.
65 // - RESOURCE_EXHAUSTED if there is no more room to add hit to the posting
66 // list.
67 libtextclassifier3::Status PrependHit(PostingListUsed* posting_list_used,
68 const EmbeddingHit& hit) const;
69
70 // Prepend hits to the posting list. Hits should be sorted in descending order
71 // (as defined by the less than operator for Hit)
72 //
73 // Returns the number of hits that could be prepended to the posting list. If
74 // keep_prepended is true, whatever could be prepended is kept, otherwise the
75 // posting list is left in its original state.
76 template <class T, EmbeddingHit (*GetHit)(const T&)>
77 libtextclassifier3::StatusOr<uint32_t> PrependHitArray(
78 PostingListUsed* posting_list_used, const T* array, uint32_t num_hits,
79 bool keep_prepended) const;
80
81 // Retrieves the hits stored in the posting list.
82 //
83 // RETURNS:
84 // - On success, a vector of hits sorted by the reverse order of prepending.
85 // - INTERNAL_ERROR if the posting list has been corrupted somehow.
86 libtextclassifier3::StatusOr<std::vector<EmbeddingHit>> GetHits(
87 const PostingListUsed* posting_list_used) const;
88
89 // Same as GetHits but appends hits to hits_out.
90 //
91 // RETURNS:
92 // - On success, a vector of hits sorted by the reverse order of prepending.
93 // - INTERNAL_ERROR if the posting list has been corrupted somehow.
94 libtextclassifier3::Status GetHits(const PostingListUsed* posting_list_used,
95 std::vector<EmbeddingHit>* hits_out) const;
96
97 // Undo the last num_hits hits prepended. If num_hits > number of
98 // hits we clear all hits.
99 //
100 // RETURNS:
101 // - OK on success
102 // - INTERNAL_ERROR if the posting list has been corrupted somehow.
103 libtextclassifier3::Status PopFrontHits(PostingListUsed* posting_list_used,
104 uint32_t num_hits) const;
105
106 private:
107 // Posting list layout formats:
108 //
109 // not_full
110 //
111 // +-----------------+----------------+-------+-----------------+
112 // |hits-start-offset|Hit::kInvalidVal|xxxxxxx|(compressed) hits|
113 // +-----------------+----------------+-------+-----------------+
114 //
115 // almost_full
116 //
117 // +-----------------+----------------+-------+-----------------+
118 // |Hit::kInvalidVal |1st hit |(pad) |(compressed) hits|
119 // +-----------------+----------------+-------+-----------------+
120 //
121 // full()
122 //
123 // +-----------------+----------------+-------+-----------------+
124 // |1st hit |2nd hit |(pad) |(compressed) hits|
125 // +-----------------+----------------+-------+-----------------+
126 //
127 // The first two uncompressed hits also implicitly encode information about
128 // the size of the compressed hits region.
129 //
130 // 1. If the posting list is NOT_FULL, then
131 // posting_list_buffer_[0] contains the byte offset of the start of the
132 // compressed hits - and, thus, the size of the compressed hits region is
133 // size_in_bytes - posting_list_buffer_[0].
134 //
135 // 2. If posting list is ALMOST_FULL or FULL, then the compressed hits region
136 // starts somewhere between [kSpecialHitsSize, kSpecialHitsSize +
137 // sizeof(EmbeddingHit) - 1] and ends at size_in_bytes - 1.
138
139 // Helpers to determine what state the posting list is in.
IsFull(const PostingListUsed * posting_list_used)140 bool IsFull(const PostingListUsed* posting_list_used) const {
141 return GetSpecialHit(posting_list_used, /*index=*/0).is_valid() &&
142 GetSpecialHit(posting_list_used, /*index=*/1).is_valid();
143 }
144
IsAlmostFull(const PostingListUsed * posting_list_used)145 bool IsAlmostFull(const PostingListUsed* posting_list_used) const {
146 return !GetSpecialHit(posting_list_used, /*index=*/0).is_valid() &&
147 GetSpecialHit(posting_list_used, /*index=*/1).is_valid();
148 }
149
IsEmpty(const PostingListUsed * posting_list_used)150 bool IsEmpty(const PostingListUsed* posting_list_used) const {
151 return GetSpecialHit(posting_list_used, /*index=*/0).value() ==
152 posting_list_used->size_in_bytes() &&
153 !GetSpecialHit(posting_list_used, /*index=*/1).is_valid();
154 }
155
156 // Returns false if both special hits are invalid or if the offset value
157 // stored in the special hit is less than kSpecialHitsSize or greater than
158 // posting_list_used->size_in_bytes(). Returns true, otherwise.
159 bool IsPostingListValid(const PostingListUsed* posting_list_used) const;
160
161 // Prepend hit to a posting list that is in the ALMOST_FULL state.
162 // RETURNS:
163 // - OK, if successful
164 // - INVALID_ARGUMENT if hit is not less than the previously added hit.
165 libtextclassifier3::Status PrependHitToAlmostFull(
166 PostingListUsed* posting_list_used, const EmbeddingHit& hit) const;
167
168 // Prepend hit to a posting list that is in the EMPTY state. This will always
169 // succeed because there are no pre-existing hits and no validly constructed
170 // posting list could fail to fit one hit.
171 void PrependHitToEmpty(PostingListUsed* posting_list_used,
172 const EmbeddingHit& hit) const;
173
174 // Prepend hit to a posting list that is in the NOT_FULL state.
175 // RETURNS:
176 // - OK, if successful
177 // - INVALID_ARGUMENT if hit is not less than the previously added hit.
178 libtextclassifier3::Status PrependHitToNotFull(
179 PostingListUsed* posting_list_used, const EmbeddingHit& hit,
180 uint32_t offset) const;
181
182 // Returns either 0 (full state), sizeof(EmbeddingHit) (almost_full state) or
183 // a byte offset between kSpecialHitsSize and
184 // posting_list_used->size_in_bytes() (inclusive) (not_full state).
185 uint32_t GetStartByteOffset(const PostingListUsed* posting_list_used) const;
186
187 // Sets the special hits to properly reflect what offset is (see layout
188 // comment for further details).
189 //
190 // Returns false if offset > posting_list_used->size_in_bytes() or offset is
191 // (kSpecialHitsSize, sizeof(EmbeddingHit)) or offset is
192 // (sizeof(EmbeddingHit), 0). True, otherwise.
193 bool SetStartByteOffset(PostingListUsed* posting_list_used,
194 uint32_t offset) const;
195
196 // Manipulate padded areas. We never store the same hit value twice
197 // so a delta of 0 is a pad byte.
198
199 // Returns offset of first non-pad byte.
200 uint32_t GetPadEnd(const PostingListUsed* posting_list_used,
201 uint32_t offset) const;
202
203 // Fill padding between offset start and offset end with 0s.
204 // Returns false if end > posting_list_used->size_in_bytes(). True,
205 // otherwise.
206 bool PadToEnd(PostingListUsed* posting_list_used, uint32_t start,
207 uint32_t end) const;
208
209 // Helper for AppendHits/PopFrontHits. Adds limit number of hits to out or all
210 // hits in the posting list if the posting list contains less than limit
211 // number of hits. out can be NULL.
212 //
213 // NOTE: If called with limit=1, pop=true on a posting list that transitioned
214 // from NOT_FULL directly to FULL, GetHitsInternal will not return the posting
215 // list to NOT_FULL. Instead it will leave it in a valid state, but it will be
216 // ALMOST_FULL.
217 //
218 // RETURNS:
219 // - OK on success
220 // - INTERNAL_ERROR if the posting list has been corrupted somehow.
221 libtextclassifier3::Status GetHitsInternal(
222 const PostingListUsed* posting_list_used, uint32_t limit, bool pop,
223 std::vector<EmbeddingHit>* out) const;
224
225 // Retrieves the value stored in the index-th special hit.
226 //
227 // REQUIRES:
228 // 0 <= index < kNumSpecialData.
229 //
230 // RETURNS:
231 // - A valid SpecialData<EmbeddingHit>.
232 EmbeddingHit GetSpecialHit(const PostingListUsed* posting_list_used,
233 uint32_t index) const;
234
235 // Sets the value stored in the index-th special hit to val.
236 //
237 // REQUIRES:
238 // 0 <= index < kNumSpecialData.
239 void SetSpecialHit(PostingListUsed* posting_list_used, uint32_t index,
240 const EmbeddingHit& val) const;
241
242 // Prepends hit to the memory region [offset - sizeof(EmbeddingHit), offset]
243 // and returns the new beginning of the padded region.
244 //
245 // RETURNS:
246 // - The new beginning of the padded region, if successful.
247 // - INVALID_ARGUMENT if hit will not fit (uncompressed) between offset and
248 // kSpecialHitsSize
249 libtextclassifier3::StatusOr<uint32_t> PrependHitUncompressed(
250 PostingListUsed* posting_list_used, const EmbeddingHit& hit,
251 uint32_t offset) const;
252 };
253
254 // Inlined functions. Implementation details below. Avert eyes!
255 template <class T, EmbeddingHit (*GetHit)(const T&)>
256 libtextclassifier3::StatusOr<uint32_t>
PrependHitArray(PostingListUsed * posting_list_used,const T * array,uint32_t num_hits,bool keep_prepended)257 PostingListEmbeddingHitSerializer::PrependHitArray(
258 PostingListUsed* posting_list_used, const T* array, uint32_t num_hits,
259 bool keep_prepended) const {
260 if (!IsPostingListValid(posting_list_used)) {
261 return 0;
262 }
263
264 // Prepend hits working backwards from array[num_hits - 1].
265 uint32_t i;
266 for (i = 0; i < num_hits; ++i) {
267 if (!PrependHit(posting_list_used, GetHit(array[num_hits - i - 1])).ok()) {
268 break;
269 }
270 }
271 if (i != num_hits && !keep_prepended) {
272 // Didn't fit. Undo everything and check that we have the same offset as
273 // before. PopFrontHits guarantees that it will remove all 'i' hits so long
274 // as there are at least 'i' hits in the posting list, which we know there
275 // are.
276 ICING_RETURN_IF_ERROR(PopFrontHits(posting_list_used, /*num_hits=*/i));
277 }
278 return i;
279 }
280
281 } // namespace lib
282 } // namespace icing
283
284 #endif // ICING_INDEX_EMBED_POSTING_LIST_EMBEDDING_HIT_SERIALIZER_H_
285