xref: /aosp_15_r20/external/executorch/extension/llm/runner/text_token_generator.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Generate tokens in a loop.
10 #pragma once
11 
12 #include <executorch/extension/llm/runner/stats.h>
13 #include <executorch/extension/llm/runner/text_decoder_runner.h>
14 #include <executorch/extension/llm/tokenizer/tokenizer.h>
15 #include <executorch/extension/tensor/tensor.h>
16 
17 namespace executorch {
18 namespace extension {
19 namespace llm {
20 
21 class ET_EXPERIMENTAL TextTokenGenerator {
22  public:
TextTokenGenerator(Tokenizer * tokenizer,TextDecoderRunner * text_decoder_runner,bool use_kv_cache,std::unique_ptr<std::unordered_set<uint64_t>> && eos_ids,Stats * stats)23   TextTokenGenerator(
24       Tokenizer* tokenizer,
25       TextDecoderRunner* text_decoder_runner,
26       bool use_kv_cache,
27       std::unique_ptr<std::unordered_set<uint64_t>>&& eos_ids,
28       Stats* stats)
29       : tokenizer_(tokenizer),
30         text_decoder_runner_(text_decoder_runner),
31         eos_ids_(std::move(eos_ids)),
32         use_kv_cache_(use_kv_cache),
33         stats_(stats) {}
34 
35   /**
36    * Token generation loop.
37    * @param tokens prompt tokens as well as the first token generated by
38    * prefill.
39    * @param start_pos the start position of the new tokens, based on how many
40    * prompt tokens is prefilled.
41    * @param seq_len the total sequence length, including the prompt tokens, next
42    * token from prefill and new tokens.
43    * @param token_callback what to do after a token is generated.
44    * @return how many tokens are generated.
45    */
generate(std::vector<uint64_t> tokens,int64_t start_pos,int32_t seq_len,std::function<void (const std::string &)> token_callback)46   inline ::executorch::runtime::Result<int64_t> generate(
47       std::vector<uint64_t> tokens,
48       int64_t start_pos,
49       int32_t seq_len,
50       std::function<void(const std::string&)> token_callback) {
51     ET_CHECK_MSG(
52         !tokens.empty(), "Token generation loop shouldn't take empty tokens");
53     int64_t pos = start_pos; // position in the sequence
54 
55     std::vector<uint64_t> token_data; // allocate space for the tokens
56     std::vector<executorch::aten::SizesType> token_shape;
57 
58     // Token after prefill
59     uint64_t cur_token = tokens.back();
60     uint64_t prev_token;
61 
62     if (use_kv_cache_) {
63       // hard code these to size 1 as kv cache is locked to static size right
64       // now.
65       token_data = {cur_token};
66       token_shape = {1, 1};
67     } else {
68       token_data = tokens;
69       token_shape = {1, static_cast<int>(tokens.size())};
70     }
71 
72     // initialize tensor wrappers
73     auto tokens_managed = from_blob(
74         token_data.data(), token_shape, executorch::aten::ScalarType::Long);
75     auto start_pos_managed =
76         from_blob(&pos, {1}, executorch::aten::ScalarType::Long);
77 
78     should_stop_ = false;
79 
80     // Generate our tokens
81     while (pos < seq_len - 1) {
82       // Run the model
83       auto logits_res =
84           text_decoder_runner_->step(tokens_managed, start_pos_managed);
85 
86       ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error());
87       executorch::aten::Tensor& logits_tensor = logits_res.get();
88 
89       prev_token = cur_token;
90 
91       stats_->on_sampling_begin();
92       cur_token = text_decoder_runner_->logits_to_token(logits_tensor);
93       stats_->on_sampling_end();
94 
95       pos++;
96 
97       if (use_kv_cache_) {
98         // update the token tensor. token_data will not be empty.
99         // NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
100         token_data[0] = cur_token;
101       } else {
102         // push it to the back
103         token_data.push_back(cur_token);
104         ET_CHECK_OK_OR_RETURN_ERROR(resize_tensor_ptr(
105             tokens_managed, {1, static_cast<int>(token_data.size())}));
106       }
107 
108       // print the token as string, decode it with the Tokenizer object
109       token_callback(ET_UNWRAP(tokenizer_->decode(prev_token, cur_token)));
110 
111       if (should_stop_) {
112         break;
113       }
114 
115       // data-dependent terminating condition: we have n_eos_ number of EOS
116       if (eos_ids_->find(cur_token) != eos_ids_->end()) {
117         printf("\n");
118         ET_LOG(Info, "\nReached to the end of generation");
119         break;
120       }
121     }
122     return pos - start_pos;
123   }
124 
125   /**
126    * Stop the generation loop.
127    */
stop()128   inline void stop() {
129     should_stop_ = true;
130   }
131 
132  private:
133   Tokenizer* tokenizer_;
134   TextDecoderRunner* text_decoder_runner_;
135   std::unique_ptr<std::unordered_set<uint64_t>> eos_ids_;
136   bool use_kv_cache_;
137 
138   // state machine
139   bool should_stop_ = false;
140 
141   // stats
142   Stats* stats_;
143 };
144 
145 } // namespace llm
146 } // namespace extension
147 } // namespace executorch
148 
149 namespace torch {
150 namespace executor {
151 // TODO(T197294990): Remove these deprecated aliases once all users have moved
152 // to the new `::executorch` namespaces.
153 using ::executorch::extension::llm::TextTokenGenerator;
154 } // namespace executor
155 } // namespace torch
156