xref: /aosp_15_r20/external/executorch/extension/llm/tokenizer/bpe_tokenizer.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/extension/llm/tokenizer/bpe_tokenizer.h>
10 
11 #include <cstring>
12 
13 using ::executorch::runtime::Error;
14 using ::executorch::runtime::Result;
15 
16 namespace executorch {
17 namespace extension {
18 namespace llm {
19 
compare_tokens(const void * a,const void * b)20 static int compare_tokens(const void* a, const void* b) {
21   if (((TokenIndex*)a)->str == nullptr) {
22     return -1;
23   }
24   if (((TokenIndex*)b)->str == nullptr) {
25     return 1;
26   }
27   return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
28 }
29 
BPETokenizer()30 BPETokenizer::BPETokenizer() : Tokenizer() {
31   for (int i = 0; i < 256; i++) {
32     byte_pieces_[i * 2] = (unsigned char)i;
33     byte_pieces_[i * 2 + 1] = '\0';
34   }
35 }
36 
37 /**
38  * @brief Load the tokenizer from a file. The tokenizer file contains the
39  * vocabulary and scores. The format is: the first integer is the maximum
40  * token length, followed by a list of (word_len, word) pairs. Here we
41  * are reading all the vocabulary into memory and keep it sorted for fast
42  * lookup.
43  *
44  * @param tokenizer_path The path to the tokenizer file.
45  * @return Error
46  */
load(const std::string & tokenizer_path)47 Error BPETokenizer::load(const std::string& tokenizer_path) {
48   if (initialized_) {
49     ET_LOG(Info, "Tokenizer already initialized");
50     return Error::Ok;
51   }
52   // read in the file
53   FILE* file = fopen(tokenizer_path.c_str(), "rb");
54   if (!file) {
55     ET_LOG(Error, "couldn't load %s", tokenizer_path.c_str());
56     return Error::InvalidArgument;
57   }
58   int32_t metadata[4];
59   for (int i = 0; i < 4; i++) {
60     if (fread(metadata + i, sizeof(int32_t), 1, file) != 1) {
61       ET_LOG(
62           Error,
63           "Failed to read the metadata at position %d, the tokenizer file is not valid!",
64           i);
65       return Error::InvalidArgument;
66     }
67   }
68 
69   // now we have two vocab_sizes one from the model and another from the
70   // tokenizer file.
71   int32_t tokenizer_vocab_size = metadata[0];
72   vocab_size_ = tokenizer_vocab_size;
73   bos_tok_ = metadata[1];
74   eos_tok_ = metadata[2];
75   max_token_length_ = metadata[3];
76 
77   // allocate space for the vocabulary
78   vocab_ = std::make_unique<char*[]>(vocab_size_);
79   vocab_scores_ = std::make_unique<float[]>(vocab_size_);
80   sorted_vocab_ = std::make_unique<TokenIndex[]>(vocab_size_);
81 
82   // read in the vocabulary
83   for (int i = 0; i < vocab_size_; i++) {
84     if (fread(vocab_scores_.get() + i, sizeof(float), 1, file) != 1) {
85       // This is allowed, we just pad the rest of the vocab with <pad> strings
86       std::string padding = "<pad>";
87       vocab_[i] = new char[padding.length() + 1];
88       strcpy(vocab_[i], padding.c_str());
89       vocab_[i][padding.length()] = '\0';
90       continue;
91     }
92     int32_t len;
93     if (fread(&len, sizeof(int32_t), 1, file) != 1) {
94       ET_LOG(Error, "Failed to read the length of the word at index %d", i);
95       return Error::InvalidArgument;
96     }
97     vocab_[i] = new char[len + 1];
98     if (fread(vocab_[i], len, 1, file) != 1) {
99       ET_LOG(
100           Error,
101           "Failed to read the word, total length %d, index %d\n",
102           len,
103           i);
104       return Error::InvalidArgument;
105     }
106     vocab_[i][len] = '\0'; // add the string terminating token
107   }
108   fclose(file);
109 
110   for (int32_t i = 0; i < vocab_size_; i++) {
111     sorted_vocab_[i].str = vocab_[i];
112     sorted_vocab_[i].id = i;
113   }
114   qsort(sorted_vocab_.get(), vocab_size_, sizeof(TokenIndex), compare_tokens);
115 
116   initialized_ = true;
117   return Error::Ok;
118 }
119 
~BPETokenizer()120 BPETokenizer::~BPETokenizer() {
121   for (int i = 0; i < vocab_size_; i++) {
122     delete[] vocab_[i];
123   }
124 }
125 
126 /**
127  * @brief Decode a token into string.
128  *
129  * @param prev_token The previous token.
130  * @param token The current token.
131  * @return Result<std::string> A pointer to the string representation of the
132  * token.
133  */
decode(uint64_t prev_token,uint64_t token) const134 Result<std::string> BPETokenizer::decode(uint64_t prev_token, uint64_t token)
135     const {
136   ET_CHECK_OK_OR_RETURN_ERROR(Tokenizer::decode_verify(token));
137   const char* piece = vocab_[token];
138   // following BOS token, sentencepiece decoder strips any leading
139   // whitespace
140   if (prev_token == bos_tok_ && piece[0] == ' ') {
141     piece++;
142   }
143   // careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
144   // parse this and convert and return the actual byte
145   unsigned char byte_val;
146   if (sscanf(piece, "<0x%02hhX>", &byte_val) == 1) {
147     piece = (char*)byte_pieces_ + byte_val * 2;
148   }
149   std::string res(piece);
150   return res;
151 }
152 
153 static int32_t
str_lookup(const char * str,TokenIndex * sorted_vocab,int32_t vocab_size)154 str_lookup(const char* str, TokenIndex* sorted_vocab, int32_t vocab_size) {
155   // efficiently find the perfect match for str in vocab, return its index or -1
156   // if not found
157   TokenIndex tok = {.str = str}; // acts as the key to search for
158   TokenIndex* res = (TokenIndex*)bsearch(
159       &tok, sorted_vocab, vocab_size, sizeof(TokenIndex), compare_tokens);
160   return res != nullptr ? res->id : -1;
161 }
162 
163 /**
164  * @brief Encode a string into a sequence of tokens.
165  *
166  * @param text The string to be encoded.
167  * @param bos The number of BOS to prepend to the token list.
168  * @param eos The number of EOS to append to the token list.
169  * @param tokens The output tokens.
170  * @param n_tokens The number of tokens.
171  * @return Result<std::vector<uint64_t>>
172  */
173 Result<std::vector<uint64_t>>
encode(const std::string & text,int8_t bos,int8_t eos) const174 BPETokenizer::encode(const std::string& text, int8_t bos, int8_t eos) const {
175   if (!initialized_) {
176     ET_LOG(Error, "Tokenizer not initialized");
177     return Error::NotSupported;
178   }
179   // encode the string text (input) into an upper-bound preallocated tokens[]
180   // array bos != 0 means prepend the BOS token (=1), eos != 0 means append the
181   // EOS token (=2)
182   if (text.empty()) {
183     ET_LOG(Error, "cannot encode empty text");
184     return Error::InvalidArgument;
185   }
186 
187   // create a temporary buffer that will store merge candidates of always two
188   // consecutive tokens *2 for concat, +1 for null terminator +2 for UTF8 (in
189   // case max_token_length is 1)
190   char* str_buffer = new char[max_token_length_ * 2 + 1 + 2];
191   size_t str_len = 0;
192 
193   // start at 0 tokens
194   std::vector<uint64_t> tokens;
195 
196   // add optional BOS token, if desired
197   if (bos >= 0) {
198     while (bos--) {
199       tokens.push_back(bos_tok_);
200     }
201   } else {
202     ET_LOG(Error, "bos %d should be >= 0", bos);
203     return Error::InvalidArgument;
204   }
205 
206   // add_dummy_prefix is true by default
207   // so prepend a dummy prefix token to the input string, but only if text != ""
208   // TODO: pretty sure this isn't correct in the general case but I don't have
209   // the energy to read more of the sentencepiece code to figure out what it's
210   // doing
211   const char* space = " ";
212   if (text[0] != '\0') {
213     int dummy_prefix = str_lookup(space, sorted_vocab_.get(), vocab_size_);
214     tokens.push_back(dummy_prefix);
215   }
216 
217   // Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
218   // Code point ↔ UTF-8 conversion
219   // First code point	Last code point	Byte 1	Byte 2	Byte 3	Byte 4
220   // U+0000	U+007F	    0xxxxxxx
221   // U+0080	U+07FF	    110xxxxx	10xxxxxx
222   // U+0800	U+FFFF	    1110xxxx	10xxxxxx	10xxxxxx
223   // U+10000	U+10FFFF    11110xxx	10xxxxxx	10xxxxxx	10xxxxxx
224 
225   // process the raw (UTF-8) byte sequence of the input string
226   for (const char* c = text.c_str(); *c != '\0'; c++) {
227     // reset buffer if the current byte is ASCII or a leading byte
228     // 0xC0 is 11000000, so (*c & 0xC0) keeps the first 2 bits and zeros the
229     // rest 0x80 is 10000000 in UTF-8, all continuation bytes start with "10" in
230     // first two bits so in English this is: "if this byte is not a continuation
231     // byte"
232     if ((*c & 0xC0) != 0x80) {
233       // this byte must be either a leading byte (11...) or an ASCII char
234       // (0x...)
235       // => reset our location, as we're starting a new UTF-8 codepoint
236       str_len = 0;
237     }
238 
239     // append the current byte to the buffer
240     str_buffer[str_len++] =
241         *c; // ++ is post-increment, incremented after this line
242     str_buffer[str_len] = '\0';
243 
244     // while the next character is a continuation byte, continue appending
245     // but if there are too many of them, just stop to avoid overruning
246     // str_buffer size.
247     if ((*(c + 1) & 0xC0) == 0x80 && str_len < 4) {
248       continue;
249     }
250 
251     // ok c+1 is not a continuation byte, so we've read in a full codepoint
252     int id = str_lookup(str_buffer, sorted_vocab_.get(), vocab_size_);
253     if (id != -1) {
254       // we found this codepoint in vocab, add it as a token
255       tokens.push_back(id);
256     } else {
257       // byte_fallback encoding: just encode each byte as a token
258       // +3 is here because the first 3 vocab elements are <unk>, <s>, </s>
259       // so the individual bytes only start at index 3
260       for (int i = 0; i < str_len; i++) {
261         tokens.push_back((unsigned char)str_buffer[i] + 3);
262       }
263     }
264     str_len = 0; // protect against a sequence of stray UTF8 continuation bytes
265   }
266 
267   // merge the best consecutive pair each iteration, according the scores in
268   // vocab_scores
269   while (1) {
270     float best_score = -1e10;
271     int best_id = -1;
272     int best_idx = -1;
273 
274     for (int i = 0; i < tokens.size() - 1; i++) {
275       // check if we can merge the pair (tokens[i], tokens[i+1])
276       snprintf(
277           str_buffer,
278           max_token_length_ * 2 + 3,
279           "%s%s",
280           vocab_[tokens[i]],
281           vocab_[tokens[i + 1]]);
282       int id = str_lookup(str_buffer, sorted_vocab_.get(), vocab_size_);
283       if (id != -1 && vocab_scores_[id] > best_score) {
284         // this merge pair exists in vocab! record its score and position
285         best_score = vocab_scores_[id];
286         best_id = id;
287         best_idx = i;
288       }
289     }
290 
291     if (best_idx == -1) {
292       break; // we couldn't find any more pairs to merge, so we're done
293     }
294 
295     // merge the consecutive pair (best_idx, best_idx+1) into new token best_id
296     tokens[best_idx] = best_id;
297     // delete token at position best_idx+1, shift the entire sequence back 1
298     for (int i = best_idx + 1; i < tokens.size() - 1; i++) {
299       tokens[i] = tokens[i + 1];
300     }
301     tokens.pop_back(); // token length decreased
302   }
303 
304   // add optional EOS (=2) token, if desired
305   if (eos >= 0) {
306     while (eos--) {
307       tokens.push_back(eos_tok_);
308     }
309   } else {
310     ET_LOG(Error, "eos %d should be >= 0", eos);
311     return Error::InvalidArgument;
312   }
313 
314   delete[] str_buffer;
315   return Result(tokens);
316 }
317 
318 } // namespace llm
319 } // namespace extension
320 } // namespace executorch
321