1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h"
17 
18 #include "absl/status/status.h"
19 #include "absl/strings/str_replace.h"
20 #include "src/sentencepiece_model.pb.h"
21 #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/decoder_config_generated.h"
22 #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.h"
23 #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/encoder_config_generated.h"
24 #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_constants.h"
25 
26 namespace tflite {
27 namespace ops {
28 namespace custom {
29 namespace sentencepiece {
30 
31 std::tuple<std::vector<uint32_t>, std::vector<int8_t>>
DecodePrecompiledCharsmap(const::sentencepiece::NormalizerSpec & normalizer_spec)32 DecodePrecompiledCharsmap(
33     const ::sentencepiece::NormalizerSpec& normalizer_spec) {
34   // This function "undoes" encoding done by
35   // sentencepiece::normalizer::Normalizer::EncodePrecompiledCharsMap.
36   const char* precompiled_map = normalizer_spec.precompiled_charsmap().data();
37   const uint32_t trie_size =
38       *reinterpret_cast<const uint32_t*>(precompiled_map);
39   const uint32_t* trie_ptr =
40       reinterpret_cast<const uint32_t*>(precompiled_map + sizeof(uint32_t));
41   const int8_t* normalized_ptr = reinterpret_cast<const int8_t*>(
42       precompiled_map + sizeof(uint32_t) + trie_size);
43   const int normalized_size = normalizer_spec.precompiled_charsmap().length() -
44                               sizeof(uint32_t) - trie_size;
45   return std::make_tuple(
46       std::vector<uint32_t>(trie_ptr, trie_ptr + trie_size / sizeof(uint32_t)),
47       std::vector<int8_t>(normalized_ptr, normalized_ptr + normalized_size));
48 }
49 
ConvertSentencepieceModelToFlatBuffer(const std::string & model_config_str,int encoding_offset)50 tflite::support::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer(
51     const std::string& model_config_str, int encoding_offset) {
52   ::sentencepiece::ModelProto model_config;
53   if (!model_config.ParseFromString(model_config_str)) {
54     return absl::InvalidArgumentError(
55         "Invalid configuration, can't parse SentencePiece model config " +
56         model_config.InitializationErrorString());
57   }
58   // Convert sentencepieces.
59   std::vector<std::string> pieces;
60   pieces.reserve(model_config.pieces_size());
61   std::vector<float> scores;
62   scores.reserve(model_config.pieces_size());
63   std::vector<int> ids;
64   ids.reserve(model_config.pieces_size());
65   float min_score = 0.0;
66   int index = 0;
67   for (const auto& piece : model_config.pieces()) {
68     switch (piece.type()) {
69       case ::sentencepiece::ModelProto::SentencePiece::NORMAL:
70       case ::sentencepiece::ModelProto::SentencePiece::USER_DEFINED:
71         pieces.push_back(piece.piece());
72         ids.push_back(index);
73         if (piece.score() < min_score) {
74           min_score = piece.score();
75         }
76         break;
77       case ::sentencepiece::ModelProto::SentencePiece::UNKNOWN:
78       case ::sentencepiece::ModelProto::SentencePiece::CONTROL:
79         // Ignore unknown and control codes.
80         break;
81       default:
82         return absl::InvalidArgumentError("Invalid SentencePiece piece type " +
83                                           piece.piece());
84     }
85     scores.push_back(piece.score());
86     ++index;
87   }
88   flatbuffers::FlatBufferBuilder builder(1024);
89   const auto pieces_trie_vector = builder.CreateVector(BuildTrie(pieces, ids));
90   const auto pieces_score_vector = builder.CreateVector(scores);
91   TrieBuilder pieces_trie_builder(builder);
92   pieces_trie_builder.add_nodes(pieces_trie_vector);
93   const auto pieces_trie_fbs = pieces_trie_builder.Finish();
94 
95   // Converting normalization.
96   const auto [normalization_trie, normalization_strings] =
97       DecodePrecompiledCharsmap(model_config.normalizer_spec());
98   const auto normalization_trie_vector =
99       builder.CreateVector(normalization_trie);
100   TrieBuilder normalization_trie_builder(builder);
101   normalization_trie_builder.add_nodes(normalization_trie_vector);
102   const auto normalization_trie_fbs = normalization_trie_builder.Finish();
103   const auto normalization_strings_fbs =
104       builder.CreateVector(normalization_strings);
105 
106   EncoderConfigBuilder ecb(builder);
107   ecb.add_version(EncoderVersion::EncoderVersion_SENTENCE_PIECE);
108   ecb.add_start_code(model_config.trainer_spec().bos_id());
109   ecb.add_end_code(model_config.trainer_spec().eos_id());
110   ecb.add_unknown_code(model_config.trainer_spec().unk_id());
111   ecb.add_unknown_penalty(min_score - kUnkPenalty);
112   ecb.add_encoding_offset(encoding_offset);
113   ecb.add_pieces(pieces_trie_fbs);
114   ecb.add_pieces_scores(pieces_score_vector);
115   ecb.add_remove_extra_whitespaces(
116       model_config.normalizer_spec().remove_extra_whitespaces());
117   ecb.add_add_dummy_prefix(model_config.normalizer_spec().add_dummy_prefix());
118   ecb.add_escape_whitespaces(
119       model_config.normalizer_spec().escape_whitespaces());
120   ecb.add_normalized_prefixes(normalization_trie_fbs);
121   ecb.add_normalized_replacements(normalization_strings_fbs);
122   FinishEncoderConfigBuffer(builder, ecb.Finish());
123   return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
124                      builder.GetSize());
125 }
126 
127 tflite::support::StatusOr<std::string>
ConvertSentencepieceModelToFlatBufferForDecoder(const std::string & model_config_str,int encoding_offset)128 ConvertSentencepieceModelToFlatBufferForDecoder(
129     const std::string& model_config_str, int encoding_offset) {
130   ::sentencepiece::ModelProto model_config;
131   if (!model_config.ParseFromString(model_config_str)) {
132     return absl::InvalidArgumentError(
133         "Invalid configuration, can't parse SentencePiece model config " +
134         model_config.InitializationErrorString());
135   }
136   flatbuffers::FlatBufferBuilder builder(1024);
137   // Collect sentencepieces.
138   std::vector<std::string> pieces;
139   for (const auto& piece : model_config.pieces()) {
140     // In the original library all pieces processing is done during decoding.
141     // Because it is independent from context or parameters we can do it in
142     // advance here.
143     switch (piece.type()) {
144       case ::sentencepiece::ModelProto::SentencePiece::NORMAL:
145       case ::sentencepiece::ModelProto::SentencePiece::USER_DEFINED:
146         pieces.push_back(
147             absl::StrReplaceAll(piece.piece(), {{kSpaceSymbol, " "}}));
148         break;
149       case ::sentencepiece::ModelProto::SentencePiece::UNKNOWN:
150         pieces.push_back(
151             kDefaultUnknownSymbol);  // Always decode with the default unknown.
152         break;
153       default:
154         pieces.push_back("");
155     }
156   }
157   const auto pieces_fbs = builder.CreateVectorOfStrings(pieces);
158   DecoderConfigBuilder decb(builder);
159 
160   decb.add_version(EncoderVersion::EncoderVersion_SENTENCE_PIECE);
161   decb.add_encoding_offset(encoding_offset);
162   decb.add_decode_pieces(pieces_fbs);
163   decb.add_remove_dummy_prefix(
164       model_config.normalizer_spec().add_dummy_prefix());
165 
166   FinishDecoderConfigBuffer(builder, decb.Finish());
167   return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
168                      builder.GetSize());
169 }
170 
GetVocabularySize(const std::string & model_string)171 int GetVocabularySize(const std::string& model_string) {
172   const EncoderConfig* config = GetEncoderConfig(model_string.data());
173   return config->pieces_scores()->size() + config->encoding_offset();
174 }
175 
ConvertSentencepieceModel(const std::string & model_string)176 std::string ConvertSentencepieceModel(const std::string& model_string) {
177   const auto result = ConvertSentencepieceModelToFlatBuffer(model_string);
178   // TODO(mgubin): Propogate error to the Python code and throw correct
179   // exception.
180   assert(result.status().ok());
181   return result.value();
182 }
183 
ConvertSentencepieceModelForDecoder(const std::string & model_string)184 std::string ConvertSentencepieceModelForDecoder(
185     const std::string& model_string) {
186   const auto result =
187       ConvertSentencepieceModelToFlatBufferForDecoder(model_string);
188   // TODO(mgubin): Propogate error to the Python code and throw correct
189   // exception.
190   assert(result.status().ok());
191   return result.value();
192 }
193 
194 }  // namespace sentencepiece
195 }  // namespace custom
196 }  // namespace ops
197 }  // namespace tflite
198