xref: /aosp_15_r20/external/executorch/extension/llm/runner/multimodal_runner.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // A simple multimodal LLM runner that includes preprocessing and post
10 // processing logic. The module takes in a string as input and emits a string as
11 // output.
12 
13 #pragma once
14 
15 #include <cstdint>
16 #include <functional>
17 #include <memory>
18 #include <string>
19 #include <type_traits>
20 #include <unordered_map>
21 
22 #include <executorch/extension/llm/runner/image.h>
23 #include <executorch/extension/llm/runner/image_prefiller.h>
24 #include <executorch/extension/llm/runner/stats.h>
25 #include <executorch/extension/llm/runner/text_decoder_runner.h>
26 #include <executorch/extension/llm/runner/text_prefiller.h>
27 #include <executorch/extension/llm/runner/text_token_generator.h>
28 #include <executorch/extension/llm/sampler/sampler.h>
29 #include <executorch/extension/llm/tokenizer/tokenizer.h>
30 #include <executorch/extension/module/module.h>
31 
32 namespace executorch {
33 namespace extension {
34 namespace llm {
35 
36 class ET_EXPERIMENTAL MultimodalRunner {
37  public:
38   explicit MultimodalRunner(
39       const std::string& model_path,
40       const std::string& tokenizer_path,
41       const float temperature = 0.8f)
temperature_(temperature)42       : temperature_(temperature),
43         module_(std::make_unique<Module>(model_path, Module::LoadMode::File)),
44         tokenizer_path_(tokenizer_path) {
45     ET_LOG(
46         Info,
47         "Creating Multimodal LLM runner: model_path=%s, tokenizer_path=%s",
48         model_path.c_str(),
49         tokenizer_path.c_str());
50   }
51 
52   virtual bool is_loaded() = 0;
53   virtual ::executorch::runtime::Error load() = 0;
54   virtual ::executorch::runtime::Error generate(
55       std::vector<Image> images,
56       const std::string& prompt,
57       int32_t seq_len = 1024,
58       std::function<void(const std::string&)> token_callback = {},
59       std::function<void(const Stats&)> stats_callback = {},
60       bool echo = true) = 0;
61 
62   /**
63    * Prefill an LLaVA Module with the given images input.
64    * @param images The image input to LLaVA.
65    * @param start_pos The starting position in KV cache of the input in the LLM.
66    * It's passed as reference and will be updated inside this function.
67    * @return The error status of prefilling images.
68    */
69   virtual runtime::Error prefill_images(
70       std::vector<Image>& images,
71       int64_t& start_pos) = 0;
72 
73   /**
74    * Prefill an LLaVA Module with the given text input.
75    * @param prompt The text prompt to LLaVA.
76    * @param start_pos The starting position in KV cache of the input in the LLM.
77    * It's passed as reference and will be updated inside this function.
78    * @param bos The number of BOS (begin of sequence) token.
79    * @param eos The number of EOS (end of sequence) token.
80    * @return The generated token of the LLaVA Module after prefill prompt.
81    */
82   virtual runtime::Result<uint64_t> prefill_prompt(
83       const std::string& prompt,
84       int64_t& start_pos,
85       int8_t bos = 0,
86       int8_t eos = 0) = 0;
87 
88   /**
89    * Generate tokens from the given prompt, starting from the given position.
90    * @param prompt The text prompt to LLaVA.
91    * @param seq_len The total sequence length, including the prompt tokens and
92    * new tokens.
93    * @param start_pos The starting position in KV cache of the input in the LLM.
94    * @param token_callback What to do after a token is generated.
95    * @param stats_callback What to do with Stats.
96    * @param echo Whether to echo the input prompt or not.
97    * @return The error code.
98    */
99   virtual runtime::Error generate_from_pos(
100       const std::string& prompt,
101       int32_t seq_len = 1024,
102       int64_t start_pos = 0,
103       std::function<void(const std::string&)> token_callback = {},
104       std::function<void(const ::executorch::extension::llm::Stats&)>
105           stats_callback = {},
106       bool echo = true) = 0;
107 
stop()108   inline void stop() {
109     text_token_generator_->stop();
110   }
111 
112   virtual ~MultimodalRunner() = default;
113 
114  protected:
115   // metadata
116   int32_t vocab_size_;
117   int32_t bos_id_;
118   int32_t eos_id_;
119   int32_t n_bos_;
120   int32_t n_eos_;
121   int32_t max_seq_len_;
122   float temperature_;
123 
124   // model
125   std::unordered_set<std::string> model_methods_;
126   std::unique_ptr<Module> module_;
127   std::unique_ptr<TextDecoderRunner> text_decoder_runner_;
128   std::unique_ptr<TextPrefiller> text_prefiller_;
129   std::unique_ptr<ImagePrefiller> image_prefiller_;
130   std::unique_ptr<TextTokenGenerator> text_token_generator_;
131   std::string tokenizer_path_;
132   std::unique_ptr<Tokenizer> tokenizer_;
133 
134   // stats
135   Stats stats_;
136 };
137 
138 } // namespace llm
139 } // namespace extension
140 } // namespace executorch
141 
142 namespace torch {
143 namespace executor {
144 // TODO(T197294990): Remove these deprecated aliases once all users have moved
145 // to the new `::executorch` namespaces.
146 using ::executorch::extension::llm::MultimodalRunner;
147 } // namespace executor
148 } // namespace torch
149