xref: /aosp_15_r20/external/icing/icing/scoring/advanced_scoring/score-expression-util.h (revision 8b6cd535a057e39b3b86660c4aa06c99747c2136)
1 // Copyright (C) 2024 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #ifndef ICING_SCORING_ADVANCED_SCORING_SCORE_EXPRESSION_UTIL_H_
16 #define ICING_SCORING_ADVANCED_SCORING_SCORE_EXPRESSION_UTIL_H_
17 
18 #include <cstdint>
19 #include <memory>
20 #include <string_view>
21 #include <unordered_set>
22 #include <utility>
23 #include <vector>
24 
25 #include "icing/text_classifier/lib3/utils/base/statusor.h"
26 #include "icing/absl_ports/canonical_errors.h"
27 #include "icing/feature-flags.h"
28 #include "icing/index/embed/embedding-query-results.h"
29 #include "icing/join/join-children-fetcher.h"
30 #include "icing/proto/scoring.pb.h"
31 #include "icing/query/advanced_query_parser/abstract-syntax-tree.h"
32 #include "icing/query/advanced_query_parser/lexer.h"
33 #include "icing/query/advanced_query_parser/parser.h"
34 #include "icing/schema/schema-store.h"
35 #include "icing/scoring/advanced_scoring/score-expression.h"
36 #include "icing/scoring/advanced_scoring/scoring-visitor.h"
37 #include "icing/scoring/bm25f-calculator.h"
38 #include "icing/scoring/section-weights.h"
39 #include "icing/store/document-store.h"
40 #include "icing/util/status-macros.h"
41 
42 namespace icing {
43 namespace lib {
44 namespace score_expression_util {
45 
46 // Returns a ScoreExpression instance for the given scoring expression.
47 //
48 // join_children_fetcher, embedding_query_results, section_weights,
49 // bm25f_calculator, schema_type_alias_map are allowed to be nullptr if the
50 // corresponding scoring expression does not use them.
51 //
52 // Returns:
53 //   - A ScoreExpression instance on success.
54 //   - Any syntax or semantics error from Lexer, Parser, or ScoringVisitor.
55 inline libtextclassifier3::StatusOr<std::unique_ptr<ScoreExpression>>
GetScoreExpression(std::string_view scoring_expression,double default_score,SearchSpecProto::EmbeddingQueryMetricType::Code default_semantic_metric_type,const DocumentStore * document_store,const SchemaStore * schema_store,int64_t current_time_ms,const JoinChildrenFetcher * join_children_fetcher,const EmbeddingQueryResults * embedding_query_results,SectionWeights * section_weights,Bm25fCalculator * bm25f_calculator,const SchemaTypeAliasMap * schema_type_alias_map,const FeatureFlags * feature_flags,const std::unordered_set<ScoringFeatureType> * scoring_feature_types_enabled)56 GetScoreExpression(std::string_view scoring_expression, double default_score,
57                    SearchSpecProto::EmbeddingQueryMetricType::Code
58                        default_semantic_metric_type,
59                    const DocumentStore* document_store,
60                    const SchemaStore* schema_store, int64_t current_time_ms,
61                    const JoinChildrenFetcher* join_children_fetcher,
62                    const EmbeddingQueryResults* embedding_query_results,
63                    SectionWeights* section_weights,
64                    Bm25fCalculator* bm25f_calculator,
65                    const SchemaTypeAliasMap* schema_type_alias_map,
66                    const FeatureFlags* feature_flags,
67                    const std::unordered_set<ScoringFeatureType>*
68                        scoring_feature_types_enabled) {
69   ICING_RETURN_ERROR_IF_NULL(document_store);
70   ICING_RETURN_ERROR_IF_NULL(schema_store);
71   ICING_RETURN_ERROR_IF_NULL(feature_flags);
72 
73   Lexer lexer(scoring_expression, Lexer::Language::SCORING);
74   ICING_ASSIGN_OR_RETURN(std::vector<Lexer::LexerToken> lexer_tokens,
75                          std::move(lexer).ExtractTokens());
76   Parser parser = Parser::Create(std::move(lexer_tokens));
77   ICING_ASSIGN_OR_RETURN(std::unique_ptr<Node> tree_root,
78                          parser.ConsumeScoring());
79   ScoringVisitor visitor(
80       default_score, default_semantic_metric_type, document_store, schema_store,
81       section_weights, bm25f_calculator, join_children_fetcher,
82       embedding_query_results, schema_type_alias_map, feature_flags,
83       scoring_feature_types_enabled, current_time_ms);
84   tree_root->Accept(&visitor);
85 
86   ICING_ASSIGN_OR_RETURN(std::unique_ptr<ScoreExpression> expression,
87                          std::move(visitor).Expression());
88   if (expression->type() != ScoreExpressionType::kDouble) {
89     return absl_ports::InvalidArgumentError(
90         "The root scoring expression is not of double type.");
91   }
92   return expression;
93 }
94 
95 inline std::unique_ptr<std::unordered_set<ScoringFeatureType>>
GetEnabledScoringFeatureTypes(const ScoringSpecProto & scoring_spec)96 GetEnabledScoringFeatureTypes(const ScoringSpecProto& scoring_spec) {
97   auto scoring_feature_types_enabled =
98       std::make_unique<std::unordered_set<ScoringFeatureType>>();
99   for (int feature_type : scoring_spec.scoring_feature_types_enabled()) {
100     scoring_feature_types_enabled->insert(
101         static_cast<ScoringFeatureType>(feature_type));
102   }
103   return scoring_feature_types_enabled;
104 }
105 
106 }  // namespace score_expression_util
107 
108 }  // namespace lib
109 }  // namespace icing
110 
111 #endif  // ICING_SCORING_ADVANCED_SCORING_SCORE_EXPRESSION_UTIL_H_
112