1 // Copyright (C) 2022 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #ifndef ICING_SCORING_ADVANCED_SCORING_SCORING_VISITOR_H_ 16 #define ICING_SCORING_ADVANCED_SCORING_SCORING_VISITOR_H_ 17 18 #include <cstdint> 19 #include <memory> 20 #include <unordered_set> 21 #include <utility> 22 #include <vector> 23 24 #include "icing/text_classifier/lib3/utils/base/status.h" 25 #include "icing/text_classifier/lib3/utils/base/statusor.h" 26 #include "icing/absl_ports/canonical_errors.h" 27 #include "icing/feature-flags.h" 28 #include "icing/index/embed/embedding-query-results.h" 29 #include "icing/join/join-children-fetcher.h" 30 #include "icing/legacy/core/icing-string-util.h" 31 #include "icing/proto/scoring.pb.h" 32 #include "icing/query/advanced_query_parser/abstract-syntax-tree.h" 33 #include "icing/schema/schema-store.h" 34 #include "icing/scoring/advanced_scoring/score-expression.h" 35 #include "icing/scoring/bm25f-calculator.h" 36 #include "icing/scoring/section-weights.h" 37 #include "icing/store/document-store.h" 38 39 namespace icing { 40 namespace lib { 41 42 class ScoringVisitor : public AbstractSyntaxTreeVisitor { 43 public: 44 // join_children_fetcher, embedding_query_results, section_weights, 45 // bm25f_calculator, schema_type_alias_map are allowed to be nullptr if the 46 // corresponding scoring expression does not use them. ScoringVisitor(double default_score,SearchSpecProto::EmbeddingQueryMetricType::Code default_semantic_metric_type,const DocumentStore * document_store,const SchemaStore * schema_store,SectionWeights * section_weights,Bm25fCalculator * bm25f_calculator,const JoinChildrenFetcher * join_children_fetcher,const EmbeddingQueryResults * embedding_query_results,const SchemaTypeAliasMap * schema_type_alias_map,const FeatureFlags * feature_flags,const std::unordered_set<ScoringFeatureType> * scoring_feature_types_enabled,int64_t current_time_ms)47 explicit ScoringVisitor(double default_score, 48 SearchSpecProto::EmbeddingQueryMetricType::Code 49 default_semantic_metric_type, 50 const DocumentStore* document_store, 51 const SchemaStore* schema_store, 52 SectionWeights* section_weights, 53 Bm25fCalculator* bm25f_calculator, 54 const JoinChildrenFetcher* join_children_fetcher, 55 const EmbeddingQueryResults* embedding_query_results, 56 const SchemaTypeAliasMap* schema_type_alias_map, 57 const FeatureFlags* feature_flags, 58 const std::unordered_set<ScoringFeatureType>* 59 scoring_feature_types_enabled, 60 int64_t current_time_ms) 61 : default_score_(default_score), 62 default_semantic_metric_type_(default_semantic_metric_type), 63 document_store_(*document_store), 64 schema_store_(*schema_store), 65 section_weights_(section_weights), 66 bm25f_calculator_(bm25f_calculator), 67 join_children_fetcher_(join_children_fetcher), 68 embedding_query_results_(embedding_query_results), 69 schema_type_alias_map_(schema_type_alias_map), 70 feature_flags_(*feature_flags), 71 scoring_feature_types_enabled_(*scoring_feature_types_enabled), 72 current_time_ms_(current_time_ms) {} 73 74 void VisitString(const StringNode* node) override; 75 void VisitText(const TextNode* node) override; 76 void VisitMember(const MemberNode* node) override; 77 VisitFunction(const FunctionNode * node)78 void VisitFunction(const FunctionNode* node) override { 79 return VisitFunctionHelper(node, /*is_member_function=*/false); 80 } 81 82 void VisitUnaryOperator(const UnaryOperatorNode* node) override; 83 void VisitNaryOperator(const NaryOperatorNode* node) override; 84 85 // RETURNS: 86 // - An ScoreExpression instance able to evaluate the expression on success. 87 // - INVALID_ARGUMENT if the AST does not conform to supported expressions, 88 // such as type errors. 89 // - INTERNAL if there are inconsistencies. 90 libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>> Expression()91 Expression() && { 92 if (has_pending_error()) { 93 return pending_error_; 94 } 95 if (stack_.size() != 1) { 96 return absl_ports::InternalError(IcingStringUtil::StringPrintf( 97 "Expect to get only one result from " 98 "ScoringVisitor, but got %zu. There must be inconsistencies.", 99 stack_.size())); 100 } 101 return std::move(stack_[0]); 102 } 103 104 private: 105 // Visit function node. If is_member_function is true, a ThisExpression will 106 // be added as the first function argument. 107 void VisitFunctionHelper(const FunctionNode* node, bool is_member_function); 108 has_pending_error()109 bool has_pending_error() const { return !pending_error_.ok(); } 110 pop_stack()111 std::unique_ptr<ScoreExpression> pop_stack() { 112 std::unique_ptr<ScoreExpression> result = std::move(stack_.back()); 113 stack_.pop_back(); 114 return result; 115 } 116 117 double default_score_; 118 const SearchSpecProto::EmbeddingQueryMetricType::Code 119 default_semantic_metric_type_; 120 const DocumentStore& document_store_; // Does not own. 121 const SchemaStore& schema_store_; // Does not own. 122 123 SectionWeights* section_weights_; // nullable, does not own. 124 Bm25fCalculator* bm25f_calculator_; // nullable, does not own. 125 // A non-null join_children_fetcher_ indicates scoring in a join. 126 const JoinChildrenFetcher* join_children_fetcher_; // nullable, does not own. 127 const EmbeddingQueryResults* 128 embedding_query_results_; // nullable, does not own. 129 const SchemaTypeAliasMap* schema_type_alias_map_; // nullable, does not own. 130 131 const FeatureFlags& feature_flags_; // Does not own. 132 const std::unordered_set<ScoringFeatureType>& 133 scoring_feature_types_enabled_; // Does not own. 134 135 libtextclassifier3::Status pending_error_; 136 std::vector<std::unique_ptr<ScoreExpression>> stack_; 137 int64_t current_time_ms_; 138 }; 139 140 } // namespace lib 141 } // namespace icing 142 143 #endif // ICING_SCORING_ADVANCED_SCORING_SCORING_VISITOR_H_ 144