1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 // Adopted from https://github.com/sewenew/tokenizer
10
11 // @lint-ignore-every LICENSELINT
12 /**************************************************************************
13 Copyright (c) 2023 sewenew
14
15 Licensed under the Apache License, Version 2.0 (the "License");
16 you may not use this file except in compliance with the License.
17 You may obtain a copy of the License at
18
19 http://www.apache.org/licenses/LICENSE-2.0
20
21 Unless required by applicable law or agreed to in writing, software
22 distributed under the License is distributed on an "AS IS" BASIS,
23 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24 See the License for the specific language governing permissions and
25 limitations under the License.
26 *************************************************************************/
27
28 #include <executorch/extension/llm/tokenizer/base64.h>
29 #include <executorch/extension/llm/tokenizer/tiktoken.h>
30 #include <executorch/runtime/core/result.h>
31 #include <fstream>
32 #include <limits>
33
34 using ::executorch::runtime::Error;
35 using ::executorch::runtime::Result;
36
37 namespace executorch {
38 namespace extension {
39 namespace llm {
40
41 // ------------------------------Util start------------------------------------
42
_max_size()43 static uint64_t _max_size() {
44 return std::numeric_limits<uint64_t>::max();
45 }
46
_create_regex(const std::string & pattern)47 static Re2UPtr _create_regex(const std::string& pattern) {
48 assert(!pattern.empty());
49
50 return std::make_unique<re2::RE2>("(" + pattern + ")");
51 }
52
_build_special_token_regex(const Encoder & special_encoder)53 static Re2UPtr _build_special_token_regex(const Encoder& special_encoder) {
54 std::string special_pattern;
55 for (const auto& ele : special_encoder) {
56 if (!special_pattern.empty()) {
57 special_pattern += "|";
58 }
59 special_pattern += re2::RE2::QuoteMeta(ele.first);
60 }
61
62 if (special_pattern.empty()) {
63 return nullptr;
64 }
65
66 return _create_regex(special_pattern);
67 }
68
_parse(const std::string & line)69 static Result<std::pair<std::string, uint64_t>> _parse(
70 const std::string& line) {
71 // Tiktoken format
72 // https://github.com/openai/tiktoken/blob/main/tiktoken/load.py#L140 <base64
73 // encoded token str> <rank>
74 auto pos = line.find(" ");
75 ET_CHECK_OR_RETURN_ERROR(
76 pos != std::string::npos,
77 InvalidArgument,
78 "invalid tiktoken line: %s",
79 line.c_str());
80
81 auto token = ET_UNWRAP(base64::decode({line.data(), pos}));
82 uint64_t rank = 0;
83 try {
84 rank = std::stoul(line.substr(pos + 1));
85 } catch (const std::exception&) {
86 ET_CHECK_OR_RETURN_ERROR(
87 false, InvalidArgument, "invalid encoder rank: %s", line.c_str());
88 }
89
90 return std::pair{std::move(token), rank};
91 }
92
_load_encoder(const std::string & path)93 static Result<Encoder> _load_encoder(const std::string& path) {
94 std::ifstream file(path);
95 ET_CHECK_OR_RETURN_ERROR(
96 file, InvalidArgument, "failed to open encoder file: %s", path.c_str());
97
98 Encoder encoder;
99 std::string line;
100 while (std::getline(file, line)) {
101 auto [token, rank] = ET_UNWRAP(_parse(line));
102
103 ET_CHECK_OR_RETURN_ERROR(
104 encoder.emplace(std::move(token), rank).second,
105 InvalidArgument,
106 "duplicate item: %s",
107 line.c_str());
108 }
109
110 return encoder;
111 }
112
_build_decoder(const Encoder & encoder)113 static Result<Decoder> _build_decoder(const Encoder& encoder) {
114 Decoder decoder;
115 for (const auto& [k, v] : encoder) {
116 decoder.emplace(v, k);
117 }
118
119 ET_CHECK_OR_RETURN_ERROR(
120 encoder.size() == decoder.size(),
121 InvalidArgument,
122 "duplicate items in encoder");
123
124 return decoder;
125 }
126
_byte_pair_merge(const std::string & piece,const std::unordered_map<std::string,uint64_t> & ranks,std::function<uint64_t (uint64_t,uint64_t)> func)127 static std::vector<uint64_t> _byte_pair_merge(
128 const std::string& piece,
129 const std::unordered_map<std::string, uint64_t>& ranks,
130 std::function<uint64_t(uint64_t, uint64_t)> func) {
131 // This is a vector of (start, rank).
132 // The rank is of the byte pair starting at position start.
133 // The rank of the last item in the vector is not a valid value.
134 std::vector<std::pair<uint64_t, uint64_t>> parts;
135 parts.reserve(piece.size() + 1);
136 for (auto idx = 0U; idx < piece.size() + 1; ++idx) {
137 parts.emplace_back(idx, _max_size());
138 }
139
140 auto get_rank = [&piece, &ranks](
141 const std::vector<std::pair<uint64_t, uint64_t>>& parts,
142 uint64_t start_idx,
143 uint64_t skip) -> std::optional<uint64_t> {
144 if (start_idx + skip + 2 < parts.size()) {
145 auto s = parts[start_idx].first;
146 auto e = parts[start_idx + skip + 2].first;
147 auto key = piece.substr(s, e - s);
148 auto iter = ranks.find(key);
149 if (iter != ranks.end()) {
150 return iter->second;
151 }
152 }
153 return std::nullopt;
154 };
155
156 // We look up the ranks once in the beginning and iteratively update
157 // them during each merge, which reduces the number of rank lookups.
158 for (auto i = 0U; i < parts.size() - 2; ++i) {
159 auto rank = get_rank(parts, i, 0);
160 if (rank) {
161 // usize::MAX is a sentinel value and cannot be a valid rank
162 ET_CHECK_MSG(*rank != _max_size(), "rank is too large");
163 parts[i].second = *rank;
164 }
165 }
166
167 // If you have n parts and m merges, this does O(mn) work.
168 // We could do something with a heap and do O(m log n) work.
169 // It is important to consider that n is often small (<100), and as such
170 // the cache-locality benefits outweigh the algorithmic complexity downsides
171 // of the `parts` vector data structure above.
172
173 // Note that we hash bytes, not token pairs. As long as we train BPE the way
174 // we currently do, this is equivalent. An easy way to break this would be
175 // to decouple merge priority from token index or to prevent specific token
176 // merges.
177 while (true) {
178 if (parts.size() == 1) {
179 break;
180 }
181
182 // usize::MAX is a sentinel rank value allowing us to
183 // take the min more quickly
184 auto min_rank = std::make_pair<uint64_t, uint64_t>(_max_size(), 0);
185 for (auto i = 0U; i < parts.size() - 1; ++i) {
186 auto rank = parts[i].second;
187 if (rank < min_rank.first) {
188 min_rank.first = rank;
189 min_rank.second = i;
190 }
191 }
192
193 if (min_rank.first != _max_size()) {
194 auto i = min_rank.second;
195
196 // NOTE: We are about to remove parts[i + 1]. We do not do it
197 // yet because there are cache-locality benefits to updating
198 // parts[i] and parts[i-1] before removing, which could thrash
199 // the cache. Thus, we update the rank calculation by skipping over
200 // parts[i + 1], by invoking `get_rank!` with `skip = 1`.
201 auto rank = get_rank(parts, i, 1);
202 if (rank) {
203 parts[i].second = *rank;
204 } else {
205 parts[i].second = _max_size();
206 }
207 if (i > 0) {
208 rank = get_rank(parts, i - 1, 1);
209 if (rank) {
210 parts[i - 1].second = *rank;
211 } else {
212 parts[i - 1].second = _max_size();
213 }
214 }
215
216 parts.erase(parts.begin() + (i + 1));
217 } else {
218 break;
219 }
220 }
221 std::vector<uint64_t> out;
222 out.reserve(parts.size() - 1);
223 for (auto i = 0U; i < parts.size() - 1; ++i) {
224 auto s = parts[i].first;
225 auto e = parts[i + 1].first;
226 out.push_back(func(s, e));
227 }
228 return out;
229 }
230
_byte_pair_encode(const std::string & piece,const Encoder & encoder)231 static std::vector<uint64_t> _byte_pair_encode(
232 const std::string& piece,
233 const Encoder& encoder) {
234 if (piece.size() == 1) {
235 auto iter = encoder.find(piece);
236 if (iter != encoder.end()) {
237 return std::vector<uint64_t>({iter->second});
238 } else {
239 // TODO: is it possible?
240 return {};
241 }
242 }
243
244 return _byte_pair_merge(
245 piece, encoder, [&piece, &encoder](uint64_t start, uint64_t stop) {
246 std::string key = piece.substr(start, stop - start);
247 auto iter = encoder.find(key);
248 if (iter != encoder.end()) {
249 return iter->second;
250 } else {
251 // TODO: what if key does not exist? Should we return `unknown`?
252 // assert(false); // ??
253 return uint64_t(0);
254 }
255 });
256 }
257 // ------------------------------Util end------------------------------------
258 // -------------------------private method start-------------------------------
259
260 template <typename T>
261 std::pair<std::optional<std::string>, re2::StringPiece>
_split_with_allowed_special_token(re2::StringPiece & input,const T & allowed_special) const262 Tiktoken::_split_with_allowed_special_token(
263 re2::StringPiece& input,
264 const T& allowed_special) const {
265 if (!_special_token_regex) {
266 return std::make_pair(std::nullopt, input);
267 }
268
269 #if __cplusplus >= 202002L
270 auto start = input.begin();
271 #else
272 const char* start = input.data();
273 #endif
274 std::string special;
275 while (true) {
276 if (!re2::RE2::FindAndConsume(&input, *_special_token_regex, &special)) {
277 // No special token.
278 break;
279 }
280
281 if (allowed_special.count(special) == 1) {
282 // Found an allowed special token, split the text with it.
283 #if __cplusplus >= 202002L
284 return std::make_pair(
285 special,
286 re2::StringPiece(start, input.begin() - start - special.size()));
287 #else
288 return std::make_pair(
289 special,
290 re2::StringPiece(start, (input.data() - start) - special.size()));
291 #endif
292 } // else try to find the next special token
293 }
294
295 return std::make_pair(std::nullopt, input);
296 }
297
_encode(re2::StringPiece & input,std::vector<uint64_t> & ret,uint64_t & last_piece_token_len) const298 void Tiktoken::_encode(
299 re2::StringPiece& input,
300 std::vector<uint64_t>& ret,
301 uint64_t& last_piece_token_len) const {
302 std::string piece;
303 assert(_regex);
304 while (re2::RE2::FindAndConsume(&input, *_regex, &piece)) {
305 auto iter = _encoder.find(piece);
306 if (iter != _encoder.end()) {
307 last_piece_token_len = 1;
308 ret.push_back(iter->second);
309 continue;
310 }
311 auto tokens = _byte_pair_encode(piece, _encoder);
312 last_piece_token_len = tokens.size();
313 ret.insert(ret.end(), tokens.begin(), tokens.end());
314 }
315 }
316
317 template <typename T>
_encode_with_special_token(const std::string & text,const T & allowed_special) const318 std::pair<std::vector<uint64_t>, uint64_t> Tiktoken::_encode_with_special_token(
319 const std::string& text,
320 const T& allowed_special) const {
321 std::vector<uint64_t> tokens;
322 uint64_t last_piece_token_len = 0;
323 re2::StringPiece input(text);
324 while (true) {
325 auto [special, sub_input] =
326 _split_with_allowed_special_token(input, allowed_special);
327
328 _encode(sub_input, tokens, last_piece_token_len);
329
330 if (special) {
331 uint64_t token = 0;
332 try {
333 token = _special_token_encoder.at(*special);
334 } catch (const std::out_of_range&) {
335 // Should never go here, since special pattern includes all special
336 // chars.
337 ET_CHECK_MSG(false, "unknown special token: %s", special->c_str());
338 }
339
340 tokens.push_back(token);
341 last_piece_token_len = 0;
342 } else {
343 break;
344 }
345 }
346
347 // last_piece_token_len is how many tokens came from the last regex split.
348 // This is used for determining unstable tokens, since you can't merge
349 // across (stable) regex splits
350 return std::make_pair(tokens, last_piece_token_len);
351 }
352
_build_special_token_encoder(ssize_t num_base_tokens) const353 Encoder Tiktoken::_build_special_token_encoder(ssize_t num_base_tokens) const {
354 Encoder special_token_encoder;
355 for (ssize_t i = 0; i < _special_tokens->size(); ++i) {
356 special_token_encoder.emplace(_special_tokens->at(i), num_base_tokens + i);
357 }
358 return special_token_encoder;
359 }
360
361 // -------------------------private method end-------------------------------
362 // -------------------------public method start-------------------------------
363
Tiktoken(std::unique_ptr<std::vector<std::string>> special_tokens,size_t bos_token_index,size_t eos_token_index)364 Tiktoken::Tiktoken(
365 std::unique_ptr<std::vector<std::string>> special_tokens,
366 size_t bos_token_index,
367 size_t eos_token_index)
368 : Tokenizer(),
369 _special_tokens(std::move(special_tokens)),
370 _bos_token_index(bos_token_index),
371 _eos_token_index(eos_token_index) {
372 ET_CHECK_MSG(
373 _bos_token_index < _special_tokens->size(),
374 "invalid bos_token_index %zu",
375 _bos_token_index);
376 ET_CHECK_MSG(
377 _eos_token_index < _special_tokens->size(),
378 "invalid eos_token_index %zu",
379 _eos_token_index);
380 }
381
load(const std::string & path)382 Error Tiktoken::load(const std::string& path) {
383 _encoder = ET_UNWRAP(_load_encoder(path));
384 _special_token_encoder = _build_special_token_encoder(_encoder.size());
385
386 _decoder = ET_UNWRAP(_build_decoder(_encoder));
387 _special_token_decoder = ET_UNWRAP(_build_decoder(_special_token_encoder));
388
389 _regex = _create_regex(_pattern);
390 // Warmup re2 as it is slow on the first run, void the return value as it's
391 // not needed Refer to
392 // https://github.com/google/re2/blob/6dcd83d60f7944926bfd308cc13979fc53dd69ca/re2/fuzzing/re2_fuzzer.cc#L136-L141
393 (void)_regex->ReverseProgramSize();
394
395 _special_token_regex = _build_special_token_regex(_special_token_encoder);
396 // Same as above, warm up re2
397 (void)_special_token_regex->ReverseProgramSize();
398
399 // initialize vocab_size, bos_tok, eos_tok
400 vocab_size_ = _encoder.size() + _special_token_encoder.size();
401 bos_tok_ = _special_token_encoder.at(_special_tokens->at(_bos_token_index));
402 eos_tok_ = _special_token_encoder.at(_special_tokens->at(_eos_token_index));
403
404 initialized_ = true;
405 return Error::Ok;
406 }
407
408 Result<std::vector<uint64_t>>
encode(const std::string & text,int8_t bos,int8_t eos) const409 Tiktoken::encode(const std::string& text, int8_t bos, int8_t eos) const {
410 if (!initialized_) {
411 return Error::NotSupported;
412 }
413 auto res = _encode_with_special_token(text, _special_token_encoder).first;
414 for (auto i = 0; i < bos; ++i) {
415 res.insert(res.begin(), bos_tok_);
416 }
417 for (auto i = 0; i < eos; ++i) {
418 res.push_back(eos_tok_);
419 }
420 return Result<std::vector<uint64_t>>(std::move(res));
421 }
422
decode(uint64_t prev,uint64_t cur) const423 Result<std::string> Tiktoken::decode(uint64_t prev, uint64_t cur) const {
424 (void)prev;
425 ET_CHECK_OK_OR_RETURN_ERROR(Tokenizer::decode_verify(cur));
426 std::string ret;
427
428 std::string token_bytes;
429 auto iter = _decoder.find(cur);
430 if (iter != _decoder.end()) {
431 token_bytes = iter->second;
432 } else {
433 iter = _special_token_decoder.find(cur);
434 if (iter != _special_token_decoder.end()) {
435 token_bytes = iter->second;
436 } else {
437 ET_CHECK_MSG(false, "unknown token: %" PRIu64, cur);
438 }
439 }
440 ret += token_bytes;
441
442 return ret;
443 }
444 // -------------------------public method end-------------------------------
445
446 } // namespace llm
447 } // namespace extension
448 } // namespace executorch
449