xref: /aosp_15_r20/external/icing/icing/index/embed/doc-hit-info-iterator-embedding.h (revision 8b6cd535a057e39b3b86660c4aa06c99747c2136)
1 // Copyright (C) 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef ICING_INDEX_EMBED_DOC_HIT_INFO_ITERATOR_EMBEDDING_H_
16 #define ICING_INDEX_EMBED_DOC_HIT_INFO_ITERATOR_EMBEDDING_H_
17 
18 #include <cstdint>
19 #include <memory>
20 #include <string>
21 #include <string_view>
22 #include <utility>
23 #include <vector>
24 
25 #include "icing/text_classifier/lib3/utils/base/status.h"
26 #include "icing/text_classifier/lib3/utils/base/statusor.h"
27 #include "icing/absl_ports/canonical_errors.h"
28 #include "icing/index/embed/embedding-hit.h"
29 #include "icing/index/embed/embedding-index.h"
30 #include "icing/index/embed/embedding-query-results.h"
31 #include "icing/index/embed/embedding-scorer.h"
32 #include "icing/index/embed/posting-list-embedding-hit-accessor.h"
33 #include "icing/index/iterator/doc-hit-info-iterator.h"
34 #include "icing/index/iterator/section-restrict-data.h"
35 #include "icing/proto/search.pb.h"
36 #include "icing/schema/schema-store.h"
37 #include "icing/schema/section.h"
38 #include "icing/store/document-filter-data.h"
39 #include "icing/store/document-store.h"
40 
41 namespace icing {
42 namespace lib {
43 
44 class DocHitInfoIteratorEmbedding
45     : public DocHitInfoIteratorHandlingSectionRestrict {
46  public:
47   // Create a DocHitInfoIterator for iterating through all docs which have an
48   // embedding matched with the provided query with a score in the range of
49   // [score_low, score_high], using the provided metric_type.
50   //
51   // The iterator will store the matched embedding scores in score_map to
52   // prepare for scoring.
53   //
54   // The iterator will handle the section restriction logic internally with the
55   // help of DocHitInfoIteratorHandlingSectionRestrict.
56   //
57   // Returns:
58   //   - a DocHitInfoIteratorEmbedding instance on success.
59   //   - Any error from posting lists.
60   static libtextclassifier3::StatusOr<
61       std::unique_ptr<DocHitInfoIteratorEmbedding>>
62   Create(const PropertyProto::VectorProto* query,
63          SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,
64          double score_low, double score_high,
65          EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map,
66          const EmbeddingIndex* embedding_index,
67          const DocumentStore* document_store, const SchemaStore* schema_store,
68          int64_t current_time_ms);
69 
70   libtextclassifier3::Status Advance() override;
71 
TrimRightMostNode()72   libtextclassifier3::StatusOr<TrimmedNode> TrimRightMostNode() && override {
73     return absl_ports::InvalidArgumentError(
74         "Query suggestions for the semanticSearch function are not supported");
75   }
76 
GetCallStats()77   CallStats GetCallStats() const override {
78     return CallStats(
79         /*num_leaf_advance_calls_lite_index_in=*/num_advance_calls_,
80         /*num_leaf_advance_calls_main_index_in=*/0,
81         /*num_leaf_advance_calls_integer_index_in=*/0,
82         /*num_leaf_advance_calls_no_index_in=*/0,
83         /*num_blocks_inspected_in=*/0);
84   }
85 
ToString()86   std::string ToString() const override { return "embedding_iterator"; }
87 
88   // PopulateMatchedTermsStats is not applicable to embedding search.
PopulateMatchedTermsStats(std::vector<TermMatchInfo> * matched_terms_stats,SectionIdMask filtering_section_mask)89   void PopulateMatchedTermsStats(
90       std::vector<TermMatchInfo>* matched_terms_stats,
91       SectionIdMask filtering_section_mask) const override {}
92 
93  private:
DocHitInfoIteratorEmbedding(const PropertyProto::VectorProto * query,SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,std::unique_ptr<EmbeddingScorer> embedding_scorer,double score_low,double score_high,EmbeddingQueryResults::EmbeddingQueryScoreMap * score_map,const EmbeddingIndex * embedding_index,std::unique_ptr<PostingListEmbeddingHitAccessor> posting_list_accessor,const DocumentStore * document_store,const SchemaStore * schema_store,int64_t current_time_ms)94   explicit DocHitInfoIteratorEmbedding(
95       const PropertyProto::VectorProto* query,
96       SearchSpecProto::EmbeddingQueryMetricType::Code metric_type,
97       std::unique_ptr<EmbeddingScorer> embedding_scorer, double score_low,
98       double score_high,
99       EmbeddingQueryResults::EmbeddingQueryScoreMap* score_map,
100       const EmbeddingIndex* embedding_index,
101       std::unique_ptr<PostingListEmbeddingHitAccessor> posting_list_accessor,
102       const DocumentStore* document_store, const SchemaStore* schema_store,
103       int64_t current_time_ms)
104       : query_(*query),
105         metric_type_(metric_type),
106         embedding_scorer_(std::move(embedding_scorer)),
107         score_low_(score_low),
108         score_high_(score_high),
109         score_map_(*score_map),
110         embedding_index_(*embedding_index),
111         posting_list_accessor_(std::move(posting_list_accessor)),
112         cached_embedding_hits_idx_(0),
113         current_allowed_sections_mask_(kSectionIdMaskAll),
114         no_more_hit_(false),
115         schema_type_id_(kInvalidSchemaTypeId),
116         document_store_(*document_store),
117         schema_store_(*schema_store),
118         current_time_ms_(current_time_ms),
119         num_advance_calls_(0) {}
120 
121   // Advance to the next embedding hit of the current document. If the current
122   // document id is kInvalidDocumentId, the method will advance to the first
123   // embedding hit of the next document and update doc_hit_info_.
124   //
125   // This method also properly updates cached_embedding_hits_,
126   // cached_embedding_hits_idx_, current_allowed_sections_mask_, and
127   // no_more_hit_ to reflect the current state.
128   //
129   // Returns:
130   //   - a const pointer to the next embedding hit on success.
131   //   - nullptr, if there is no more hit for the current document, or no more
132   //     hit in general if the current document id is kInvalidDocumentId.
133   //   - Any error from posting lists.
134   libtextclassifier3::StatusOr<const EmbeddingHit*> AdvanceToNextEmbeddingHit();
135 
136   // Similar to Advance(), this method advances the iterator to the next
137   // document, but it does not guarantee that the next document will have
138   // a matched embedding hit within the score range.
139   //
140   // Returns:
141   //   - OK, if it is able to advance to a new document_id.
142   //   - RESOUCE_EXHAUSTED, if we have run out of document_ids to iterate over.
143   //   - Any error from posting lists.
144   libtextclassifier3::Status AdvanceToNextUnfilteredDocument();
145 
146   // Query information
147   const PropertyProto::VectorProto& query_;  // Does not own
148 
149   // Scoring arguments
150   SearchSpecProto::EmbeddingQueryMetricType::Code metric_type_;
151   std::unique_ptr<EmbeddingScorer> embedding_scorer_;
152   double score_low_;
153   double score_high_;
154 
155   // Score map
156   EmbeddingQueryResults::EmbeddingQueryScoreMap& score_map_;  // Does not own
157 
158   // Access to embeddings index data
159   const EmbeddingIndex& embedding_index_;
160   std::unique_ptr<PostingListEmbeddingHitAccessor> posting_list_accessor_;
161 
162   // Cached data from the embeddings index
163   std::vector<EmbeddingHit> cached_embedding_hits_;
164   int cached_embedding_hits_idx_;
165   SectionIdMask current_allowed_sections_mask_;
166   bool no_more_hit_;
167   SchemaTypeId schema_type_id_;  // The schema type id for the current document.
168 
169   const DocumentStore& document_store_;
170   const SchemaStore& schema_store_;
171   int64_t current_time_ms_;
172   int num_advance_calls_;
173 };
174 
175 }  // namespace lib
176 }  // namespace icing
177 
178 #endif  // ICING_INDEX_EMBED_DOC_HIT_INFO_ITERATOR_EMBEDDING_H_
179