1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 // A simple multimodal LLM runner that includes preprocessing and post 10 // processing logic. The module takes in a string as input and emits a string as 11 // output. 12 13 #pragma once 14 15 #include <cstdint> 16 #include <functional> 17 #include <memory> 18 #include <string> 19 #include <type_traits> 20 #include <unordered_map> 21 22 #include <executorch/extension/llm/runner/image.h> 23 #include <executorch/extension/llm/runner/image_prefiller.h> 24 #include <executorch/extension/llm/runner/stats.h> 25 #include <executorch/extension/llm/runner/text_decoder_runner.h> 26 #include <executorch/extension/llm/runner/text_prefiller.h> 27 #include <executorch/extension/llm/runner/text_token_generator.h> 28 #include <executorch/extension/llm/sampler/sampler.h> 29 #include <executorch/extension/llm/tokenizer/tokenizer.h> 30 #include <executorch/extension/module/module.h> 31 32 namespace executorch { 33 namespace extension { 34 namespace llm { 35 36 class ET_EXPERIMENTAL MultimodalRunner { 37 public: 38 explicit MultimodalRunner( 39 const std::string& model_path, 40 const std::string& tokenizer_path, 41 const float temperature = 0.8f) temperature_(temperature)42 : temperature_(temperature), 43 module_(std::make_unique<Module>(model_path, Module::LoadMode::File)), 44 tokenizer_path_(tokenizer_path) { 45 ET_LOG( 46 Info, 47 "Creating Multimodal LLM runner: model_path=%s, tokenizer_path=%s", 48 model_path.c_str(), 49 tokenizer_path.c_str()); 50 } 51 52 virtual bool is_loaded() = 0; 53 virtual ::executorch::runtime::Error load() = 0; 54 virtual ::executorch::runtime::Error generate( 55 std::vector<Image> images, 56 const std::string& prompt, 57 int32_t seq_len = 1024, 58 std::function<void(const std::string&)> token_callback = {}, 59 std::function<void(const Stats&)> stats_callback = {}, 60 bool echo = true) = 0; 61 62 /** 63 * Prefill an LLaVA Module with the given images input. 64 * @param images The image input to LLaVA. 65 * @param start_pos The starting position in KV cache of the input in the LLM. 66 * It's passed as reference and will be updated inside this function. 67 * @return The error status of prefilling images. 68 */ 69 virtual runtime::Error prefill_images( 70 std::vector<Image>& images, 71 int64_t& start_pos) = 0; 72 73 /** 74 * Prefill an LLaVA Module with the given text input. 75 * @param prompt The text prompt to LLaVA. 76 * @param start_pos The starting position in KV cache of the input in the LLM. 77 * It's passed as reference and will be updated inside this function. 78 * @param bos The number of BOS (begin of sequence) token. 79 * @param eos The number of EOS (end of sequence) token. 80 * @return The generated token of the LLaVA Module after prefill prompt. 81 */ 82 virtual runtime::Result<uint64_t> prefill_prompt( 83 const std::string& prompt, 84 int64_t& start_pos, 85 int8_t bos = 0, 86 int8_t eos = 0) = 0; 87 88 /** 89 * Generate tokens from the given prompt, starting from the given position. 90 * @param prompt The text prompt to LLaVA. 91 * @param seq_len The total sequence length, including the prompt tokens and 92 * new tokens. 93 * @param start_pos The starting position in KV cache of the input in the LLM. 94 * @param token_callback What to do after a token is generated. 95 * @param stats_callback What to do with Stats. 96 * @param echo Whether to echo the input prompt or not. 97 * @return The error code. 98 */ 99 virtual runtime::Error generate_from_pos( 100 const std::string& prompt, 101 int32_t seq_len = 1024, 102 int64_t start_pos = 0, 103 std::function<void(const std::string&)> token_callback = {}, 104 std::function<void(const ::executorch::extension::llm::Stats&)> 105 stats_callback = {}, 106 bool echo = true) = 0; 107 stop()108 inline void stop() { 109 text_token_generator_->stop(); 110 } 111 112 virtual ~MultimodalRunner() = default; 113 114 protected: 115 // metadata 116 int32_t vocab_size_; 117 int32_t bos_id_; 118 int32_t eos_id_; 119 int32_t n_bos_; 120 int32_t n_eos_; 121 int32_t max_seq_len_; 122 float temperature_; 123 124 // model 125 std::unordered_set<std::string> model_methods_; 126 std::unique_ptr<Module> module_; 127 std::unique_ptr<TextDecoderRunner> text_decoder_runner_; 128 std::unique_ptr<TextPrefiller> text_prefiller_; 129 std::unique_ptr<ImagePrefiller> image_prefiller_; 130 std::unique_ptr<TextTokenGenerator> text_token_generator_; 131 std::string tokenizer_path_; 132 std::unique_ptr<Tokenizer> tokenizer_; 133 134 // stats 135 Stats stats_; 136 }; 137 138 } // namespace llm 139 } // namespace extension 140 } // namespace executorch 141 142 namespace torch { 143 namespace executor { 144 // TODO(T197294990): Remove these deprecated aliases once all users have moved 145 // to the new `::executorch` namespaces. 146 using ::executorch::extension::llm::MultimodalRunner; 147 } // namespace executor 148 } // namespace torch 149