1 // Copyright (C) 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "icing/testing/embedding-test-utils.h"
16
17 #include <cstdint>
18 #include <memory>
19 #include <string_view>
20 #include <utility>
21 #include <vector>
22
23 #include "icing/text_classifier/lib3/utils/base/statusor.h"
24 #include "icing/absl_ports/canonical_errors.h"
25 #include "icing/index/embed/embedding-hit.h"
26 #include "icing/index/embed/embedding-index.h"
27 #include "icing/index/embed/posting-list-embedding-hit-accessor.h"
28 #include "icing/index/embed/quantizer.h"
29 #include "icing/proto/document.pb.h"
30 #include "icing/util/status-macros.h"
31
32 namespace icing {
33 namespace lib {
34
35 libtextclassifier3::StatusOr<std::vector<EmbeddingHit>>
GetEmbeddingHitsFromIndex(const EmbeddingIndex * embedding_index,uint32_t dimension,std::string_view model_signature)36 GetEmbeddingHitsFromIndex(const EmbeddingIndex* embedding_index,
37 uint32_t dimension,
38 std::string_view model_signature) {
39 std::vector<EmbeddingHit> hits;
40
41 libtextclassifier3::StatusOr<std::unique_ptr<PostingListEmbeddingHitAccessor>>
42 pl_accessor_or = embedding_index->GetAccessor(dimension, model_signature);
43 if (absl_ports::IsNotFound(pl_accessor_or.status())) {
44 return hits;
45 }
46 ICING_ASSIGN_OR_RETURN(
47 std::unique_ptr<PostingListEmbeddingHitAccessor> pl_accessor,
48 std::move(pl_accessor_or));
49
50 while (true) {
51 ICING_ASSIGN_OR_RETURN(std::vector<EmbeddingHit> batch,
52 pl_accessor->GetNextHitsBatch());
53 if (batch.empty()) {
54 return hits;
55 }
56 hits.insert(hits.end(), batch.begin(), batch.end());
57 }
58 }
59
GetRawEmbeddingDataFromIndex(const EmbeddingIndex * embedding_index)60 std::vector<float> GetRawEmbeddingDataFromIndex(
61 const EmbeddingIndex* embedding_index) {
62 ICING_ASSIGN_OR_RETURN(const float* data,
63 embedding_index->GetRawEmbeddingData(),
64 std::vector<float>());
65 return std::vector<float>(data, data + embedding_index->GetTotalVectorSize());
66 }
67
68 libtextclassifier3::StatusOr<std::vector<float>>
GetAndRestoreQuantizedEmbeddingVectorFromIndex(const EmbeddingIndex * embedding_index,const EmbeddingHit & hit,uint32_t dimension)69 GetAndRestoreQuantizedEmbeddingVectorFromIndex(
70 const EmbeddingIndex* embedding_index, const EmbeddingHit& hit,
71 uint32_t dimension) {
72 ICING_ASSIGN_OR_RETURN(
73 const char* data,
74 embedding_index->GetQuantizedEmbeddingVector(hit, dimension));
75 Quantizer quantizer(data);
76 const uint8_t* quantized_vector =
77 reinterpret_cast<const uint8_t*>(data + sizeof(Quantizer));
78 std::vector<float> result;
79 result.reserve(dimension);
80 for (int i = 0; i < dimension; ++i) {
81 result.push_back(quantizer.Dequantize(quantized_vector[i]));
82 }
83 return result;
84 }
85
86 } // namespace lib
87 } // namespace icing
88