xref: /aosp_15_r20/external/icing/icing/scoring/advanced_scoring/scoring-visitor.h (revision 8b6cd535a057e39b3b86660c4aa06c99747c2136)
1 // Copyright (C) 2022 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef ICING_SCORING_ADVANCED_SCORING_SCORING_VISITOR_H_
16 #define ICING_SCORING_ADVANCED_SCORING_SCORING_VISITOR_H_
17 
18 #include <cstdint>
19 #include <memory>
20 #include <unordered_set>
21 #include <utility>
22 #include <vector>
23 
24 #include "icing/text_classifier/lib3/utils/base/status.h"
25 #include "icing/text_classifier/lib3/utils/base/statusor.h"
26 #include "icing/absl_ports/canonical_errors.h"
27 #include "icing/feature-flags.h"
28 #include "icing/index/embed/embedding-query-results.h"
29 #include "icing/join/join-children-fetcher.h"
30 #include "icing/legacy/core/icing-string-util.h"
31 #include "icing/proto/scoring.pb.h"
32 #include "icing/query/advanced_query_parser/abstract-syntax-tree.h"
33 #include "icing/schema/schema-store.h"
34 #include "icing/scoring/advanced_scoring/score-expression.h"
35 #include "icing/scoring/bm25f-calculator.h"
36 #include "icing/scoring/section-weights.h"
37 #include "icing/store/document-store.h"
38 
39 namespace icing {
40 namespace lib {
41 
42 class ScoringVisitor : public AbstractSyntaxTreeVisitor {
43  public:
44   // join_children_fetcher, embedding_query_results, section_weights,
45   // bm25f_calculator, schema_type_alias_map are allowed to be nullptr if the
46   // corresponding scoring expression does not use them.
ScoringVisitor(double default_score,SearchSpecProto::EmbeddingQueryMetricType::Code default_semantic_metric_type,const DocumentStore * document_store,const SchemaStore * schema_store,SectionWeights * section_weights,Bm25fCalculator * bm25f_calculator,const JoinChildrenFetcher * join_children_fetcher,const EmbeddingQueryResults * embedding_query_results,const SchemaTypeAliasMap * schema_type_alias_map,const FeatureFlags * feature_flags,const std::unordered_set<ScoringFeatureType> * scoring_feature_types_enabled,int64_t current_time_ms)47   explicit ScoringVisitor(double default_score,
48                           SearchSpecProto::EmbeddingQueryMetricType::Code
49                               default_semantic_metric_type,
50                           const DocumentStore* document_store,
51                           const SchemaStore* schema_store,
52                           SectionWeights* section_weights,
53                           Bm25fCalculator* bm25f_calculator,
54                           const JoinChildrenFetcher* join_children_fetcher,
55                           const EmbeddingQueryResults* embedding_query_results,
56                           const SchemaTypeAliasMap* schema_type_alias_map,
57                           const FeatureFlags* feature_flags,
58                           const std::unordered_set<ScoringFeatureType>*
59                               scoring_feature_types_enabled,
60                           int64_t current_time_ms)
61       : default_score_(default_score),
62         default_semantic_metric_type_(default_semantic_metric_type),
63         document_store_(*document_store),
64         schema_store_(*schema_store),
65         section_weights_(section_weights),
66         bm25f_calculator_(bm25f_calculator),
67         join_children_fetcher_(join_children_fetcher),
68         embedding_query_results_(embedding_query_results),
69         schema_type_alias_map_(schema_type_alias_map),
70         feature_flags_(*feature_flags),
71         scoring_feature_types_enabled_(*scoring_feature_types_enabled),
72         current_time_ms_(current_time_ms) {}
73 
74   void VisitString(const StringNode* node) override;
75   void VisitText(const TextNode* node) override;
76   void VisitMember(const MemberNode* node) override;
77 
VisitFunction(const FunctionNode * node)78   void VisitFunction(const FunctionNode* node) override {
79     return VisitFunctionHelper(node, /*is_member_function=*/false);
80   }
81 
82   void VisitUnaryOperator(const UnaryOperatorNode* node) override;
83   void VisitNaryOperator(const NaryOperatorNode* node) override;
84 
85   // RETURNS:
86   //   - An ScoreExpression instance able to evaluate the expression on success.
87   //   - INVALID_ARGUMENT if the AST does not conform to supported expressions,
88   //   such as type errors.
89   //   - INTERNAL if there are inconsistencies.
90   libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>>
Expression()91   Expression() && {
92     if (has_pending_error()) {
93       return pending_error_;
94     }
95     if (stack_.size() != 1) {
96       return absl_ports::InternalError(IcingStringUtil::StringPrintf(
97           "Expect to get only one result from "
98           "ScoringVisitor, but got %zu. There must be inconsistencies.",
99           stack_.size()));
100     }
101     return std::move(stack_[0]);
102   }
103 
104  private:
105   // Visit function node. If is_member_function is true, a ThisExpression will
106   // be added as the first function argument.
107   void VisitFunctionHelper(const FunctionNode* node, bool is_member_function);
108 
has_pending_error()109   bool has_pending_error() const { return !pending_error_.ok(); }
110 
pop_stack()111   std::unique_ptr<ScoreExpression> pop_stack() {
112     std::unique_ptr<ScoreExpression> result = std::move(stack_.back());
113     stack_.pop_back();
114     return result;
115   }
116 
117   double default_score_;
118   const SearchSpecProto::EmbeddingQueryMetricType::Code
119       default_semantic_metric_type_;
120   const DocumentStore& document_store_;  // Does not own.
121   const SchemaStore& schema_store_;      // Does not own.
122 
123   SectionWeights* section_weights_;    // nullable, does not own.
124   Bm25fCalculator* bm25f_calculator_;  // nullable, does not own.
125   // A non-null join_children_fetcher_ indicates scoring in a join.
126   const JoinChildrenFetcher* join_children_fetcher_;  // nullable, does not own.
127   const EmbeddingQueryResults*
128       embedding_query_results_;                      // nullable, does not own.
129   const SchemaTypeAliasMap* schema_type_alias_map_;  // nullable, does not own.
130 
131   const FeatureFlags& feature_flags_;  // Does not own.
132   const std::unordered_set<ScoringFeatureType>&
133       scoring_feature_types_enabled_;  // Does not own.
134 
135   libtextclassifier3::Status pending_error_;
136   std::vector<std::unique_ptr<ScoreExpression>> stack_;
137   int64_t current_time_ms_;
138 };
139 
140 }  // namespace lib
141 }  // namespace icing
142 
143 #endif  // ICING_SCORING_ADVANCED_SCORING_SCORING_VISITOR_H_
144