xref: /aosp_15_r20/external/icing/icing/index/embed/embedding-index.h (revision 8b6cd535a057e39b3b86660c4aa06c99747c2136)
1 // Copyright (C) 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef ICING_INDEX_EMBED_EMBEDDING_INDEX_H_
16 #define ICING_INDEX_EMBED_EMBEDDING_INDEX_H_
17 
18 #include <cstdint>
19 #include <memory>
20 #include <string>
21 #include <string_view>
22 #include <utility>
23 #include <vector>
24 
25 #include "icing/text_classifier/lib3/utils/base/status.h"
26 #include "icing/text_classifier/lib3/utils/base/statusor.h"
27 #include "icing/absl_ports/canonical_errors.h"
28 #include "icing/feature-flags.h"
29 #include "icing/file/file-backed-vector.h"
30 #include "icing/file/filesystem.h"
31 #include "icing/file/memory-mapped-file.h"
32 #include "icing/file/persistent-storage.h"
33 #include "icing/file/posting_list/flash-index-storage.h"
34 #include "icing/file/posting_list/posting-list-identifier.h"
35 #include "icing/index/embed/embedding-hit.h"
36 #include "icing/index/embed/embedding-scorer.h"
37 #include "icing/index/embed/posting-list-embedding-hit-accessor.h"
38 #include "icing/index/embed/posting-list-embedding-hit-serializer.h"
39 #include "icing/index/embed/quantizer.h"
40 #include "icing/index/hit/hit.h"
41 #include "icing/schema/schema-store.h"
42 #include "icing/store/document-id.h"
43 #include "icing/store/document-store.h"
44 #include "icing/store/key-mapper.h"
45 #include "icing/util/clock.h"
46 #include "icing/util/crc32.h"
47 #include "icing/util/logging.h"
48 
49 namespace icing {
50 namespace lib {
51 
52 class EmbeddingIndex : public PersistentStorage {
53  public:
54   struct Info {
55     static constexpr int32_t kMagic = 0x61e7cbf1;
56 
57     int32_t magic;
58     DocumentId last_added_document_id;
59     bool is_empty;
60 
61     static constexpr int kPaddingSize = 1000;
62     // Padding exists just to reserve space for additional values.
63     uint8_t padding_[kPaddingSize];
64 
GetChecksumInfo65     Crc32 GetChecksum() const {
66       return Crc32(
67           std::string_view(reinterpret_cast<const char*>(this), sizeof(Info)));
68     }
69   };
70   static_assert(sizeof(Info) == 1012, "");
71 
72   // Metadata file layout: <Crcs><Info>
73   static constexpr int32_t kCrcsMetadataBufferOffset = 0;
74   static constexpr int32_t kInfoMetadataBufferOffset =
75       static_cast<int32_t>(sizeof(Crcs));
76   static constexpr int32_t kMetadataFileSize = sizeof(Crcs) + sizeof(Info);
77   static_assert(kMetadataFileSize == 1024, "");
78 
79   static constexpr WorkingPathType kWorkingPathType =
80       WorkingPathType::kDirectory;
81 
82   EmbeddingIndex(const EmbeddingIndex&) = delete;
83   EmbeddingIndex& operator=(const EmbeddingIndex&) = delete;
84 
85   // Creates a new EmbeddingIndex instance to index embeddings.
86   //
87   // Returns:
88   //   - FAILED_PRECONDITION_ERROR if the file checksum doesn't match the stored
89   //                               checksum.
90   //   - INTERNAL_ERROR on I/O errors.
91   //   - Any error from MemoryMappedFile, FlashIndexStorage,
92   //     DynamicTrieKeyMapper, or FileBackedVector.
93   static libtextclassifier3::StatusOr<std::unique_ptr<EmbeddingIndex>> Create(
94       const Filesystem* filesystem, std::string working_path,
95       const Clock* clock, const FeatureFlags* feature_flags);
96 
Discard(const Filesystem & filesystem,const std::string & working_path)97   static libtextclassifier3::Status Discard(const Filesystem& filesystem,
98                                             const std::string& working_path) {
99     return PersistentStorage::Discard(filesystem, working_path,
100                                       kWorkingPathType);
101   }
102 
103   libtextclassifier3::Status Clear();
104 
~EmbeddingIndex()105   ~EmbeddingIndex() override {
106     if (!PersistToDisk().ok()) {
107       ICING_LOG(WARNING)
108           << "Failed to persist embedding index to disk while destructing "
109           << working_path_;
110     }
111   }
112 
113   // Buffer an embedding pending to be added to the index. This is required
114   // since EmbeddingHits added in posting lists must be decreasing, which means
115   // that section ids and location indexes for a single document must be added
116   // decreasingly.
117   //
118   // Returns:
119   //   - OK on success
120   //   - INVALID_ARGUMENT error if the dimension is 0.
121   //   - INTERNAL_ERROR on I/O error
122   libtextclassifier3::Status BufferEmbedding(
123       const BasicHit& basic_hit, const PropertyProto::VectorProto& vector,
124       EmbeddingIndexingConfig::QuantizationType::Code quantization_type);
125 
126   // Commit the embedding hits in the buffer to the index.
127   //
128   // Returns:
129   //   - OK on success
130   //   - INTERNAL_ERROR on I/O error
131   //   - Any error from posting lists
132   libtextclassifier3::Status CommitBufferToIndex();
133 
134   // Returns a PostingListEmbeddingHitAccessor for all embedding hits that match
135   // with the provided dimension and signature.
136   //
137   // Returns:
138   //   - a PostingListEmbeddingHitAccessor instance on success.
139   //   - INVALID_ARGUMENT error if the dimension is 0.
140   //   - NOT_FOUND error if there is no matching embedding hit.
141   //   - Any error from posting lists.
142   libtextclassifier3::StatusOr<std::unique_ptr<PostingListEmbeddingHitAccessor>>
143   GetAccessor(uint32_t dimension, std::string_view model_signature) const;
144 
145   // Returns a PostingListEmbeddingHitAccessor for all embedding hits that match
146   // with the provided vector's dimension and signature.
147   //
148   // Returns:
149   //   - a PostingListEmbeddingHitAccessor instance on success.
150   //   - INVALID_ARGUMENT error if the dimension is 0.
151   //   - NOT_FOUND error if there is no matching embedding hit.
152   //   - Any error from posting lists.
153   libtextclassifier3::StatusOr<std::unique_ptr<PostingListEmbeddingHitAccessor>>
GetAccessorForVector(const PropertyProto::VectorProto & vector)154   GetAccessorForVector(const PropertyProto::VectorProto& vector) const {
155     return GetAccessor(vector.values_size(), vector.model_signature());
156   }
157 
158   // Reduces internal file sizes by reclaiming space of deleted documents.
159   // new_last_added_document_id will be used to update the last added document
160   // id in the lite index.
161   //
162   // Returns:
163   //   - OK on success
164   //   - INTERNAL_ERROR on IO error, this indicates that the index may be in an
165   //     invalid state and should be cleared.
166   libtextclassifier3::Status Optimize(
167       const DocumentStore* document_store, const SchemaStore* schema_store,
168       const std::vector<DocumentId>& document_id_old_to_new,
169       DocumentId new_last_added_document_id);
170 
171   // Returns a pointer to the embedding vector for the given hit.
172   //
173   // Returns:
174   //   - a pointer to the embedding vector on success.
175   //   - OUT_OF_RANGE error if the referred vector is out of range based on the
176   //     location and dimension.
GetEmbeddingVector(const EmbeddingHit & hit,uint32_t dimension)177   libtextclassifier3::StatusOr<const float*> GetEmbeddingVector(
178       const EmbeddingHit& hit, uint32_t dimension) const {
179     if (static_cast<int64_t>(hit.location()) + dimension >
180         GetTotalVectorSize()) {
181       return absl_ports::OutOfRangeError(
182           "Got an embedding hit that refers to a vector out of range.");
183     }
184     return embedding_vectors_->array() + hit.location();
185   }
GetQuantizedEmbeddingVector(const EmbeddingHit & hit,uint32_t dimension)186   libtextclassifier3::StatusOr<const char*> GetQuantizedEmbeddingVector(
187       const EmbeddingHit& hit, uint32_t dimension) const {
188     // quantized_embedding_vectors_ stores data in char format. Every quantized
189     // embedding vector contains a Quantizer header followed by the actual
190     // vector, and every value in the vector is stored in uint8_t.
191     if (static_cast<int64_t>(hit.location()) + sizeof(Quantizer) +
192             sizeof(uint8_t) * dimension >
193         GetTotalQuantizedVectorSize()) {
194       return absl_ports::OutOfRangeError(
195           "Got an embedding hit that refers to a vector out of range.");
196     }
197     return quantized_embedding_vectors_->array() + hit.location();
198   }
199 
200   // Calculates the score for the given embedding hit with the given query.
201   //
202   // Returns:
203   //   - The score on success.
204   //   - OUT_OF_RANGE_ERROR if the referred vector is out of range based on the
205   //     location and dimension.
206   //   - Any error from schema store when getting the quantization type.
207   libtextclassifier3::StatusOr<float> ScoreEmbeddingHit(
208       const EmbeddingScorer& scorer, const PropertyProto::VectorProto& query,
209       const EmbeddingHit& hit,
210       EmbeddingIndexingConfig::QuantizationType::Code quantization_type) const;
211 
GetRawEmbeddingData()212   libtextclassifier3::StatusOr<const float*> GetRawEmbeddingData() const {
213     if (is_empty()) {
214       return absl_ports::NotFoundError("EmbeddingIndex is empty");
215     }
216     return embedding_vectors_->array();
217   }
218 
GetTotalVectorSize()219   int32_t GetTotalVectorSize() const {
220     if (is_empty()) {
221       return 0;
222     }
223     return embedding_vectors_->num_elements();
224   }
225 
GetTotalQuantizedVectorSize()226   int32_t GetTotalQuantizedVectorSize() const {
227     if (is_empty()) {
228       return 0;
229     }
230     return quantized_embedding_vectors_->num_elements();
231   }
232 
last_added_document_id()233   DocumentId last_added_document_id() const {
234     return info().last_added_document_id;
235   }
236 
set_last_added_document_id(DocumentId document_id)237   void set_last_added_document_id(DocumentId document_id) {
238     Info& info_ref = info();
239     if (info_ref.last_added_document_id == kInvalidDocumentId ||
240         document_id > info_ref.last_added_document_id) {
241       info_ref.last_added_document_id = document_id;
242     }
243   }
244 
is_empty()245   bool is_empty() const { return info().is_empty; }
246 
247  private:
EmbeddingIndex(const Filesystem & filesystem,std::string working_path,const Clock * clock,const FeatureFlags * feature_flags)248   explicit EmbeddingIndex(const Filesystem& filesystem,
249                           std::string working_path, const Clock* clock,
250                           const FeatureFlags* feature_flags)
251       : PersistentStorage(filesystem, std::move(working_path),
252                           kWorkingPathType),
253         clock_(*clock),
254         feature_flags_(feature_flags) {}
255 
256   // Creates the storage data if the index is not empty. This will initialize
257   // flash_index_storage_, embedding_posting_list_mapper_, embedding_vectors_.
258   //
259   // Returns:
260   //   - OK on success
261   //   - Any error from FlashIndexStorage, DynamicTrieKeyMapper, or
262   //     FileBackedVector.
263   libtextclassifier3::Status CreateStorageDataIfNonEmpty();
264 
265   // Marks the index's header to indicate that the index is non-empty.
266   //
267   // If the index is already marked as non-empty, this is a no-op. Otherwise,
268   // CreateStorageDataIfNonEmpty will be called to create the storage data.
269   //
270   // Returns:
271   //   - OK on success
272   //   - Any error when calling CreateStorageDataIfNonEmpty.
273   libtextclassifier3::Status MarkIndexNonEmpty();
274 
275   libtextclassifier3::Status Initialize();
276 
277   // Transfers the embedding vector of the given hit from the current index to
278   // the new index.
279   //
280   // Returns:
281   //   - The location of the transferred vector in the new index on success.
282   //   - Any error when allocating the vector storage in the new index.
283   libtextclassifier3::StatusOr<uint32_t> TransferEmbeddingVector(
284       const EmbeddingHit& old_hit, uint32_t dimension,
285       EmbeddingIndexingConfig::QuantizationType::Code quantization_type,
286       EmbeddingIndex* new_index) const;
287 
288   // Transfers embedding data and hits from the current index to new_index.
289   //
290   // Returns:
291   //   - OK on success
292   //   - FAILED_PRECONDITION_ERROR if the current index is empty.
293   //   - INTERNAL_ERROR on I/O error. This could potentially leave the storages
294   //     in an invalid state and the caller should handle it properly (e.g.
295   //     discard and rebuild)
296   libtextclassifier3::Status TransferIndex(
297       const DocumentStore& document_store, const SchemaStore& schema_store,
298       const std::vector<DocumentId>& document_id_old_to_new,
299       EmbeddingIndex* new_index) const;
300 
301   libtextclassifier3::Status PersistMetadataToDisk() override;
302 
303   libtextclassifier3::Status PersistStoragesToDisk() override;
304 
WriteMetadata()305   libtextclassifier3::Status WriteMetadata() override {
306     // EmbeddingIndex::Header is mmapped. Therefore, writes occur when the
307     // metadata is modified. So just return OK.
308     return libtextclassifier3::Status::OK;
309   }
310 
311   libtextclassifier3::StatusOr<Crc32> UpdateStoragesChecksum() override;
312 
GetInfoChecksum()313   libtextclassifier3::StatusOr<Crc32> GetInfoChecksum() const override {
314     return info().GetChecksum();
315   }
316 
317   libtextclassifier3::StatusOr<Crc32> GetStoragesChecksum() const override;
318 
319   // Appends the given embedding vector to the appropriate vector storage
320   // (embedding_vectors_ or quantized_embedding_vectors_) based on the
321   // quantization type.
322   //
323   // Returns:
324   //   - The location of the appended vector (i.e., the starting index within
325   //     the vector storage).
326   //   - Any error when allocating the vector storage.
327   libtextclassifier3::StatusOr<uint32_t> AppendEmbeddingVector(
328       const PropertyProto::VectorProto& vector,
329       EmbeddingIndexingConfig::QuantizationType::Code quantization_type);
330 
crcs()331   Crcs& crcs() override {
332     return *reinterpret_cast<Crcs*>(metadata_mmapped_file_->mutable_region() +
333                                     kCrcsMetadataBufferOffset);
334   }
335 
crcs()336   const Crcs& crcs() const override {
337     return *reinterpret_cast<const Crcs*>(metadata_mmapped_file_->region() +
338                                           kCrcsMetadataBufferOffset);
339   }
340 
info()341   Info& info() {
342     return *reinterpret_cast<Info*>(metadata_mmapped_file_->mutable_region() +
343                                     kInfoMetadataBufferOffset);
344   }
345 
info()346   const Info& info() const {
347     return *reinterpret_cast<const Info*>(metadata_mmapped_file_->region() +
348                                           kInfoMetadataBufferOffset);
349   }
350 
351   const Clock& clock_;
352   const FeatureFlags* feature_flags_;  // Does not own.
353 
354   // In memory data:
355   // Pending embedding hits with their embedding keys used for
356   // embedding_posting_list_mapper_.
357   std::vector<std::pair<std::string, EmbeddingHit>> pending_embedding_hits_;
358 
359   // Metadata
360   std::unique_ptr<MemoryMappedFile> metadata_mmapped_file_;
361 
362   // Posting list storage
363   std::unique_ptr<PostingListEmbeddingHitSerializer>
364       posting_list_hit_serializer_ =
365           std::make_unique<PostingListEmbeddingHitSerializer>();
366 
367   // null if the index is empty.
368   std::unique_ptr<FlashIndexStorage> flash_index_storage_;
369 
370   // The mapper from embedding keys to the corresponding posting list identifier
371   // that stores all embedding hits with the same key.
372   //
373   // The key for an embedding hit is a one-to-one encoded string of the ordered
374   // pair (dimension, model_signature) corresponding to the embedding.
375   //
376   // null if the index is empty.
377   std::unique_ptr<KeyMapper<PostingListIdentifier>>
378       embedding_posting_list_mapper_;
379 
380   // A single FileBackedVector that holds all embedding vectors.
381   //
382   // null if the index is empty.
383   std::unique_ptr<FileBackedVector<float>> embedding_vectors_;
384   std::unique_ptr<FileBackedVector<char>> quantized_embedding_vectors_;
385 };
386 
387 }  // namespace lib
388 }  // namespace icing
389 
390 #endif  // ICING_INDEX_EMBED_EMBEDDING_INDEX_H_
391