xref: /aosp_15_r20/external/executorch/examples/models/llava/runner/llava_runner.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // A simple LLaVA runner that includes preprocessing and post processing logic.
10 // The runner takes in a prompt string as well as a list of images as input and
11 // emits a string as output.
12 
13 #include <executorch/examples/models/llava/runner/llava_image_prefiller.h>
14 #include <executorch/examples/models/llava/runner/llava_runner.h>
15 #include <executorch/examples/models/llava/runner/llava_text_decoder_runner.h>
16 #include <executorch/extension/llm/tokenizer/bpe_tokenizer.h>
17 
18 #include <ctime>
19 #include <memory>
20 #include <sstream>
21 #include <vector>
22 
23 namespace llm = ::executorch::extension::llm;
24 using ::executorch::runtime::Error;
25 using ::executorch::runtime::Result;
26 
27 namespace example {
28 
is_loaded()29 bool LlavaRunner::is_loaded() {
30   bool instantiated = tokenizer_ && text_decoder_runner_ && text_prefiller_ &&
31       image_prefiller_ && text_token_generator_;
32   if (!instantiated) {
33     return false;
34   }
35   return text_decoder_runner_->is_method_loaded() &&
36       image_prefiller_->is_method_loaded();
37 }
38 
load()39 Error LlavaRunner::load() {
40   if (is_loaded()) {
41     return Error::Ok;
42   }
43   stats_.model_load_start_ms = llm::time_in_ms();
44 
45   // Load the tokenizer
46   tokenizer_ = std::make_unique<llm::BPETokenizer>();
47   tokenizer_->load(tokenizer_path_);
48 
49   // Load the text decoder runner
50   text_decoder_runner_ = std::make_unique<LlavaTextDecoderRunner>(
51       module_.get(), tokenizer_->vocab_size(), temperature_);
52   text_decoder_runner_->load();
53 
54   // Load the text prefiller
55   text_prefiller_ = std::make_unique<llm::TextPrefiller>(
56       text_decoder_runner_.get(),
57       /*use_kv_cache=*/true,
58       /*enable_parallel_prefill=*/true);
59 
60   // Load the image prefiller
61   image_prefiller_ = std::make_unique<LlavaImagePrefiller>(module_.get());
62   image_prefiller_->load();
63 
64   // Load the text token generator
65   text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
66       tokenizer_.get(),
67       text_decoder_runner_.get(),
68       /*use_kv_cache=*/true,
69       std::make_unique<std::unordered_set<uint64_t>>(
70           std::unordered_set<uint64_t>{tokenizer_->eos_tok()}),
71       &stats_);
72 
73   stats_.model_load_end_ms = llm::time_in_ms();
74   return Error::Ok;
75 }
76 
prefill_images(std::vector<llm::Image> & images,int64_t & start_pos)77 Error LlavaRunner::prefill_images(
78     std::vector<llm::Image>& images,
79     int64_t& start_pos) {
80   for (auto& image : images) {
81     // pos is updated inside image prefill.
82     ET_UNWRAP(image_prefiller_->prefill(image, start_pos));
83   }
84   return Error::Ok;
85 }
86 
prefill_prompt(const std::string & prompt,int64_t & start_pos,int8_t bos,int8_t eos)87 Result<uint64_t> LlavaRunner::prefill_prompt(
88     const std::string& prompt,
89     int64_t& start_pos,
90     int8_t bos,
91     int8_t eos) {
92   std::vector<uint64_t> prompt_tokens =
93       ET_UNWRAP(tokenizer_->encode(prompt, bos, eos));
94 
95   return text_prefiller_->prefill(prompt_tokens, start_pos);
96 }
97 
generate_from_pos(const std::string & prompt,int32_t seq_len,int64_t start_pos,std::function<void (const std::string &)> token_callback,std::function<void (const::executorch::extension::llm::Stats &)> stats_callback,bool echo)98 Error LlavaRunner::generate_from_pos(
99     const std::string& prompt,
100     int32_t seq_len,
101     int64_t start_pos,
102     std::function<void(const std::string&)> token_callback,
103     std::function<void(const ::executorch::extension::llm::Stats&)>
104         stats_callback,
105     bool echo) {
106   // prefill user prompt. No BOS because preset prompt already has it.
107   if (echo) {
108     token_callback(prompt);
109   }
110 
111   uint64_t prefill_next_token =
112       ET_UNWRAP(prefill_prompt(prompt, start_pos, /*bos=*/0, /*eos*/ 0));
113   stats_.first_token_ms = llm::time_in_ms();
114   stats_.prompt_eval_end_ms = llm::time_in_ms();
115   stats_.num_prompt_tokens = start_pos;
116 
117   // Generate tokens
118   int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate(
119       {prefill_next_token}, start_pos, seq_len, token_callback));
120 
121   // Bookkeeping
122   stats_.num_generated_tokens = num_generated_tokens;
123   if (stats_callback) {
124     stats_callback(stats_);
125   }
126   return Error::Ok;
127 }
128 
generate(std::vector<llm::Image> images,const std::string & prompt,int32_t seq_len,std::function<void (const std::string &)> token_callback,std::function<void (const llm::Stats &)> stats_callback,bool echo)129 Error LlavaRunner::generate(
130     std::vector<llm::Image> images,
131     const std::string& prompt,
132     int32_t seq_len,
133     std::function<void(const std::string&)> token_callback,
134     std::function<void(const llm::Stats&)> stats_callback,
135     bool echo) {
136   ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
137   if (!is_loaded()) {
138     ET_CHECK_OK_OR_RETURN_ERROR(load());
139   }
140 
141   ET_LOG(
142       Info,
143       "RSS after loading model: %f MiB (0 if unsupported)",
144       llm::get_rss_bytes() / 1024.0 / 1024.0);
145 
146   // Wrap the token_callback with print function
147   std::function<void(const std::string&)> wrapped_callback =
148       [token_callback](const std::string& piece) {
149         llm::safe_printf(piece.c_str());
150         fflush(stdout);
151         if (token_callback) {
152           token_callback(piece);
153         }
154       };
155 
156   int64_t pos = 0;
157   stats_.inference_start_ms = llm::time_in_ms();
158 
159   // prefill preset prompt
160   prefill_prompt(kPresetPrompt, pos, /*bos=*/1, /*eos*/ 0);
161 
162   // prefill images
163   prefill_images(images, pos);
164 
165   ET_LOG(
166       Info,
167       "RSS after prompt and image prefill: %f MiB (0 if unsupported)",
168       llm::get_rss_bytes() / 1024.0 / 1024.0);
169 
170   // Generate tokens
171   Error err = generate_from_pos(
172       prompt, seq_len, pos, wrapped_callback, stats_callback, echo);
173 
174   stats_.inference_end_ms = llm::time_in_ms();
175   ::executorch::llm::print_report(stats_);
176 
177   ET_LOG(
178       Info,
179       "RSS after finishing text generation: %f MiB (0 if unsupported)",
180       llm::get_rss_bytes() / 1024.0 / 1024.0);
181 
182   return err;
183 }
184 
185 } // namespace example
186