xref: /aosp_15_r20/external/icing/icing/testing/embedding-test-utils.cc (revision 8b6cd535a057e39b3b86660c4aa06c99747c2136)
1 // Copyright (C) 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "icing/testing/embedding-test-utils.h"
16 
17 #include <cstdint>
18 #include <memory>
19 #include <string_view>
20 #include <utility>
21 #include <vector>
22 
23 #include "icing/text_classifier/lib3/utils/base/statusor.h"
24 #include "icing/absl_ports/canonical_errors.h"
25 #include "icing/index/embed/embedding-hit.h"
26 #include "icing/index/embed/embedding-index.h"
27 #include "icing/index/embed/posting-list-embedding-hit-accessor.h"
28 #include "icing/index/embed/quantizer.h"
29 #include "icing/proto/document.pb.h"
30 #include "icing/util/status-macros.h"
31 
32 namespace icing {
33 namespace lib {
34 
35 libtextclassifier3::StatusOr<std::vector<EmbeddingHit>>
GetEmbeddingHitsFromIndex(const EmbeddingIndex * embedding_index,uint32_t dimension,std::string_view model_signature)36 GetEmbeddingHitsFromIndex(const EmbeddingIndex* embedding_index,
37                           uint32_t dimension,
38                           std::string_view model_signature) {
39   std::vector<EmbeddingHit> hits;
40 
41   libtextclassifier3::StatusOr<std::unique_ptr<PostingListEmbeddingHitAccessor>>
42       pl_accessor_or = embedding_index->GetAccessor(dimension, model_signature);
43   if (absl_ports::IsNotFound(pl_accessor_or.status())) {
44     return hits;
45   }
46   ICING_ASSIGN_OR_RETURN(
47       std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor,
48       std::move(pl_accessor_or));
49 
50   while (true) {
51     ICING_ASSIGN_OR_RETURN(std::vector<EmbeddingHit> batch,
52                            pl_accessor->GetNextHitsBatch());
53     if (batch.empty()) {
54       return hits;
55     }
56     hits.insert(hits.end(), batch.begin(), batch.end());
57   }
58 }
59 
GetRawEmbeddingDataFromIndex(const EmbeddingIndex * embedding_index)60 std::vector<float> GetRawEmbeddingDataFromIndex(
61     const EmbeddingIndex* embedding_index) {
62   ICING_ASSIGN_OR_RETURN(const float* data,
63                          embedding_index->GetRawEmbeddingData(),
64                          std::vector<float>());
65   return std::vector<float>(data, data + embedding_index->GetTotalVectorSize());
66 }
67 
68 libtextclassifier3::StatusOr<std::vector<float>>
GetAndRestoreQuantizedEmbeddingVectorFromIndex(const EmbeddingIndex * embedding_index,const EmbeddingHit & hit,uint32_t dimension)69 GetAndRestoreQuantizedEmbeddingVectorFromIndex(
70     const EmbeddingIndex* embedding_index, const EmbeddingHit& hit,
71     uint32_t dimension) {
72   ICING_ASSIGN_OR_RETURN(
73       const char* data,
74       embedding_index->GetQuantizedEmbeddingVector(hit, dimension));
75   Quantizer quantizer(data);
76   const uint8_t* quantized_vector =
77       reinterpret_cast<const uint8_t*>(data + sizeof(Quantizer));
78   std::vector<float> result;
79   result.reserve(dimension);
80   for (int i = 0; i < dimension; ++i) {
81     result.push_back(quantizer.Dequantize(quantized_vector[i]));
82   }
83   return result;
84 }
85 
86 }  // namespace lib
87 }  // namespace icing
88