1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 // A simple LLaVA runner that includes preprocessing and post processing logic.
10 // The runner takes in a prompt string as well as a list of images as input and
11 // emits a string as output.
12
13 #include <executorch/examples/models/llava/runner/llava_image_prefiller.h>
14 #include <executorch/examples/models/llava/runner/llava_runner.h>
15 #include <executorch/examples/models/llava/runner/llava_text_decoder_runner.h>
16 #include <executorch/extension/llm/tokenizer/bpe_tokenizer.h>
17
18 #include <ctime>
19 #include <memory>
20 #include <sstream>
21 #include <vector>
22
23 namespace llm = ::executorch::extension::llm;
24 using ::executorch::runtime::Error;
25 using ::executorch::runtime::Result;
26
27 namespace example {
28
is_loaded()29 bool LlavaRunner::is_loaded() {
30 bool instantiated = tokenizer_ && text_decoder_runner_ && text_prefiller_ &&
31 image_prefiller_ && text_token_generator_;
32 if (!instantiated) {
33 return false;
34 }
35 return text_decoder_runner_->is_method_loaded() &&
36 image_prefiller_->is_method_loaded();
37 }
38
load()39 Error LlavaRunner::load() {
40 if (is_loaded()) {
41 return Error::Ok;
42 }
43 stats_.model_load_start_ms = llm::time_in_ms();
44
45 // Load the tokenizer
46 tokenizer_ = std::make_unique<llm::BPETokenizer>();
47 tokenizer_->load(tokenizer_path_);
48
49 // Load the text decoder runner
50 text_decoder_runner_ = std::make_unique<LlavaTextDecoderRunner>(
51 module_.get(), tokenizer_->vocab_size(), temperature_);
52 text_decoder_runner_->load();
53
54 // Load the text prefiller
55 text_prefiller_ = std::make_unique<llm::TextPrefiller>(
56 text_decoder_runner_.get(),
57 /*use_kv_cache=*/true,
58 /*enable_parallel_prefill=*/true);
59
60 // Load the image prefiller
61 image_prefiller_ = std::make_unique<LlavaImagePrefiller>(module_.get());
62 image_prefiller_->load();
63
64 // Load the text token generator
65 text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
66 tokenizer_.get(),
67 text_decoder_runner_.get(),
68 /*use_kv_cache=*/true,
69 std::make_unique<std::unordered_set<uint64_t>>(
70 std::unordered_set<uint64_t>{tokenizer_->eos_tok()}),
71 &stats_);
72
73 stats_.model_load_end_ms = llm::time_in_ms();
74 return Error::Ok;
75 }
76
prefill_images(std::vector<llm::Image> & images,int64_t & start_pos)77 Error LlavaRunner::prefill_images(
78 std::vector<llm::Image>& images,
79 int64_t& start_pos) {
80 for (auto& image : images) {
81 // pos is updated inside image prefill.
82 ET_UNWRAP(image_prefiller_->prefill(image, start_pos));
83 }
84 return Error::Ok;
85 }
86
prefill_prompt(const std::string & prompt,int64_t & start_pos,int8_t bos,int8_t eos)87 Result<uint64_t> LlavaRunner::prefill_prompt(
88 const std::string& prompt,
89 int64_t& start_pos,
90 int8_t bos,
91 int8_t eos) {
92 std::vector<uint64_t> prompt_tokens =
93 ET_UNWRAP(tokenizer_->encode(prompt, bos, eos));
94
95 return text_prefiller_->prefill(prompt_tokens, start_pos);
96 }
97
generate_from_pos(const std::string & prompt,int32_t seq_len,int64_t start_pos,std::function<void (const std::string &)> token_callback,std::function<void (const::executorch::extension::llm::Stats &)> stats_callback,bool echo)98 Error LlavaRunner::generate_from_pos(
99 const std::string& prompt,
100 int32_t seq_len,
101 int64_t start_pos,
102 std::function<void(const std::string&)> token_callback,
103 std::function<void(const ::executorch::extension::llm::Stats&)>
104 stats_callback,
105 bool echo) {
106 // prefill user prompt. No BOS because preset prompt already has it.
107 if (echo) {
108 token_callback(prompt);
109 }
110
111 uint64_t prefill_next_token =
112 ET_UNWRAP(prefill_prompt(prompt, start_pos, /*bos=*/0, /*eos*/ 0));
113 stats_.first_token_ms = llm::time_in_ms();
114 stats_.prompt_eval_end_ms = llm::time_in_ms();
115 stats_.num_prompt_tokens = start_pos;
116
117 // Generate tokens
118 int64_t num_generated_tokens = ET_UNWRAP(text_token_generator_->generate(
119 {prefill_next_token}, start_pos, seq_len, token_callback));
120
121 // Bookkeeping
122 stats_.num_generated_tokens = num_generated_tokens;
123 if (stats_callback) {
124 stats_callback(stats_);
125 }
126 return Error::Ok;
127 }
128
generate(std::vector<llm::Image> images,const std::string & prompt,int32_t seq_len,std::function<void (const std::string &)> token_callback,std::function<void (const llm::Stats &)> stats_callback,bool echo)129 Error LlavaRunner::generate(
130 std::vector<llm::Image> images,
131 const std::string& prompt,
132 int32_t seq_len,
133 std::function<void(const std::string&)> token_callback,
134 std::function<void(const llm::Stats&)> stats_callback,
135 bool echo) {
136 ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
137 if (!is_loaded()) {
138 ET_CHECK_OK_OR_RETURN_ERROR(load());
139 }
140
141 ET_LOG(
142 Info,
143 "RSS after loading model: %f MiB (0 if unsupported)",
144 llm::get_rss_bytes() / 1024.0 / 1024.0);
145
146 // Wrap the token_callback with print function
147 std::function<void(const std::string&)> wrapped_callback =
148 [token_callback](const std::string& piece) {
149 llm::safe_printf(piece.c_str());
150 fflush(stdout);
151 if (token_callback) {
152 token_callback(piece);
153 }
154 };
155
156 int64_t pos = 0;
157 stats_.inference_start_ms = llm::time_in_ms();
158
159 // prefill preset prompt
160 prefill_prompt(kPresetPrompt, pos, /*bos=*/1, /*eos*/ 0);
161
162 // prefill images
163 prefill_images(images, pos);
164
165 ET_LOG(
166 Info,
167 "RSS after prompt and image prefill: %f MiB (0 if unsupported)",
168 llm::get_rss_bytes() / 1024.0 / 1024.0);
169
170 // Generate tokens
171 Error err = generate_from_pos(
172 prompt, seq_len, pos, wrapped_callback, stats_callback, echo);
173
174 stats_.inference_end_ms = llm::time_in_ms();
175 ::executorch::llm::print_report(stats_);
176
177 ET_LOG(
178 Info,
179 "RSS after finishing text generation: %f MiB (0 if unsupported)",
180 llm::get_rss_bytes() / 1024.0 / 1024.0);
181
182 return err;
183 }
184
185 } // namespace example
186