1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 // Generate tokens in a loop. 10 #pragma once 11 12 #include <executorch/extension/llm/runner/stats.h> 13 #include <executorch/extension/llm/runner/text_decoder_runner.h> 14 #include <executorch/extension/llm/tokenizer/tokenizer.h> 15 #include <executorch/extension/tensor/tensor.h> 16 17 namespace executorch { 18 namespace extension { 19 namespace llm { 20 21 class ET_EXPERIMENTAL TextTokenGenerator { 22 public: TextTokenGenerator(Tokenizer * tokenizer,TextDecoderRunner * text_decoder_runner,bool use_kv_cache,std::unique_ptr<std::unordered_set<uint64_t>> && eos_ids,Stats * stats)23 TextTokenGenerator( 24 Tokenizer* tokenizer, 25 TextDecoderRunner* text_decoder_runner, 26 bool use_kv_cache, 27 std::unique_ptr<std::unordered_set<uint64_t>>&& eos_ids, 28 Stats* stats) 29 : tokenizer_(tokenizer), 30 text_decoder_runner_(text_decoder_runner), 31 eos_ids_(std::move(eos_ids)), 32 use_kv_cache_(use_kv_cache), 33 stats_(stats) {} 34 35 /** 36 * Token generation loop. 37 * @param tokens prompt tokens as well as the first token generated by 38 * prefill. 39 * @param start_pos the start position of the new tokens, based on how many 40 * prompt tokens is prefilled. 41 * @param seq_len the total sequence length, including the prompt tokens, next 42 * token from prefill and new tokens. 43 * @param token_callback what to do after a token is generated. 44 * @return how many tokens are generated. 45 */ generate(std::vector<uint64_t> tokens,int64_t start_pos,int32_t seq_len,std::function<void (const std::string &)> token_callback)46 inline ::executorch::runtime::Result<int64_t> generate( 47 std::vector<uint64_t> tokens, 48 int64_t start_pos, 49 int32_t seq_len, 50 std::function<void(const std::string&)> token_callback) { 51 ET_CHECK_MSG( 52 !tokens.empty(), "Token generation loop shouldn't take empty tokens"); 53 int64_t pos = start_pos; // position in the sequence 54 55 std::vector<uint64_t> token_data; // allocate space for the tokens 56 std::vector<executorch::aten::SizesType> token_shape; 57 58 // Token after prefill 59 uint64_t cur_token = tokens.back(); 60 uint64_t prev_token; 61 62 if (use_kv_cache_) { 63 // hard code these to size 1 as kv cache is locked to static size right 64 // now. 65 token_data = {cur_token}; 66 token_shape = {1, 1}; 67 } else { 68 token_data = tokens; 69 token_shape = {1, static_cast<int>(tokens.size())}; 70 } 71 72 // initialize tensor wrappers 73 auto tokens_managed = from_blob( 74 token_data.data(), token_shape, executorch::aten::ScalarType::Long); 75 auto start_pos_managed = 76 from_blob(&pos, {1}, executorch::aten::ScalarType::Long); 77 78 should_stop_ = false; 79 80 // Generate our tokens 81 while (pos < seq_len - 1) { 82 // Run the model 83 auto logits_res = 84 text_decoder_runner_->step(tokens_managed, start_pos_managed); 85 86 ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error()); 87 executorch::aten::Tensor& logits_tensor = logits_res.get(); 88 89 prev_token = cur_token; 90 91 stats_->on_sampling_begin(); 92 cur_token = text_decoder_runner_->logits_to_token(logits_tensor); 93 stats_->on_sampling_end(); 94 95 pos++; 96 97 if (use_kv_cache_) { 98 // update the token tensor. token_data will not be empty. 99 // NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds) 100 token_data[0] = cur_token; 101 } else { 102 // push it to the back 103 token_data.push_back(cur_token); 104 ET_CHECK_OK_OR_RETURN_ERROR(resize_tensor_ptr( 105 tokens_managed, {1, static_cast<int>(token_data.size())})); 106 } 107 108 // print the token as string, decode it with the Tokenizer object 109 token_callback(ET_UNWRAP(tokenizer_->decode(prev_token, cur_token))); 110 111 if (should_stop_) { 112 break; 113 } 114 115 // data-dependent terminating condition: we have n_eos_ number of EOS 116 if (eos_ids_->find(cur_token) != eos_ids_->end()) { 117 printf("\n"); 118 ET_LOG(Info, "\nReached to the end of generation"); 119 break; 120 } 121 } 122 return pos - start_pos; 123 } 124 125 /** 126 * Stop the generation loop. 127 */ stop()128 inline void stop() { 129 should_stop_ = true; 130 } 131 132 private: 133 Tokenizer* tokenizer_; 134 TextDecoderRunner* text_decoder_runner_; 135 std::unique_ptr<std::unordered_set<uint64_t>> eos_ids_; 136 bool use_kv_cache_; 137 138 // state machine 139 bool should_stop_ = false; 140 141 // stats 142 Stats* stats_; 143 }; 144 145 } // namespace llm 146 } // namespace extension 147 } // namespace executorch 148 149 namespace torch { 150 namespace executor { 151 // TODO(T197294990): Remove these deprecated aliases once all users have moved 152 // to the new `::executorch` namespaces. 153 using ::executorch::extension::llm::TextTokenGenerator; 154 } // namespace executor 155 } // namespace torch 156