xref: /aosp_15_r20/external/federated-compute/fcp/dictionary/dictionary.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "fcp/dictionary/dictionary.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "fcp/base/monitoring.h"
26 #include "fcp/dictionary/dictionary.pb.h"
27 #include "absl/container/node_hash_map.h"
28 #include "absl/status/status.h"
29 #include "absl/status/statusor.h"
30 #include "absl/strings/string_view.h"
31 
32 namespace fcp {
33 namespace dictionary {
34 
35 // Bidirectional map defined as hash_map from strings to int32_t paired with
36 // a vector of those keys for reverse lookup.
37 typedef std::pair<absl::node_hash_map<std::string, int32_t>,
38                   std::vector<std::string>>
39     HashVectorBimap;
40 
41 namespace {
42 
43 // Map a string to an ID, using a bidirectional map (an std::pair containing
44 // two data structures for string -> int and for int -> string lookups).
MapLookup(const HashVectorBimap & bimap,const std::string & tag)45 int32_t MapLookup(const HashVectorBimap& bimap, const std::string& tag) {
46   auto map_idx = bimap.first.find(tag);
47   return map_idx == bimap.first.end() ? Dictionary::kNotFound : map_idx->second;
48 }
49 // Lookup a token given its ID.
MapReverseLookup(const HashVectorBimap & bimap,int32_t id)50 std::string MapReverseLookup(const HashVectorBimap& bimap, int32_t id) {
51   if (id < 0 || id >= bimap.second.size()) {
52     return "";
53   }
54   return bimap.second[id];
55 }
56 
57 // Return the size of an stl-like data structure.
GetSize(const HashVectorBimap & bimap)58 int32_t GetSize(const HashVectorBimap& bimap) {
59   return static_cast<int32_t>(bimap.first.size());
60 }
61 
GetMaxSpecialId(const DictionaryDescription::SpecialIds & special_ids)62 int32_t GetMaxSpecialId(const DictionaryDescription::SpecialIds& special_ids) {
63   int32_t max_special_id = -1;
64   max_special_id = std::max(max_special_id, special_ids.bos());
65   max_special_id = std::max(max_special_id, special_ids.eos());
66   max_special_id = std::max(max_special_id, special_ids.unk());
67   return max_special_id;
68 }
69 
70 // Dictionary implementation powered by templated utility functions above.
71 template <typename Bimap>
72 class DictionaryImpl : public Dictionary {
73  public:
DictionaryImpl(std::unique_ptr<Bimap> bimap,const DictionaryDescription::SpecialIds & special_ids,const DictionaryDescription::OutputBlocklistIds & output_blocklist_ids)74   DictionaryImpl(
75       std::unique_ptr<Bimap> bimap,
76       const DictionaryDescription::SpecialIds& special_ids,
77       const DictionaryDescription::OutputBlocklistIds& output_blocklist_ids)
78       : bimap_(std::move(bimap)),
79         special_ids_(special_ids),
80         max_special_id_(GetMaxSpecialId(special_ids)) {
81     // Validate special ids.
82     FCP_CHECK(special_ids.has_bos() == (special_ids.bos() >= 0));
83     FCP_CHECK(special_ids.has_eos() == (special_ids.eos() >= 0));
84     FCP_CHECK(special_ids.has_unk() == (special_ids.unk() >= 0));
85 
86     // Token numbering starts at max(special_ids) + 1.
87     output_blocklist_ids_.reserve(max_special_id_ + 1 +
88                                   output_blocklist_ids.id_size());
89     for (int32_t id = 0; id <= max_special_id_; ++id) {
90       output_blocklist_ids_.push_back(id);
91     }
92     for (int32_t id : output_blocklist_ids.id()) {
93       output_blocklist_ids_.push_back(id);
94     }
95   }
96 
Size() const97   int32_t Size() const override {
98     return GetSize(*bimap_) + max_special_id_ + 1;
99   }
100 
TokenToId(const std::string & tag) const101   int32_t TokenToId(const std::string& tag) const override {
102     int32_t id = MapLookup(*bimap_, tag);
103     if (id == kNotFound) {
104       return special_ids_.unk();
105     } else {
106       return id + max_special_id_ + 1;
107     }
108   }
109 
IdToToken(int32_t id) const110   std::string IdToToken(int32_t id) const override {
111     return MapReverseLookup(*bimap_, id - (max_special_id_ + 1));
112   }
113 
IsSpecialId(int32_t token_id) const114   bool IsSpecialId(int32_t token_id) const override {
115     return token_id <= max_special_id_;
116   }
117 
GetSortedOutputBlocklistIds() const118   const std::vector<int32_t>& GetSortedOutputBlocklistIds() const override {
119     return output_blocklist_ids_;
120   }
121 
GetSpecialIds() const122   const DictionaryDescription::SpecialIds& GetSpecialIds() const override {
123     return special_ids_;
124   }
125 
126  private:
127   const std::unique_ptr<Bimap> bimap_;
128   const DictionaryDescription::SpecialIds special_ids_;
129   int32_t max_special_id_;
130   std::vector<int32_t> output_blocklist_ids_;
131 };
132 
IsOutputBlocklistIdsSortedAndUnique(const DictionaryDescription & description)133 absl::Status IsOutputBlocklistIdsSortedAndUnique(
134     const DictionaryDescription& description) {
135   // All blocklist ids must be greater than max_special_id.
136   const int32_t max_special_id = GetMaxSpecialId(description.special_ids());
137 
138   // Make sure output blocklist IDs are sorted in ascending order and unique.
139   if (description.has_output_blocklist_ids()) {
140     for (int i = 0; i < description.output_blocklist_ids().id_size(); i++) {
141       if (description.output_blocklist_ids().id(i) <= max_special_id) {
142         return absl::InvalidArgumentError(
143             "output_blocklist_ids should not overlap with special ids");
144       }
145       if (!(i == 0 || description.output_blocklist_ids().id(i) >
146                           description.output_blocklist_ids().id(i - 1))) {
147         return absl::InvalidArgumentError(
148             "output_blocklist_ids not unique or sorted");
149       }
150     }
151   }
152   return absl::OkStatus();
153 }
154 
155 }  // anonymous namespace
156 
Create(const DictionaryDescription & description)157 absl::StatusOr<std::unique_ptr<Dictionary>> Dictionary::Create(
158     const DictionaryDescription& description) {
159   if (!description.has_vocabulary()) {
160     return absl::InvalidArgumentError(
161         "Cannot create a dictionary that does not have vocabulary set");
162   }
163   // Make sure output blocklist IDs are sorted in ascending order and unique.
164   FCP_RETURN_IF_ERROR(IsOutputBlocklistIdsSortedAndUnique(description));
165 
166   if (description.vocabulary().has_index()) {
167     auto bimap = std::make_unique<HashVectorBimap>();
168     int i = 0;
169     bimap->second.reserve(description.vocabulary().index().token_size());
170     for (const std::string& token : description.vocabulary().index().token()) {
171       FCP_CHECK(!token.empty());
172       bimap->first[token] = i++;
173       bimap->second.push_back(token);
174     }
175     return std::unique_ptr<Dictionary>(new DictionaryImpl<HashVectorBimap>(
176         std::move(bimap), description.special_ids(),
177     description.output_blocklist_ids()));
178   } else {
179     return absl::InvalidArgumentError(
180         "Invalid DictionaryDescription: no vocabulary specified.");
181   }
182 }
183 }  // namespace dictionary
184 }  // namespace fcp
185