xref: /aosp_15_r20/external/executorch/extension/llm/tokenizer/tiktoken.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Adopted from https://github.com/sewenew/tokenizer
10 
11 // @lint-ignore-every LICENSELINT
12 /**************************************************************************
13    Copyright (c) 2023 sewenew
14 
15    Licensed under the Apache License, Version 2.0 (the "License");
16    you may not use this file except in compliance with the License.
17    You may obtain a copy of the License at
18 
19        http://www.apache.org/licenses/LICENSE-2.0
20 
21    Unless required by applicable law or agreed to in writing, software
22    distributed under the License is distributed on an "AS IS" BASIS,
23    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24    See the License for the specific language governing permissions and
25    limitations under the License.
26  *************************************************************************/
27 
28 #include <executorch/extension/llm/tokenizer/base64.h>
29 #include <executorch/extension/llm/tokenizer/tiktoken.h>
30 #include <executorch/runtime/core/result.h>
31 #include <fstream>
32 #include <limits>
33 
34 using ::executorch::runtime::Error;
35 using ::executorch::runtime::Result;
36 
37 namespace executorch {
38 namespace extension {
39 namespace llm {
40 
41 // ------------------------------Util start------------------------------------
42 
_max_size()43 static uint64_t _max_size() {
44   return std::numeric_limits<uint64_t>::max();
45 }
46 
_create_regex(const std::string & pattern)47 static Re2UPtr _create_regex(const std::string& pattern) {
48   assert(!pattern.empty());
49 
50   return std::make_unique<re2::RE2>("(" + pattern + ")");
51 }
52 
_build_special_token_regex(const Encoder & special_encoder)53 static Re2UPtr _build_special_token_regex(const Encoder& special_encoder) {
54   std::string special_pattern;
55   for (const auto& ele : special_encoder) {
56     if (!special_pattern.empty()) {
57       special_pattern += "|";
58     }
59     special_pattern += re2::RE2::QuoteMeta(ele.first);
60   }
61 
62   if (special_pattern.empty()) {
63     return nullptr;
64   }
65 
66   return _create_regex(special_pattern);
67 }
68 
_parse(const std::string & line)69 static Result<std::pair<std::string, uint64_t>> _parse(
70     const std::string& line) {
71   // Tiktoken format
72   // https://github.com/openai/tiktoken/blob/main/tiktoken/load.py#L140 <base64
73   // encoded token str> <rank>
74   auto pos = line.find(" ");
75   ET_CHECK_OR_RETURN_ERROR(
76       pos != std::string::npos,
77       InvalidArgument,
78       "invalid tiktoken line: %s",
79       line.c_str());
80 
81   auto token = ET_UNWRAP(base64::decode({line.data(), pos}));
82   uint64_t rank = 0;
83   try {
84     rank = std::stoul(line.substr(pos + 1));
85   } catch (const std::exception&) {
86     ET_CHECK_OR_RETURN_ERROR(
87         false, InvalidArgument, "invalid encoder rank: %s", line.c_str());
88   }
89 
90   return std::pair{std::move(token), rank};
91 }
92 
_load_encoder(const std::string & path)93 static Result<Encoder> _load_encoder(const std::string& path) {
94   std::ifstream file(path);
95   ET_CHECK_OR_RETURN_ERROR(
96       file, InvalidArgument, "failed to open encoder file: %s", path.c_str());
97 
98   Encoder encoder;
99   std::string line;
100   while (std::getline(file, line)) {
101     auto [token, rank] = ET_UNWRAP(_parse(line));
102 
103     ET_CHECK_OR_RETURN_ERROR(
104         encoder.emplace(std::move(token), rank).second,
105         InvalidArgument,
106         "duplicate item: %s",
107         line.c_str());
108   }
109 
110   return encoder;
111 }
112 
_build_decoder(const Encoder & encoder)113 static Result<Decoder> _build_decoder(const Encoder& encoder) {
114   Decoder decoder;
115   for (const auto& [k, v] : encoder) {
116     decoder.emplace(v, k);
117   }
118 
119   ET_CHECK_OR_RETURN_ERROR(
120       encoder.size() == decoder.size(),
121       InvalidArgument,
122       "duplicate items in encoder");
123 
124   return decoder;
125 }
126 
_byte_pair_merge(const std::string & piece,const std::unordered_map<std::string,uint64_t> & ranks,std::function<uint64_t (uint64_t,uint64_t)> func)127 static std::vector<uint64_t> _byte_pair_merge(
128     const std::string& piece,
129     const std::unordered_map<std::string, uint64_t>& ranks,
130     std::function<uint64_t(uint64_t, uint64_t)> func) {
131   // This is a vector of (start, rank).
132   // The rank is of the byte pair starting at position start.
133   // The rank of the last item in the vector is not a valid value.
134   std::vector<std::pair<uint64_t, uint64_t>> parts;
135   parts.reserve(piece.size() + 1);
136   for (auto idx = 0U; idx < piece.size() + 1; ++idx) {
137     parts.emplace_back(idx, _max_size());
138   }
139 
140   auto get_rank = [&piece, &ranks](
141                       const std::vector<std::pair<uint64_t, uint64_t>>& parts,
142                       uint64_t start_idx,
143                       uint64_t skip) -> std::optional<uint64_t> {
144     if (start_idx + skip + 2 < parts.size()) {
145       auto s = parts[start_idx].first;
146       auto e = parts[start_idx + skip + 2].first;
147       auto key = piece.substr(s, e - s);
148       auto iter = ranks.find(key);
149       if (iter != ranks.end()) {
150         return iter->second;
151       }
152     }
153     return std::nullopt;
154   };
155 
156   // We look up the ranks once in the beginning and iteratively update
157   // them during each merge, which reduces the number of rank lookups.
158   for (auto i = 0U; i < parts.size() - 2; ++i) {
159     auto rank = get_rank(parts, i, 0);
160     if (rank) {
161       // usize::MAX is a sentinel value and cannot be a valid rank
162       ET_CHECK_MSG(*rank != _max_size(), "rank is too large");
163       parts[i].second = *rank;
164     }
165   }
166 
167   // If you have n parts and m merges, this does O(mn) work.
168   // We could do something with a heap and do O(m log n) work.
169   // It is important to consider that n is often small (<100), and as such
170   // the cache-locality benefits outweigh the algorithmic complexity downsides
171   // of the `parts` vector data structure above.
172 
173   // Note that we hash bytes, not token pairs. As long as we train BPE the way
174   // we currently do, this is equivalent. An easy way to break this would be
175   // to decouple merge priority from token index or to prevent specific token
176   // merges.
177   while (true) {
178     if (parts.size() == 1) {
179       break;
180     }
181 
182     // usize::MAX is a sentinel rank value allowing us to
183     // take the min more quickly
184     auto min_rank = std::make_pair<uint64_t, uint64_t>(_max_size(), 0);
185     for (auto i = 0U; i < parts.size() - 1; ++i) {
186       auto rank = parts[i].second;
187       if (rank < min_rank.first) {
188         min_rank.first = rank;
189         min_rank.second = i;
190       }
191     }
192 
193     if (min_rank.first != _max_size()) {
194       auto i = min_rank.second;
195 
196       // NOTE: We are about to remove parts[i + 1]. We do not do it
197       // yet because there are cache-locality benefits to updating
198       // parts[i] and parts[i-1] before removing, which could thrash
199       // the cache. Thus, we update the rank calculation by skipping over
200       // parts[i + 1], by invoking `get_rank!` with `skip = 1`.
201       auto rank = get_rank(parts, i, 1);
202       if (rank) {
203         parts[i].second = *rank;
204       } else {
205         parts[i].second = _max_size();
206       }
207       if (i > 0) {
208         rank = get_rank(parts, i - 1, 1);
209         if (rank) {
210           parts[i - 1].second = *rank;
211         } else {
212           parts[i - 1].second = _max_size();
213         }
214       }
215 
216       parts.erase(parts.begin() + (i + 1));
217     } else {
218       break;
219     }
220   }
221   std::vector<uint64_t> out;
222   out.reserve(parts.size() - 1);
223   for (auto i = 0U; i < parts.size() - 1; ++i) {
224     auto s = parts[i].first;
225     auto e = parts[i + 1].first;
226     out.push_back(func(s, e));
227   }
228   return out;
229 }
230 
_byte_pair_encode(const std::string & piece,const Encoder & encoder)231 static std::vector<uint64_t> _byte_pair_encode(
232     const std::string& piece,
233     const Encoder& encoder) {
234   if (piece.size() == 1) {
235     auto iter = encoder.find(piece);
236     if (iter != encoder.end()) {
237       return std::vector<uint64_t>({iter->second});
238     } else {
239       // TODO: is it possible?
240       return {};
241     }
242   }
243 
244   return _byte_pair_merge(
245       piece, encoder, [&piece, &encoder](uint64_t start, uint64_t stop) {
246         std::string key = piece.substr(start, stop - start);
247         auto iter = encoder.find(key);
248         if (iter != encoder.end()) {
249           return iter->second;
250         } else {
251           // TODO: what if key does not exist? Should we return `unknown`?
252           // assert(false); // ??
253           return uint64_t(0);
254         }
255       });
256 }
257 // ------------------------------Util end------------------------------------
258 // -------------------------private method start-------------------------------
259 
260 template <typename T>
261 std::pair<std::optional<std::string>, re2::StringPiece>
_split_with_allowed_special_token(re2::StringPiece & input,const T & allowed_special) const262 Tiktoken::_split_with_allowed_special_token(
263     re2::StringPiece& input,
264     const T& allowed_special) const {
265   if (!_special_token_regex) {
266     return std::make_pair(std::nullopt, input);
267   }
268 
269 #if __cplusplus >= 202002L
270   auto start = input.begin();
271 #else
272   const char* start = input.data();
273 #endif
274   std::string special;
275   while (true) {
276     if (!re2::RE2::FindAndConsume(&input, *_special_token_regex, &special)) {
277       // No special token.
278       break;
279     }
280 
281     if (allowed_special.count(special) == 1) {
282       // Found an allowed special token, split the text with it.
283 #if __cplusplus >= 202002L
284       return std::make_pair(
285           special,
286           re2::StringPiece(start, input.begin() - start - special.size()));
287 #else
288       return std::make_pair(
289           special,
290           re2::StringPiece(start, (input.data() - start) - special.size()));
291 #endif
292     } // else try to find the next special token
293   }
294 
295   return std::make_pair(std::nullopt, input);
296 }
297 
_encode(re2::StringPiece & input,std::vector<uint64_t> & ret,uint64_t & last_piece_token_len) const298 void Tiktoken::_encode(
299     re2::StringPiece& input,
300     std::vector<uint64_t>& ret,
301     uint64_t& last_piece_token_len) const {
302   std::string piece;
303   assert(_regex);
304   while (re2::RE2::FindAndConsume(&input, *_regex, &piece)) {
305     auto iter = _encoder.find(piece);
306     if (iter != _encoder.end()) {
307       last_piece_token_len = 1;
308       ret.push_back(iter->second);
309       continue;
310     }
311     auto tokens = _byte_pair_encode(piece, _encoder);
312     last_piece_token_len = tokens.size();
313     ret.insert(ret.end(), tokens.begin(), tokens.end());
314   }
315 }
316 
317 template <typename T>
_encode_with_special_token(const std::string & text,const T & allowed_special) const318 std::pair<std::vector<uint64_t>, uint64_t> Tiktoken::_encode_with_special_token(
319     const std::string& text,
320     const T& allowed_special) const {
321   std::vector<uint64_t> tokens;
322   uint64_t last_piece_token_len = 0;
323   re2::StringPiece input(text);
324   while (true) {
325     auto [special, sub_input] =
326         _split_with_allowed_special_token(input, allowed_special);
327 
328     _encode(sub_input, tokens, last_piece_token_len);
329 
330     if (special) {
331       uint64_t token = 0;
332       try {
333         token = _special_token_encoder.at(*special);
334       } catch (const std::out_of_range&) {
335         // Should never go here, since special pattern includes all special
336         // chars.
337         ET_CHECK_MSG(false, "unknown special token: %s", special->c_str());
338       }
339 
340       tokens.push_back(token);
341       last_piece_token_len = 0;
342     } else {
343       break;
344     }
345   }
346 
347   // last_piece_token_len is how many tokens came from the last regex split.
348   // This is used for determining unstable tokens, since you can't merge
349   // across (stable) regex splits
350   return std::make_pair(tokens, last_piece_token_len);
351 }
352 
_build_special_token_encoder(ssize_t num_base_tokens) const353 Encoder Tiktoken::_build_special_token_encoder(ssize_t num_base_tokens) const {
354   Encoder special_token_encoder;
355   for (ssize_t i = 0; i < _special_tokens->size(); ++i) {
356     special_token_encoder.emplace(_special_tokens->at(i), num_base_tokens + i);
357   }
358   return special_token_encoder;
359 }
360 
361 // -------------------------private method end-------------------------------
362 // -------------------------public method start-------------------------------
363 
Tiktoken(std::unique_ptr<std::vector<std::string>> special_tokens,size_t bos_token_index,size_t eos_token_index)364 Tiktoken::Tiktoken(
365     std::unique_ptr<std::vector<std::string>> special_tokens,
366     size_t bos_token_index,
367     size_t eos_token_index)
368     : Tokenizer(),
369       _special_tokens(std::move(special_tokens)),
370       _bos_token_index(bos_token_index),
371       _eos_token_index(eos_token_index) {
372   ET_CHECK_MSG(
373       _bos_token_index < _special_tokens->size(),
374       "invalid bos_token_index %zu",
375       _bos_token_index);
376   ET_CHECK_MSG(
377       _eos_token_index < _special_tokens->size(),
378       "invalid eos_token_index %zu",
379       _eos_token_index);
380 }
381 
load(const std::string & path)382 Error Tiktoken::load(const std::string& path) {
383   _encoder = ET_UNWRAP(_load_encoder(path));
384   _special_token_encoder = _build_special_token_encoder(_encoder.size());
385 
386   _decoder = ET_UNWRAP(_build_decoder(_encoder));
387   _special_token_decoder = ET_UNWRAP(_build_decoder(_special_token_encoder));
388 
389   _regex = _create_regex(_pattern);
390   // Warmup re2 as it is slow on the first run, void the return value as it's
391   // not needed Refer to
392   // https://github.com/google/re2/blob/6dcd83d60f7944926bfd308cc13979fc53dd69ca/re2/fuzzing/re2_fuzzer.cc#L136-L141
393   (void)_regex->ReverseProgramSize();
394 
395   _special_token_regex = _build_special_token_regex(_special_token_encoder);
396   // Same as above, warm up re2
397   (void)_special_token_regex->ReverseProgramSize();
398 
399   // initialize vocab_size, bos_tok, eos_tok
400   vocab_size_ = _encoder.size() + _special_token_encoder.size();
401   bos_tok_ = _special_token_encoder.at(_special_tokens->at(_bos_token_index));
402   eos_tok_ = _special_token_encoder.at(_special_tokens->at(_eos_token_index));
403 
404   initialized_ = true;
405   return Error::Ok;
406 }
407 
408 Result<std::vector<uint64_t>>
encode(const std::string & text,int8_t bos,int8_t eos) const409 Tiktoken::encode(const std::string& text, int8_t bos, int8_t eos) const {
410   if (!initialized_) {
411     return Error::NotSupported;
412   }
413   auto res = _encode_with_special_token(text, _special_token_encoder).first;
414   for (auto i = 0; i < bos; ++i) {
415     res.insert(res.begin(), bos_tok_);
416   }
417   for (auto i = 0; i < eos; ++i) {
418     res.push_back(eos_tok_);
419   }
420   return Result<std::vector<uint64_t>>(std::move(res));
421 }
422 
decode(uint64_t prev,uint64_t cur) const423 Result<std::string> Tiktoken::decode(uint64_t prev, uint64_t cur) const {
424   (void)prev;
425   ET_CHECK_OK_OR_RETURN_ERROR(Tokenizer::decode_verify(cur));
426   std::string ret;
427 
428   std::string token_bytes;
429   auto iter = _decoder.find(cur);
430   if (iter != _decoder.end()) {
431     token_bytes = iter->second;
432   } else {
433     iter = _special_token_decoder.find(cur);
434     if (iter != _special_token_decoder.end()) {
435       token_bytes = iter->second;
436     } else {
437       ET_CHECK_MSG(false, "unknown token: %" PRIu64, cur);
438     }
439   }
440   ret += token_bytes;
441 
442   return ret;
443 }
444 // -------------------------public method end-------------------------------
445 
446 } // namespace llm
447 } // namespace extension
448 } // namespace executorch
449