xref: /aosp_15_r20/external/executorch/examples/models/llama/tokenizer/llama_tiktoken.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
10 
11 namespace example {
12 
13 using ::executorch::extension::llm::Tiktoken;
14 
15 namespace {
16 static constexpr int32_t kSpecialTokensSize = 256;
17 static constexpr size_t kBOSTokenIndex = 0;
18 static constexpr size_t kEOSTokenIndex = 1;
19 
20 static inline std::unique_ptr<std::vector<std::string>>
_get_default_special_tokens()21 _get_default_special_tokens() {
22   auto special_tokens =
23       std::make_unique<std::vector<std::string>>(std::vector<std::string>{
24           "<|begin_of_text|>",
25           "<|end_of_text|>",
26           "<|reserved_special_token_0|>",
27           "<|reserved_special_token_1|>",
28           "<|finetune_right_pad_id|>",
29           "<|step_id|>",
30           "<|start_header_id|>",
31           "<|end_header_id|>",
32           "<|eom_id|>",
33           "<|eot_id|>",
34           "<|python_tag|>"});
35   // pad the rest of the special tokens with reserved tokens
36   ssize_t reserved_special_token_num = 2;
37   while (special_tokens->size() < kSpecialTokensSize) {
38     special_tokens->emplace_back(
39         "<|reserved_special_token_" +
40         std::to_string(reserved_special_token_num++) + "|>");
41   }
42   return special_tokens;
43 }
44 
45 static inline std::unique_ptr<std::vector<std::string>>
_get_multimodal_special_tokens()46 _get_multimodal_special_tokens() {
47   auto special_tokens =
48       std::make_unique<std::vector<std::string>>(std::vector<std::string>{
49           "<|begin_of_text|>",
50           "<|end_of_text|>",
51           "<|reserved_special_token_0|>",
52           "<|reserved_special_token_1|>",
53           "<|reserved_special_token_2|>",
54           "<|reserved_special_token_3|>",
55           "<|start_header_id|>",
56           "<|end_header_id|>",
57           "<|eom_id|>",
58           "<|eot_id|>",
59           "<|image|>"});
60 
61   // pad the rest of the special tokens with reserved tokens except the last
62   // one
63   ssize_t reserved_special_token_num = 4;
64   while (special_tokens->size() < kSpecialTokensSize - 1) {
65     special_tokens->emplace_back(
66         "<|reserved_special_token_" +
67         std::to_string(reserved_special_token_num++) + "|>");
68   }
69 
70   special_tokens->emplace_back("<|python_tag|>");
71 
72   return special_tokens;
73 }
74 
_get_special_tokens(Version version)75 std::unique_ptr<std::vector<std::string>> _get_special_tokens(Version version) {
76   switch (version) {
77     case Version::Multimodal:
78       return _get_multimodal_special_tokens();
79     default:
80       return _get_default_special_tokens();
81   }
82 }
83 
84 } // namespace
85 
get_tiktoken_for_llama(Version version)86 std::unique_ptr<Tiktoken> get_tiktoken_for_llama(Version version) {
87   return std::make_unique<Tiktoken>(
88       _get_special_tokens(version), kBOSTokenIndex, kEOSTokenIndex);
89 }
90 
91 } // namespace example
92