1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/model_converter.h"
17
18 #include "absl/status/status.h"
19 #include "absl/strings/str_replace.h"
20 #include "src/sentencepiece_model.pb.h"
21 #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/decoder_config_generated.h"
22 #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/double_array_trie_builder.h"
23 #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/encoder_config_generated.h"
24 #include "tensorflow_lite_support/custom_ops/kernel/sentencepiece/sentencepiece_constants.h"
25
26 namespace tflite {
27 namespace ops {
28 namespace custom {
29 namespace sentencepiece {
30
31 std::tuple<std::vector<uint32_t>, std::vector<int8_t>>
DecodePrecompiledCharsmap(const::sentencepiece::NormalizerSpec & normalizer_spec)32 DecodePrecompiledCharsmap(
33 const ::sentencepiece::NormalizerSpec& normalizer_spec) {
34 // This function "undoes" encoding done by
35 // sentencepiece::normalizer::Normalizer::EncodePrecompiledCharsMap.
36 const char* precompiled_map = normalizer_spec.precompiled_charsmap().data();
37 const uint32_t trie_size =
38 *reinterpret_cast<const uint32_t*>(precompiled_map);
39 const uint32_t* trie_ptr =
40 reinterpret_cast<const uint32_t*>(precompiled_map + sizeof(uint32_t));
41 const int8_t* normalized_ptr = reinterpret_cast<const int8_t*>(
42 precompiled_map + sizeof(uint32_t) + trie_size);
43 const int normalized_size = normalizer_spec.precompiled_charsmap().length() -
44 sizeof(uint32_t) - trie_size;
45 return std::make_tuple(
46 std::vector<uint32_t>(trie_ptr, trie_ptr + trie_size / sizeof(uint32_t)),
47 std::vector<int8_t>(normalized_ptr, normalized_ptr + normalized_size));
48 }
49
ConvertSentencepieceModelToFlatBuffer(const std::string & model_config_str,int encoding_offset)50 tflite::support::StatusOr<std::string> ConvertSentencepieceModelToFlatBuffer(
51 const std::string& model_config_str, int encoding_offset) {
52 ::sentencepiece::ModelProto model_config;
53 if (!model_config.ParseFromString(model_config_str)) {
54 return absl::InvalidArgumentError(
55 "Invalid configuration, can't parse SentencePiece model config " +
56 model_config.InitializationErrorString());
57 }
58 // Convert sentencepieces.
59 std::vector<std::string> pieces;
60 pieces.reserve(model_config.pieces_size());
61 std::vector<float> scores;
62 scores.reserve(model_config.pieces_size());
63 std::vector<int> ids;
64 ids.reserve(model_config.pieces_size());
65 float min_score = 0.0;
66 int index = 0;
67 for (const auto& piece : model_config.pieces()) {
68 switch (piece.type()) {
69 case ::sentencepiece::ModelProto::SentencePiece::NORMAL:
70 case ::sentencepiece::ModelProto::SentencePiece::USER_DEFINED:
71 pieces.push_back(piece.piece());
72 ids.push_back(index);
73 if (piece.score() < min_score) {
74 min_score = piece.score();
75 }
76 break;
77 case ::sentencepiece::ModelProto::SentencePiece::UNKNOWN:
78 case ::sentencepiece::ModelProto::SentencePiece::CONTROL:
79 // Ignore unknown and control codes.
80 break;
81 default:
82 return absl::InvalidArgumentError("Invalid SentencePiece piece type " +
83 piece.piece());
84 }
85 scores.push_back(piece.score());
86 ++index;
87 }
88 flatbuffers::FlatBufferBuilder builder(1024);
89 const auto pieces_trie_vector = builder.CreateVector(BuildTrie(pieces, ids));
90 const auto pieces_score_vector = builder.CreateVector(scores);
91 TrieBuilder pieces_trie_builder(builder);
92 pieces_trie_builder.add_nodes(pieces_trie_vector);
93 const auto pieces_trie_fbs = pieces_trie_builder.Finish();
94
95 // Converting normalization.
96 const auto [normalization_trie, normalization_strings] =
97 DecodePrecompiledCharsmap(model_config.normalizer_spec());
98 const auto normalization_trie_vector =
99 builder.CreateVector(normalization_trie);
100 TrieBuilder normalization_trie_builder(builder);
101 normalization_trie_builder.add_nodes(normalization_trie_vector);
102 const auto normalization_trie_fbs = normalization_trie_builder.Finish();
103 const auto normalization_strings_fbs =
104 builder.CreateVector(normalization_strings);
105
106 EncoderConfigBuilder ecb(builder);
107 ecb.add_version(EncoderVersion::EncoderVersion_SENTENCE_PIECE);
108 ecb.add_start_code(model_config.trainer_spec().bos_id());
109 ecb.add_end_code(model_config.trainer_spec().eos_id());
110 ecb.add_unknown_code(model_config.trainer_spec().unk_id());
111 ecb.add_unknown_penalty(min_score - kUnkPenalty);
112 ecb.add_encoding_offset(encoding_offset);
113 ecb.add_pieces(pieces_trie_fbs);
114 ecb.add_pieces_scores(pieces_score_vector);
115 ecb.add_remove_extra_whitespaces(
116 model_config.normalizer_spec().remove_extra_whitespaces());
117 ecb.add_add_dummy_prefix(model_config.normalizer_spec().add_dummy_prefix());
118 ecb.add_escape_whitespaces(
119 model_config.normalizer_spec().escape_whitespaces());
120 ecb.add_normalized_prefixes(normalization_trie_fbs);
121 ecb.add_normalized_replacements(normalization_strings_fbs);
122 FinishEncoderConfigBuffer(builder, ecb.Finish());
123 return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
124 builder.GetSize());
125 }
126
127 tflite::support::StatusOr<std::string>
ConvertSentencepieceModelToFlatBufferForDecoder(const std::string & model_config_str,int encoding_offset)128 ConvertSentencepieceModelToFlatBufferForDecoder(
129 const std::string& model_config_str, int encoding_offset) {
130 ::sentencepiece::ModelProto model_config;
131 if (!model_config.ParseFromString(model_config_str)) {
132 return absl::InvalidArgumentError(
133 "Invalid configuration, can't parse SentencePiece model config " +
134 model_config.InitializationErrorString());
135 }
136 flatbuffers::FlatBufferBuilder builder(1024);
137 // Collect sentencepieces.
138 std::vector<std::string> pieces;
139 for (const auto& piece : model_config.pieces()) {
140 // In the original library all pieces processing is done during decoding.
141 // Because it is independent from context or parameters we can do it in
142 // advance here.
143 switch (piece.type()) {
144 case ::sentencepiece::ModelProto::SentencePiece::NORMAL:
145 case ::sentencepiece::ModelProto::SentencePiece::USER_DEFINED:
146 pieces.push_back(
147 absl::StrReplaceAll(piece.piece(), {{kSpaceSymbol, " "}}));
148 break;
149 case ::sentencepiece::ModelProto::SentencePiece::UNKNOWN:
150 pieces.push_back(
151 kDefaultUnknownSymbol); // Always decode with the default unknown.
152 break;
153 default:
154 pieces.push_back("");
155 }
156 }
157 const auto pieces_fbs = builder.CreateVectorOfStrings(pieces);
158 DecoderConfigBuilder decb(builder);
159
160 decb.add_version(EncoderVersion::EncoderVersion_SENTENCE_PIECE);
161 decb.add_encoding_offset(encoding_offset);
162 decb.add_decode_pieces(pieces_fbs);
163 decb.add_remove_dummy_prefix(
164 model_config.normalizer_spec().add_dummy_prefix());
165
166 FinishDecoderConfigBuffer(builder, decb.Finish());
167 return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
168 builder.GetSize());
169 }
170
GetVocabularySize(const std::string & model_string)171 int GetVocabularySize(const std::string& model_string) {
172 const EncoderConfig* config = GetEncoderConfig(model_string.data());
173 return config->pieces_scores()->size() + config->encoding_offset();
174 }
175
ConvertSentencepieceModel(const std::string & model_string)176 std::string ConvertSentencepieceModel(const std::string& model_string) {
177 const auto result = ConvertSentencepieceModelToFlatBuffer(model_string);
178 // TODO(mgubin): Propogate error to the Python code and throw correct
179 // exception.
180 assert(result.status().ok());
181 return result.value();
182 }
183
ConvertSentencepieceModelForDecoder(const std::string & model_string)184 std::string ConvertSentencepieceModelForDecoder(
185 const std::string& model_string) {
186 const auto result =
187 ConvertSentencepieceModelToFlatBufferForDecoder(model_string);
188 // TODO(mgubin): Propogate error to the Python code and throw correct
189 // exception.
190 assert(result.status().ok());
191 return result.value();
192 }
193
194 } // namespace sentencepiece
195 } // namespace custom
196 } // namespace ops
197 } // namespace tflite
198