1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 // A simple llama2 runner that includes preprocessing and post processing logic. 10 // The module takes in a string as input and emits a string as output. 11 12 #pragma once 13 14 #include <cstdint> 15 #include <functional> 16 #include <memory> 17 #include <string> 18 #include <unordered_map> 19 20 #include <executorch/extension/llm/runner/irunner.h> 21 #include <executorch/extension/llm/runner/stats.h> 22 #include <executorch/extension/llm/runner/text_decoder_runner.h> 23 #include <executorch/extension/llm/runner/text_prefiller.h> 24 #include <executorch/extension/llm/runner/text_token_generator.h> 25 #include <executorch/extension/llm/tokenizer/tokenizer.h> 26 #include <executorch/extension/module/module.h> 27 28 namespace example { 29 30 class ET_EXPERIMENTAL Runner : public executorch::extension::llm::IRunner { 31 public: 32 explicit Runner( 33 const std::string& model_path, 34 const std::string& tokenizer_path, 35 const float temperature = 0.8f); 36 37 bool is_loaded() const; 38 ::executorch::runtime::Error load(); 39 ::executorch::runtime::Error generate( 40 const std::string& prompt, 41 int32_t seq_len = 128, 42 std::function<void(const std::string&)> token_callback = {}, 43 std::function<void(const ::executorch::extension::llm::Stats&)> 44 stats_callback = {}, 45 bool echo = true, 46 bool warming = false); 47 ::executorch::runtime::Error warmup( 48 const std::string& prompt, 49 int32_t seq_len = 128); 50 void stop(); 51 52 private: 53 float temperature_; 54 bool shouldStop_{false}; 55 56 // model 57 std::unique_ptr<::executorch::extension::Module> module_; 58 std::string tokenizer_path_; 59 std::unique_ptr<::executorch::extension::llm::Tokenizer> tokenizer_; 60 std::unordered_map<std::string, int64_t> metadata_; 61 std::unique_ptr<::executorch::extension::llm::TextDecoderRunner> 62 text_decoder_runner_; 63 std::unique_ptr<::executorch::extension::llm::TextPrefiller> text_prefiller_; 64 std::unique_ptr<::executorch::extension::llm::TextTokenGenerator> 65 text_token_generator_; 66 67 // stats 68 ::executorch::extension::llm::Stats stats_; 69 }; 70 71 } // namespace example 72