xref: /aosp_15_r20/external/icing/icing/scoring/section-weights.cc (revision 8b6cd535a057e39b3b86660c4aa06c99747c2136)
1 // Copyright (C) 2021 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "icing/scoring/section-weights.h"
16 
17 #include <cfloat>
18 #include <unordered_map>
19 #include <utility>
20 
21 #include "icing/proto/scoring.pb.h"
22 #include "icing/schema/section.h"
23 #include "icing/util/logging.h"
24 
25 namespace icing {
26 namespace lib {
27 
28 namespace {
29 
30 // Normalizes all weights in the map to be in range [0.0, 1.0], where the max
31 // weight is normalized to 1.0. In the case that all weights are equal to 0.0,
32 // the normalized weight for each will be 0.0.
NormalizeSectionWeights(double max_weight,std::unordered_map<SectionId,double> & section_weights)33 inline void NormalizeSectionWeights(
34     double max_weight, std::unordered_map<SectionId, double>& section_weights) {
35   if (max_weight == 0.0) {
36     return;
37   }
38   for (auto& raw_weight : section_weights) {
39     raw_weight.second = raw_weight.second / max_weight;
40   }
41 }
42 }  // namespace
43 
44 libtextclassifier3::StatusOr<std::unique_ptr<SectionWeights>>
Create(const SchemaStore * schema_store,const ScoringSpecProto & scoring_spec)45 SectionWeights::Create(const SchemaStore* schema_store,
46                        const ScoringSpecProto& scoring_spec) {
47   ICING_RETURN_ERROR_IF_NULL(schema_store);
48 
49   std::unordered_map<SchemaTypeId, NormalizedSectionWeights>
50       schema_property_weight_map;
51   for (const TypePropertyWeights& type_property_weights :
52        scoring_spec.type_property_weights()) {
53     std::string_view schema_type = type_property_weights.schema_type();
54     auto schema_type_id_or = schema_store->GetSchemaTypeId(schema_type);
55     if (!schema_type_id_or.ok()) {
56       ICING_LOG(WARNING) << "No schema type id found for schema type: "
57                          << schema_type;
58       continue;
59     }
60     SchemaTypeId schema_type_id = schema_type_id_or.ValueOrDie();
61     auto section_metadata_list_or =
62         schema_store->GetSectionMetadata(schema_type.data());
63     if (!section_metadata_list_or.ok()) {
64       ICING_LOG(WARNING) << "No metadata found for schema type: "
65                          << schema_type;
66       continue;
67     }
68 
69     const std::vector<SectionMetadata>* metadata_list =
70         section_metadata_list_or.ValueOrDie();
71 
72     std::unordered_map<std::string, double> property_paths_weights;
73     for (const PropertyWeight& property_weight :
74          type_property_weights.property_weights()) {
75       double property_path_weight = property_weight.weight();
76 
77       // Return error on negative weights.
78       if (property_path_weight < 0.0) {
79         return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf(
80             "Property weight for property path \"%s\" is negative. Negative "
81             "weights are invalid.",
82             property_weight.path().c_str()));
83       }
84       property_paths_weights.insert(
85           {property_weight.path(), property_path_weight});
86     }
87     NormalizedSectionWeights normalized_section_weights =
88         ExtractNormalizedSectionWeights(property_paths_weights, *metadata_list);
89 
90     schema_property_weight_map.insert(
91         {schema_type_id,
92          {/*section_weights*/ std::move(
93               normalized_section_weights.section_weights),
94           /*default_weight*/ normalized_section_weights.default_weight}});
95   }
96   // Using `new` to access a non-public constructor.
97   return std::unique_ptr<SectionWeights>(
98       new SectionWeights(std::move(schema_property_weight_map)));
99 }
100 
GetNormalizedSectionWeight(SchemaTypeId schema_type_id,SectionId section_id) const101 double SectionWeights::GetNormalizedSectionWeight(SchemaTypeId schema_type_id,
102                                                   SectionId section_id) const {
103   auto schema_type_map = schema_section_weight_map_.find(schema_type_id);
104   if (schema_type_map == schema_section_weight_map_.end()) {
105     // Return default weight if the schema type has no weights specified.
106     return kDefaultSectionWeight;
107   }
108 
109   auto section_weight =
110       schema_type_map->second.section_weights.find(section_id);
111   if (section_weight == schema_type_map->second.section_weights.end()) {
112     // If there is no entry for SectionId, the weight is implicitly the
113     // normalized default weight.
114     return schema_type_map->second.default_weight;
115   }
116   return section_weight->second;
117 }
118 
119 inline SectionWeights::NormalizedSectionWeights
ExtractNormalizedSectionWeights(const std::unordered_map<std::string,double> & raw_weights,const std::vector<SectionMetadata> & metadata_list)120 SectionWeights::ExtractNormalizedSectionWeights(
121     const std::unordered_map<std::string, double>& raw_weights,
122     const std::vector<SectionMetadata>& metadata_list) {
123   double max_weight = -std::numeric_limits<double>::infinity();
124   std::unordered_map<SectionId, double> section_weights;
125   for (const SectionMetadata& section_metadata : metadata_list) {
126     std::string_view metadata_path = section_metadata.path;
127     double section_weight = kDefaultSectionWeight;
128     auto iter = raw_weights.find(metadata_path.data());
129     if (iter != raw_weights.end()) {
130       section_weight = iter->second;
131       section_weights.insert({section_metadata.id, section_weight});
132     }
133     // Replace max if we see new max weight.
134     max_weight = std::max(max_weight, section_weight);
135   }
136 
137   NormalizeSectionWeights(max_weight, section_weights);
138   // Set normalized default weight to 1.0 in case there is no section
139   // metadata and max_weight is -INF (we should not see this case).
140   double normalized_default_weight =
141       max_weight == -std::numeric_limits<double>::infinity()
142           ? kDefaultSectionWeight
143           : kDefaultSectionWeight / max_weight;
144   SectionWeights::NormalizedSectionWeights normalized_section_weights =
145       SectionWeights::NormalizedSectionWeights();
146   normalized_section_weights.section_weights = std::move(section_weights);
147   normalized_section_weights.default_weight = normalized_default_weight;
148   return normalized_section_weights;
149 }
150 }  // namespace lib
151 }  // namespace icing
152