1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/extension/llm/tokenizer/bpe_tokenizer.h>
10
11 #include <cstring>
12
13 using ::executorch::runtime::Error;
14 using ::executorch::runtime::Result;
15
16 namespace executorch {
17 namespace extension {
18 namespace llm {
19
compare_tokens(const void * a,const void * b)20 static int compare_tokens(const void* a, const void* b) {
21 if (((TokenIndex*)a)->str == nullptr) {
22 return -1;
23 }
24 if (((TokenIndex*)b)->str == nullptr) {
25 return 1;
26 }
27 return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
28 }
29
BPETokenizer()30 BPETokenizer::BPETokenizer() : Tokenizer() {
31 for (int i = 0; i < 256; i++) {
32 byte_pieces_[i * 2] = (unsigned char)i;
33 byte_pieces_[i * 2 + 1] = '\0';
34 }
35 }
36
37 /**
38 * @brief Load the tokenizer from a file. The tokenizer file contains the
39 * vocabulary and scores. The format is: the first integer is the maximum
40 * token length, followed by a list of (word_len, word) pairs. Here we
41 * are reading all the vocabulary into memory and keep it sorted for fast
42 * lookup.
43 *
44 * @param tokenizer_path The path to the tokenizer file.
45 * @return Error
46 */
load(const std::string & tokenizer_path)47 Error BPETokenizer::load(const std::string& tokenizer_path) {
48 if (initialized_) {
49 ET_LOG(Info, "Tokenizer already initialized");
50 return Error::Ok;
51 }
52 // read in the file
53 FILE* file = fopen(tokenizer_path.c_str(), "rb");
54 if (!file) {
55 ET_LOG(Error, "couldn't load %s", tokenizer_path.c_str());
56 return Error::InvalidArgument;
57 }
58 int32_t metadata[4];
59 for (int i = 0; i < 4; i++) {
60 if (fread(metadata + i, sizeof(int32_t), 1, file) != 1) {
61 ET_LOG(
62 Error,
63 "Failed to read the metadata at position %d, the tokenizer file is not valid!",
64 i);
65 return Error::InvalidArgument;
66 }
67 }
68
69 // now we have two vocab_sizes one from the model and another from the
70 // tokenizer file.
71 int32_t tokenizer_vocab_size = metadata[0];
72 vocab_size_ = tokenizer_vocab_size;
73 bos_tok_ = metadata[1];
74 eos_tok_ = metadata[2];
75 max_token_length_ = metadata[3];
76
77 // allocate space for the vocabulary
78 vocab_ = std::make_unique<char*[]>(vocab_size_);
79 vocab_scores_ = std::make_unique<float[]>(vocab_size_);
80 sorted_vocab_ = std::make_unique<TokenIndex[]>(vocab_size_);
81
82 // read in the vocabulary
83 for (int i = 0; i < vocab_size_; i++) {
84 if (fread(vocab_scores_.get() + i, sizeof(float), 1, file) != 1) {
85 // This is allowed, we just pad the rest of the vocab with <pad> strings
86 std::string padding = "<pad>";
87 vocab_[i] = new char[padding.length() + 1];
88 strcpy(vocab_[i], padding.c_str());
89 vocab_[i][padding.length()] = '\0';
90 continue;
91 }
92 int32_t len;
93 if (fread(&len, sizeof(int32_t), 1, file) != 1) {
94 ET_LOG(Error, "Failed to read the length of the word at index %d", i);
95 return Error::InvalidArgument;
96 }
97 vocab_[i] = new char[len + 1];
98 if (fread(vocab_[i], len, 1, file) != 1) {
99 ET_LOG(
100 Error,
101 "Failed to read the word, total length %d, index %d\n",
102 len,
103 i);
104 return Error::InvalidArgument;
105 }
106 vocab_[i][len] = '\0'; // add the string terminating token
107 }
108 fclose(file);
109
110 for (int32_t i = 0; i < vocab_size_; i++) {
111 sorted_vocab_[i].str = vocab_[i];
112 sorted_vocab_[i].id = i;
113 }
114 qsort(sorted_vocab_.get(), vocab_size_, sizeof(TokenIndex), compare_tokens);
115
116 initialized_ = true;
117 return Error::Ok;
118 }
119
~BPETokenizer()120 BPETokenizer::~BPETokenizer() {
121 for (int i = 0; i < vocab_size_; i++) {
122 delete[] vocab_[i];
123 }
124 }
125
126 /**
127 * @brief Decode a token into string.
128 *
129 * @param prev_token The previous token.
130 * @param token The current token.
131 * @return Result<std::string> A pointer to the string representation of the
132 * token.
133 */
decode(uint64_t prev_token,uint64_t token) const134 Result<std::string> BPETokenizer::decode(uint64_t prev_token, uint64_t token)
135 const {
136 ET_CHECK_OK_OR_RETURN_ERROR(Tokenizer::decode_verify(token));
137 const char* piece = vocab_[token];
138 // following BOS token, sentencepiece decoder strips any leading
139 // whitespace
140 if (prev_token == bos_tok_ && piece[0] == ' ') {
141 piece++;
142 }
143 // careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
144 // parse this and convert and return the actual byte
145 unsigned char byte_val;
146 if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {
147 piece = (char*)byte_pieces_ + byte_val * 2;
148 }
149 std::string res(piece);
150 return res;
151 }
152
153 static int32_t
str_lookup(const char * str,TokenIndex * sorted_vocab,int32_t vocab_size)154 str_lookup(const char* str, TokenIndex* sorted_vocab, int32_t vocab_size) {
155 // efficiently find the perfect match for str in vocab, return its index or -1
156 // if not found
157 TokenIndex tok = {.str = str}; // acts as the key to search for
158 TokenIndex* res = (TokenIndex*)bsearch(
159 &tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
160 return res != nullptr ? res->id : -1;
161 }
162
163 /**
164 * @brief Encode a string into a sequence of tokens.
165 *
166 * @param text The string to be encoded.
167 * @param bos The number of BOS to prepend to the token list.
168 * @param eos The number of EOS to append to the token list.
169 * @param tokens The output tokens.
170 * @param n_tokens The number of tokens.
171 * @return Result<std::vector<uint64_t>>
172 */
173 Result<std::vector<uint64_t>>
encode(const std::string & text,int8_t bos,int8_t eos) const174 BPETokenizer::encode(const std::string& text, int8_t bos, int8_t eos) const {
175 if (!initialized_) {
176 ET_LOG(Error, "Tokenizer not initialized");
177 return Error::NotSupported;
178 }
179 // encode the string text (input) into an upper-bound preallocated tokens[]
180 // array bos != 0 means prepend the BOS token (=1), eos != 0 means append the
181 // EOS token (=2)
182 if (text.empty()) {
183 ET_LOG(Error, "cannot encode empty text");
184 return Error::InvalidArgument;
185 }
186
187 // create a temporary buffer that will store merge candidates of always two
188 // consecutive tokens *2 for concat, +1 for null terminator +2 for UTF8 (in
189 // case max_token_length is 1)
190 char* str_buffer = new char[max_token_length_ * 2 + 1 + 2];
191 size_t str_len = 0;
192
193 // start at 0 tokens
194 std::vector<uint64_t> tokens;
195
196 // add optional BOS token, if desired
197 if (bos >= 0) {
198 while (bos--) {
199 tokens.push_back(bos_tok_);
200 }
201 } else {
202 ET_LOG(Error, "bos %d should be >= 0", bos);
203 return Error::InvalidArgument;
204 }
205
206 // add_dummy_prefix is true by default
207 // so prepend a dummy prefix token to the input string, but only if text != ""
208 // TODO: pretty sure this isn't correct in the general case but I don't have
209 // the energy to read more of the sentencepiece code to figure out what it's
210 // doing
211 const char* space = " ";
212 if (text[0] != '\0') {
213 int dummy_prefix = str_lookup(space, sorted_vocab_.get(), vocab_size_);
214 tokens.push_back(dummy_prefix);
215 }
216
217 // Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
218 // Code point ↔ UTF-8 conversion
219 // First code point Last code point Byte 1 Byte 2 Byte 3 Byte 4
220 // U+0000 U+007F 0xxxxxxx
221 // U+0080 U+07FF 110xxxxx 10xxxxxx
222 // U+0800 U+FFFF 1110xxxx 10xxxxxx 10xxxxxx
223 // U+10000 U+10FFFF 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
224
225 // process the raw (UTF-8) byte sequence of the input string
226 for (const char* c = text.c_str(); *c != '\0'; c++) {
227 // reset buffer if the current byte is ASCII or a leading byte
228 // 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the
229 // rest 0x80 is 10000000 in UTF-8, all continuation bytes start with "10" in
230 // first two bits so in English this is: "if this byte is not a continuation
231 // byte"
232 if ((*c & 0xC0) != 0x80) {
233 // this byte must be either a leading byte (11...) or an ASCII char
234 // (0x...)
235 // => reset our location, as we're starting a new UTF-8 codepoint
236 str_len = 0;
237 }
238
239 // append the current byte to the buffer
240 str_buffer[str_len++] =
241 *c; // ++ is post-increment, incremented after this line
242 str_buffer[str_len] = '\0';
243
244 // while the next character is a continuation byte, continue appending
245 // but if there are too many of them, just stop to avoid overruning
246 // str_buffer size.
247 if ((*(c + 1) & 0xC0) == 0x80 && str_len < 4) {
248 continue;
249 }
250
251 // ok c+1 is not a continuation byte, so we've read in a full codepoint
252 int id = str_lookup(str_buffer, sorted_vocab_.get(), vocab_size_);
253 if (id != -1) {
254 // we found this codepoint in vocab, add it as a token
255 tokens.push_back(id);
256 } else {
257 // byte_fallback encoding: just encode each byte as a token
258 // +3 is here because the first 3 vocab elements are <unk>, <s>, </s>
259 // so the individual bytes only start at index 3
260 for (int i = 0; i < str_len; i++) {
261 tokens.push_back((unsigned char)str_buffer[i] + 3);
262 }
263 }
264 str_len = 0; // protect against a sequence of stray UTF8 continuation bytes
265 }
266
267 // merge the best consecutive pair each iteration, according the scores in
268 // vocab_scores
269 while (1) {
270 float best_score = -1e10;
271 int best_id = -1;
272 int best_idx = -1;
273
274 for (int i = 0; i < tokens.size() - 1; i++) {
275 // check if we can merge the pair (tokens[i], tokens[i+1])
276 snprintf(
277 str_buffer,
278 max_token_length_ * 2 + 3,
279 "%s%s",
280 vocab_[tokens[i]],
281 vocab_[tokens[i + 1]]);
282 int id = str_lookup(str_buffer, sorted_vocab_.get(), vocab_size_);
283 if (id != -1 && vocab_scores_[id] > best_score) {
284 // this merge pair exists in vocab! record its score and position
285 best_score = vocab_scores_[id];
286 best_id = id;
287 best_idx = i;
288 }
289 }
290
291 if (best_idx == -1) {
292 break; // we couldn't find any more pairs to merge, so we're done
293 }
294
295 // merge the consecutive pair (best_idx, best_idx+1) into new token best_id
296 tokens[best_idx] = best_id;
297 // delete token at position best_idx+1, shift the entire sequence back 1
298 for (int i = best_idx + 1; i < tokens.size() - 1; i++) {
299 tokens[i] = tokens[i + 1];
300 }
301 tokens.pop_back(); // token length decreased
302 }
303
304 // add optional EOS (=2) token, if desired
305 if (eos >= 0) {
306 while (eos--) {
307 tokens.push_back(eos_tok_);
308 }
309 } else {
310 ET_LOG(Error, "eos %d should be >= 0", eos);
311 return Error::InvalidArgument;
312 }
313
314 delete[] str_buffer;
315 return Result(tokens);
316 }
317
318 } // namespace llm
319 } // namespace extension
320 } // namespace executorch
321