1 // Copyright (C) 2021 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "icing/scoring/section-weights.h"
16
17 #include <cfloat>
18 #include <unordered_map>
19 #include <utility>
20
21 #include "icing/proto/scoring.pb.h"
22 #include "icing/schema/section.h"
23 #include "icing/util/logging.h"
24
25 namespace icing {
26 namespace lib {
27
28 namespace {
29
30 // Normalizes all weights in the map to be in range [0.0, 1.0], where the max
31 // weight is normalized to 1.0. In the case that all weights are equal to 0.0,
32 // the normalized weight for each will be 0.0.
NormalizeSectionWeights(double max_weight,std::unordered_map<SectionId,double> & section_weights)33 inline void NormalizeSectionWeights(
34 double max_weight, std::unordered_map<SectionId, double>& section_weights) {
35 if (max_weight == 0.0) {
36 return;
37 }
38 for (auto& raw_weight : section_weights) {
39 raw_weight.second = raw_weight.second / max_weight;
40 }
41 }
42 } // namespace
43
44 libtextclassifier3::StatusOr<std::unique_ptr<SectionWeights>>
Create(const SchemaStore * schema_store,const ScoringSpecProto & scoring_spec)45 SectionWeights::Create(const SchemaStore* schema_store,
46 const ScoringSpecProto& scoring_spec) {
47 ICING_RETURN_ERROR_IF_NULL(schema_store);
48
49 std::unordered_map<SchemaTypeId, NormalizedSectionWeights>
50 schema_property_weight_map;
51 for (const TypePropertyWeights& type_property_weights :
52 scoring_spec.type_property_weights()) {
53 std::string_view schema_type = type_property_weights.schema_type();
54 auto schema_type_id_or = schema_store->GetSchemaTypeId(schema_type);
55 if (!schema_type_id_or.ok()) {
56 ICING_LOG(WARNING) << "No schema type id found for schema type: "
57 << schema_type;
58 continue;
59 }
60 SchemaTypeId schema_type_id = schema_type_id_or.ValueOrDie();
61 auto section_metadata_list_or =
62 schema_store->GetSectionMetadata(schema_type.data());
63 if (!section_metadata_list_or.ok()) {
64 ICING_LOG(WARNING) << "No metadata found for schema type: "
65 << schema_type;
66 continue;
67 }
68
69 const std::vector<SectionMetadata>* metadata_list =
70 section_metadata_list_or.ValueOrDie();
71
72 std::unordered_map<std::string, double> property_paths_weights;
73 for (const PropertyWeight& property_weight :
74 type_property_weights.property_weights()) {
75 double property_path_weight = property_weight.weight();
76
77 // Return error on negative weights.
78 if (property_path_weight < 0.0) {
79 return absl_ports::InvalidArgumentError(IcingStringUtil::StringPrintf(
80 "Property weight for property path \"%s\" is negative. Negative "
81 "weights are invalid.",
82 property_weight.path().c_str()));
83 }
84 property_paths_weights.insert(
85 {property_weight.path(), property_path_weight});
86 }
87 NormalizedSectionWeights normalized_section_weights =
88 ExtractNormalizedSectionWeights(property_paths_weights, *metadata_list);
89
90 schema_property_weight_map.insert(
91 {schema_type_id,
92 {/*section_weights*/ std::move(
93 normalized_section_weights.section_weights),
94 /*default_weight*/ normalized_section_weights.default_weight}});
95 }
96 // Using `new` to access a non-public constructor.
97 return std::unique_ptr<SectionWeights>(
98 new SectionWeights(std::move(schema_property_weight_map)));
99 }
100
GetNormalizedSectionWeight(SchemaTypeId schema_type_id,SectionId section_id) const101 double SectionWeights::GetNormalizedSectionWeight(SchemaTypeId schema_type_id,
102 SectionId section_id) const {
103 auto schema_type_map = schema_section_weight_map_.find(schema_type_id);
104 if (schema_type_map == schema_section_weight_map_.end()) {
105 // Return default weight if the schema type has no weights specified.
106 return kDefaultSectionWeight;
107 }
108
109 auto section_weight =
110 schema_type_map->second.section_weights.find(section_id);
111 if (section_weight == schema_type_map->second.section_weights.end()) {
112 // If there is no entry for SectionId, the weight is implicitly the
113 // normalized default weight.
114 return schema_type_map->second.default_weight;
115 }
116 return section_weight->second;
117 }
118
119 inline SectionWeights::NormalizedSectionWeights
ExtractNormalizedSectionWeights(const std::unordered_map<std::string,double> & raw_weights,const std::vector<SectionMetadata> & metadata_list)120 SectionWeights::ExtractNormalizedSectionWeights(
121 const std::unordered_map<std::string, double>& raw_weights,
122 const std::vector<SectionMetadata>& metadata_list) {
123 double max_weight = -std::numeric_limits<double>::infinity();
124 std::unordered_map<SectionId, double> section_weights;
125 for (const SectionMetadata& section_metadata : metadata_list) {
126 std::string_view metadata_path = section_metadata.path;
127 double section_weight = kDefaultSectionWeight;
128 auto iter = raw_weights.find(metadata_path.data());
129 if (iter != raw_weights.end()) {
130 section_weight = iter->second;
131 section_weights.insert({section_metadata.id, section_weight});
132 }
133 // Replace max if we see new max weight.
134 max_weight = std::max(max_weight, section_weight);
135 }
136
137 NormalizeSectionWeights(max_weight, section_weights);
138 // Set normalized default weight to 1.0 in case there is no section
139 // metadata and max_weight is -INF (we should not see this case).
140 double normalized_default_weight =
141 max_weight == -std::numeric_limits<double>::infinity()
142 ? kDefaultSectionWeight
143 : kDefaultSectionWeight / max_weight;
144 SectionWeights::NormalizedSectionWeights normalized_section_weights =
145 SectionWeights::NormalizedSectionWeights();
146 normalized_section_weights.section_weights = std::move(section_weights);
147 normalized_section_weights.default_weight = normalized_default_weight;
148 return normalized_section_weights;
149 }
150 } // namespace lib
151 } // namespace icing
152