xref: /aosp_15_r20/external/executorch/examples/models/llama/runner/runner.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // A simple llama2 runner that includes preprocessing and post processing logic.
10 // The module takes in a string as input and emits a string as output.
11 
12 #pragma once
13 
14 #include <cstdint>
15 #include <functional>
16 #include <memory>
17 #include <string>
18 #include <unordered_map>
19 
20 #include <executorch/extension/llm/runner/irunner.h>
21 #include <executorch/extension/llm/runner/stats.h>
22 #include <executorch/extension/llm/runner/text_decoder_runner.h>
23 #include <executorch/extension/llm/runner/text_prefiller.h>
24 #include <executorch/extension/llm/runner/text_token_generator.h>
25 #include <executorch/extension/llm/tokenizer/tokenizer.h>
26 #include <executorch/extension/module/module.h>
27 
28 namespace example {
29 
30 class ET_EXPERIMENTAL Runner : public executorch::extension::llm::IRunner {
31  public:
32   explicit Runner(
33       const std::string& model_path,
34       const std::string& tokenizer_path,
35       const float temperature = 0.8f);
36 
37   bool is_loaded() const;
38   ::executorch::runtime::Error load();
39   ::executorch::runtime::Error generate(
40       const std::string& prompt,
41       int32_t seq_len = 128,
42       std::function<void(const std::string&)> token_callback = {},
43       std::function<void(const ::executorch::extension::llm::Stats&)>
44           stats_callback = {},
45       bool echo = true,
46       bool warming = false);
47   ::executorch::runtime::Error warmup(
48       const std::string& prompt,
49       int32_t seq_len = 128);
50   void stop();
51 
52  private:
53   float temperature_;
54   bool shouldStop_{false};
55 
56   // model
57   std::unique_ptr<::executorch::extension::Module> module_;
58   std::string tokenizer_path_;
59   std::unique_ptr<::executorch::extension::llm::Tokenizer> tokenizer_;
60   std::unordered_map<std::string, int64_t> metadata_;
61   std::unique_ptr<::executorch::extension::llm::TextDecoderRunner>
62       text_decoder_runner_;
63   std::unique_ptr<::executorch::extension::llm::TextPrefiller> text_prefiller_;
64   std::unique_ptr<::executorch::extension::llm::TextTokenGenerator>
65       text_token_generator_;
66 
67   // stats
68   ::executorch::extension::llm::Stats stats_;
69 };
70 
71 } // namespace example
72