1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
10
11 namespace example {
12
13 using ::executorch::extension::llm::Tiktoken;
14
15 namespace {
16 static constexpr int32_t kSpecialTokensSize = 256;
17 static constexpr size_t kBOSTokenIndex = 0;
18 static constexpr size_t kEOSTokenIndex = 1;
19
20 static inline std::unique_ptr<std::vector<std::string>>
_get_default_special_tokens()21 _get_default_special_tokens() {
22 auto special_tokens =
23 std::make_unique<std::vector<std::string>>(std::vector<std::string>{
24 "<|begin_of_text|>",
25 "<|end_of_text|>",
26 "<|reserved_special_token_0|>",
27 "<|reserved_special_token_1|>",
28 "<|finetune_right_pad_id|>",
29 "<|step_id|>",
30 "<|start_header_id|>",
31 "<|end_header_id|>",
32 "<|eom_id|>",
33 "<|eot_id|>",
34 "<|python_tag|>"});
35 // pad the rest of the special tokens with reserved tokens
36 ssize_t reserved_special_token_num = 2;
37 while (special_tokens->size() < kSpecialTokensSize) {
38 special_tokens->emplace_back(
39 "<|reserved_special_token_" +
40 std::to_string(reserved_special_token_num++) + "|>");
41 }
42 return special_tokens;
43 }
44
45 static inline std::unique_ptr<std::vector<std::string>>
_get_multimodal_special_tokens()46 _get_multimodal_special_tokens() {
47 auto special_tokens =
48 std::make_unique<std::vector<std::string>>(std::vector<std::string>{
49 "<|begin_of_text|>",
50 "<|end_of_text|>",
51 "<|reserved_special_token_0|>",
52 "<|reserved_special_token_1|>",
53 "<|reserved_special_token_2|>",
54 "<|reserved_special_token_3|>",
55 "<|start_header_id|>",
56 "<|end_header_id|>",
57 "<|eom_id|>",
58 "<|eot_id|>",
59 "<|image|>"});
60
61 // pad the rest of the special tokens with reserved tokens except the last
62 // one
63 ssize_t reserved_special_token_num = 4;
64 while (special_tokens->size() < kSpecialTokensSize - 1) {
65 special_tokens->emplace_back(
66 "<|reserved_special_token_" +
67 std::to_string(reserved_special_token_num++) + "|>");
68 }
69
70 special_tokens->emplace_back("<|python_tag|>");
71
72 return special_tokens;
73 }
74
_get_special_tokens(Version version)75 std::unique_ptr<std::vector<std::string>> _get_special_tokens(Version version) {
76 switch (version) {
77 case Version::Multimodal:
78 return _get_multimodal_special_tokens();
79 default:
80 return _get_default_special_tokens();
81 }
82 }
83
84 } // namespace
85
get_tiktoken_for_llama(Version version)86 std::unique_ptr<Tiktoken> get_tiktoken_for_llama(Version version) {
87 return std::make_unique<Tiktoken>(
88 _get_special_tokens(version), kBOSTokenIndex, kEOSTokenIndex);
89 }
90
91 } // namespace example
92