xref: /aosp_15_r20/external/executorch/examples/models/phi-3-mini/runner.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker  * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker  * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker  *
5*523fa7a6SAndroid Build Coastguard Worker  * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker  * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker  */
8*523fa7a6SAndroid Build Coastguard Worker 
9*523fa7a6SAndroid Build Coastguard Worker // A simple phi-3-mini runner that includes preprocessing and post processing
10*523fa7a6SAndroid Build Coastguard Worker // logic. The module takes in a string as input and emits a string as output.
11*523fa7a6SAndroid Build Coastguard Worker 
12*523fa7a6SAndroid Build Coastguard Worker #pragma once
13*523fa7a6SAndroid Build Coastguard Worker 
14*523fa7a6SAndroid Build Coastguard Worker #include <memory>
15*523fa7a6SAndroid Build Coastguard Worker #include <string>
16*523fa7a6SAndroid Build Coastguard Worker 
17*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/llm/sampler/sampler.h>
18*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/llm/tokenizer/tokenizer.h>
19*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/module/module.h>
20*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/exec_aten/exec_aten.h>
21*523fa7a6SAndroid Build Coastguard Worker 
22*523fa7a6SAndroid Build Coastguard Worker namespace example {
23*523fa7a6SAndroid Build Coastguard Worker 
24*523fa7a6SAndroid Build Coastguard Worker class Runner {
25*523fa7a6SAndroid Build Coastguard Worker  public:
26*523fa7a6SAndroid Build Coastguard Worker   explicit Runner(
27*523fa7a6SAndroid Build Coastguard Worker       const std::string& model_path,
28*523fa7a6SAndroid Build Coastguard Worker       const std::string& tokenizer_path,
29*523fa7a6SAndroid Build Coastguard Worker       const float temperature = 0.8f);
30*523fa7a6SAndroid Build Coastguard Worker 
31*523fa7a6SAndroid Build Coastguard Worker   /**
32*523fa7a6SAndroid Build Coastguard Worker    * Generates response for a given prompt.
33*523fa7a6SAndroid Build Coastguard Worker    *
34*523fa7a6SAndroid Build Coastguard Worker    * @param[in] prompt The prompt to generate a response for.
35*523fa7a6SAndroid Build Coastguard Worker    * @param[in] max_seq_len The maximum length of the sequence to generate,
36*523fa7a6SAndroid Build Coastguard Worker    * including prompt.
37*523fa7a6SAndroid Build Coastguard Worker    */
38*523fa7a6SAndroid Build Coastguard Worker   void generate(const std::string& prompt, std::size_t max_seq_len);
39*523fa7a6SAndroid Build Coastguard Worker 
40*523fa7a6SAndroid Build Coastguard Worker  private:
41*523fa7a6SAndroid Build Coastguard Worker   uint64_t logits_to_token(const exec_aten::Tensor& logits_tensor);
42*523fa7a6SAndroid Build Coastguard Worker   uint64_t prefill(std::vector<uint64_t>& tokens);
43*523fa7a6SAndroid Build Coastguard Worker   uint64_t run_model_step(uint64_t token);
44*523fa7a6SAndroid Build Coastguard Worker 
45*523fa7a6SAndroid Build Coastguard Worker   std::unique_ptr<executorch::extension::Module> module_;
46*523fa7a6SAndroid Build Coastguard Worker   std::unique_ptr<executorch::extension::llm::Tokenizer> tokenizer_;
47*523fa7a6SAndroid Build Coastguard Worker   std::unique_ptr<executorch::extension::llm::Sampler> sampler_;
48*523fa7a6SAndroid Build Coastguard Worker };
49*523fa7a6SAndroid Build Coastguard Worker 
50*523fa7a6SAndroid Build Coastguard Worker } // namespace example
51