xref: /aosp_15_r20/external/executorch/examples/models/llava/runner/llava_runner.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // A simple multimodal LLM runner that includes preprocessing and post
10 // processing logic.
11 #pragma once
12 
13 #include <cstdint>
14 #include <functional>
15 #include <memory>
16 #include <string>
17 #include <type_traits>
18 #include <unordered_map>
19 
20 #include <executorch/extension/llm/runner/multimodal_runner.h>
21 
22 namespace example {
23 
24 class ET_EXPERIMENTAL LlavaRunner
25     : public ::executorch::extension::llm::MultimodalRunner {
26  public:
27   explicit LlavaRunner(
28       const std::string& model_path,
29       const std::string& tokenizer_path,
30       const float temperature = 0.8f)
MultimodalRunner(model_path,tokenizer_path,temperature)31       : MultimodalRunner(model_path, tokenizer_path, temperature){};
32   bool is_loaded();
33   ::executorch::runtime::Error load();
34   ::executorch::runtime::Error generate(
35       std::vector<::executorch::extension::llm::Image> images,
36       const std::string& prompt,
37       int32_t seq_len = 1024,
38       std::function<void(const std::string&)> token_callback = {},
39       std::function<void(const ::executorch::extension::llm::Stats&)>
40           stats_callback = {},
41       bool echo = true);
42 
43   /**
44    * Prefill an LLaVA Module with the given images input.
45    * @param images The image input to LLaVA.
46    * @param start_pos The starting position in KV cache of the input in the LLM.
47    * It's passed as reference and will be updated inside this function.
48    * @return The error status of prefilling images.
49    */
50   ::executorch::runtime::Error prefill_images(
51       std::vector<::executorch::extension::llm::Image>& images,
52       int64_t& start_pos);
53 
54   /**
55    * Prefill an LLaVA Module with the given text input.
56    * @param prompt The text prompt to LLaVA.
57    * @param start_pos The starting position in KV cache of the input in the LLM.
58    * It's passed as reference and will be updated inside this function.
59    * @param bos The number of BOS (begin of sequence) token.
60    * @param eos The number of EOS (end of sequence) token.
61    * @return The generated token of the LLaVA Module after prefill prompt.
62    */
63   ::executorch::runtime::Result<uint64_t> prefill_prompt(
64       const std::string& prompt,
65       int64_t& start_pos,
66       int8_t bos = 0,
67       int8_t eos = 0);
68 
69   /**
70    * Generate tokens from the given prompt, starting from the given position.
71    * @param prompt The text prompt to LLaVA.
72    * @param seq_len The total sequence length, including the prompt tokens and
73    * new tokens.
74    * @param start_pos The starting position in KV cache of the input in the LLM.
75    * @param token_callback What to do after a token is generated.
76    * @param stats_callback What to do with Stats.
77    * @param echo Whether to echo the input prompt or not.
78    * @return The error code.
79    */
80   ::executorch::runtime::Error generate_from_pos(
81       const std::string& prompt,
82       int32_t seq_len = 1024,
83       int64_t start_pos = 0,
84       std::function<void(const std::string&)> token_callback = {},
85       std::function<void(const ::executorch::extension::llm::Stats&)>
86           stats_callback = {},
87       bool echo = true);
88 
89  private:
90   inline static const std::string kPresetPrompt =
91       "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: ";
92 };
93 
94 } // namespace example
95