1 // Copyright (C) 2024 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef ICING_INDEX_EMBED_DOC_HIT_INFO_ITERATOR_EMBEDDING_H_ 16 #define ICING_INDEX_EMBED_DOC_HIT_INFO_ITERATOR_EMBEDDING_H_ 17 18 #include <cstdint> 19 #include <memory> 20 #include <string> 21 #include <string_view> 22 #include <utility> 23 #include <vector> 24 25 #include "icing/text_classifier/lib3/utils/base/status.h" 26 #include "icing/text_classifier/lib3/utils/base/statusor.h" 27 #include "icing/absl_ports/canonical_errors.h" 28 #include "icing/index/embed/embedding-hit.h" 29 #include "icing/index/embed/embedding-index.h" 30 #include "icing/index/embed/embedding-query-results.h" 31 #include "icing/index/embed/embedding-scorer.h" 32 #include "icing/index/embed/posting-list-embedding-hit-accessor.h" 33 #include "icing/index/iterator/doc-hit-info-iterator.h" 34 #include "icing/index/iterator/section-restrict-data.h" 35 #include "icing/proto/search.pb.h" 36 #include "icing/schema/schema-store.h" 37 #include "icing/schema/section.h" 38 #include "icing/store/document-filter-data.h" 39 #include "icing/store/document-store.h" 40 41 namespace icing { 42 namespace lib { 43 44 class DocHitInfoIteratorEmbedding 45 : public DocHitInfoIteratorHandlingSectionRestrict { 46 public: 47 // Create a DocHitInfoIterator for iterating through all docs which have an 48 // embedding matched with the provided query with a score in the range of 49 // [score_low, score_high], using the provided metric_type. 50 // 51 // The iterator will store the matched embedding scores in score_map to 52 // prepare for scoring. 53 // 54 // The iterator will handle the section restriction logic internally with the 55 // help of DocHitInfoIteratorHandlingSectionRestrict. 56 // 57 // Returns: 58 // - a DocHitInfoIteratorEmbedding instance on success. 59 // - Any error from posting lists. 60 static libtextclassifier3::StatusOr< 61 std::unique_ptr<DocHitInfoIteratorEmbedding>> 62 Create(const PropertyProto::VectorProto* query, 63 SearchSpecProto::EmbeddingQueryMetricType::Code metric_type, 64 double score_low, double score_high, 65 EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map, 66 const EmbeddingIndex* embedding_index, 67 const DocumentStore* document_store, const SchemaStore* schema_store, 68 int64_t current_time_ms); 69 70 libtextclassifier3::Status Advance() override; 71 TrimRightMostNode()72 libtextclassifier3::StatusOr<TrimmedNode> TrimRightMostNode() && override { 73 return absl_ports::InvalidArgumentError( 74 "Query suggestions for the semanticSearch function are not supported"); 75 } 76 GetCallStats()77 CallStats GetCallStats() const override { 78 return CallStats( 79 /*num_leaf_advance_calls_lite_index_in=*/num_advance_calls_, 80 /*num_leaf_advance_calls_main_index_in=*/0, 81 /*num_leaf_advance_calls_integer_index_in=*/0, 82 /*num_leaf_advance_calls_no_index_in=*/0, 83 /*num_blocks_inspected_in=*/0); 84 } 85 ToString()86 std::string ToString() const override { return "embedding_iterator"; } 87 88 // PopulateMatchedTermsStats is not applicable to embedding search. PopulateMatchedTermsStats(std::vector<TermMatchInfo> * matched_terms_stats,SectionIdMask filtering_section_mask)89 void PopulateMatchedTermsStats( 90 std::vector<TermMatchInfo>* matched_terms_stats, 91 SectionIdMask filtering_section_mask) const override {} 92 93 private: DocHitInfoIteratorEmbedding(const PropertyProto::VectorProto * query,SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,std::unique_ptr<EmbeddingScorer> embedding_scorer,double score_low,double score_high,EmbeddingQueryResults::EmbeddingQueryScoreMap * score_map,const EmbeddingIndex * embedding_index,std::unique_ptr<PostingListEmbeddingHitAccessor> posting_list_accessor,const DocumentStore * document_store,const SchemaStore * schema_store,int64_t current_time_ms)94 explicit DocHitInfoIteratorEmbedding( 95 const PropertyProto::VectorProto* query, 96 SearchSpecProto::EmbeddingQueryMetricType::Code metric_type, 97 std::unique_ptr<EmbeddingScorer> embedding_scorer, double score_low, 98 double score_high, 99 EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map, 100 const EmbeddingIndex* embedding_index, 101 std::unique_ptr<PostingListEmbeddingHitAccessor> posting_list_accessor, 102 const DocumentStore* document_store, const SchemaStore* schema_store, 103 int64_t current_time_ms) 104 : query_(*query), 105 metric_type_(metric_type), 106 embedding_scorer_(std::move(embedding_scorer)), 107 score_low_(score_low), 108 score_high_(score_high), 109 score_map_(*score_map), 110 embedding_index_(*embedding_index), 111 posting_list_accessor_(std::move(posting_list_accessor)), 112 cached_embedding_hits_idx_(0), 113 current_allowed_sections_mask_(kSectionIdMaskAll), 114 no_more_hit_(false), 115 schema_type_id_(kInvalidSchemaTypeId), 116 document_store_(*document_store), 117 schema_store_(*schema_store), 118 current_time_ms_(current_time_ms), 119 num_advance_calls_(0) {} 120 121 // Advance to the next embedding hit of the current document. If the current 122 // document id is kInvalidDocumentId, the method will advance to the first 123 // embedding hit of the next document and update doc_hit_info_. 124 // 125 // This method also properly updates cached_embedding_hits_, 126 // cached_embedding_hits_idx_, current_allowed_sections_mask_, and 127 // no_more_hit_ to reflect the current state. 128 // 129 // Returns: 130 // - a const pointer to the next embedding hit on success. 131 // - nullptr, if there is no more hit for the current document, or no more 132 // hit in general if the current document id is kInvalidDocumentId. 133 // - Any error from posting lists. 134 libtextclassifier3::StatusOr<const EmbeddingHit*> AdvanceToNextEmbeddingHit(); 135 136 // Similar to Advance(), this method advances the iterator to the next 137 // document, but it does not guarantee that the next document will have 138 // a matched embedding hit within the score range. 139 // 140 // Returns: 141 // - OK, if it is able to advance to a new document_id. 142 // - RESOUCE_EXHAUSTED, if we have run out of document_ids to iterate over. 143 // - Any error from posting lists. 144 libtextclassifier3::Status AdvanceToNextUnfilteredDocument(); 145 146 // Query information 147 const PropertyProto::VectorProto& query_; // Does not own 148 149 // Scoring arguments 150 SearchSpecProto::EmbeddingQueryMetricType::Code metric_type_; 151 std::unique_ptr<EmbeddingScorer> embedding_scorer_; 152 double score_low_; 153 double score_high_; 154 155 // Score map 156 EmbeddingQueryResults::EmbeddingQueryScoreMap& score_map_; // Does not own 157 158 // Access to embeddings index data 159 const EmbeddingIndex& embedding_index_; 160 std::unique_ptr<PostingListEmbeddingHitAccessor> posting_list_accessor_; 161 162 // Cached data from the embeddings index 163 std::vector<EmbeddingHit> cached_embedding_hits_; 164 int cached_embedding_hits_idx_; 165 SectionIdMask current_allowed_sections_mask_; 166 bool no_more_hit_; 167 SchemaTypeId schema_type_id_; // The schema type id for the current document. 168 169 const DocumentStore& document_store_; 170 const SchemaStore& schema_store_; 171 int64_t current_time_ms_; 172 int num_advance_calls_; 173 }; 174 175 } // namespace lib 176 } // namespace icing 177 178 #endif // ICING_INDEX_EMBED_DOC_HIT_INFO_ITERATOR_EMBEDDING_H_ 179