1*523fa7a6SAndroid Build Coastguard Worker /* 2*523fa7a6SAndroid Build Coastguard Worker * Copyright (c) Meta Platforms, Inc. and affiliates. 3*523fa7a6SAndroid Build Coastguard Worker * All rights reserved. 4*523fa7a6SAndroid Build Coastguard Worker * 5*523fa7a6SAndroid Build Coastguard Worker * This source code is licensed under the BSD-style license found in the 6*523fa7a6SAndroid Build Coastguard Worker * LICENSE file in the root directory of this source tree. 7*523fa7a6SAndroid Build Coastguard Worker */ 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Worker // A simple phi-3-mini runner that includes preprocessing and post processing 10*523fa7a6SAndroid Build Coastguard Worker // logic. The module takes in a string as input and emits a string as output. 11*523fa7a6SAndroid Build Coastguard Worker 12*523fa7a6SAndroid Build Coastguard Worker #pragma once 13*523fa7a6SAndroid Build Coastguard Worker 14*523fa7a6SAndroid Build Coastguard Worker #include <memory> 15*523fa7a6SAndroid Build Coastguard Worker #include <string> 16*523fa7a6SAndroid Build Coastguard Worker 17*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/llm/sampler/sampler.h> 18*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/llm/tokenizer/tokenizer.h> 19*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/module/module.h> 20*523fa7a6SAndroid Build Coastguard Worker #include <executorch/runtime/core/exec_aten/exec_aten.h> 21*523fa7a6SAndroid Build Coastguard Worker 22*523fa7a6SAndroid Build Coastguard Worker namespace example { 23*523fa7a6SAndroid Build Coastguard Worker 24*523fa7a6SAndroid Build Coastguard Worker class Runner { 25*523fa7a6SAndroid Build Coastguard Worker public: 26*523fa7a6SAndroid Build Coastguard Worker explicit Runner( 27*523fa7a6SAndroid Build Coastguard Worker const std::string& model_path, 28*523fa7a6SAndroid Build Coastguard Worker const std::string& tokenizer_path, 29*523fa7a6SAndroid Build Coastguard Worker const float temperature = 0.8f); 30*523fa7a6SAndroid Build Coastguard Worker 31*523fa7a6SAndroid Build Coastguard Worker /** 32*523fa7a6SAndroid Build Coastguard Worker * Generates response for a given prompt. 33*523fa7a6SAndroid Build Coastguard Worker * 34*523fa7a6SAndroid Build Coastguard Worker * @param[in] prompt The prompt to generate a response for. 35*523fa7a6SAndroid Build Coastguard Worker * @param[in] max_seq_len The maximum length of the sequence to generate, 36*523fa7a6SAndroid Build Coastguard Worker * including prompt. 37*523fa7a6SAndroid Build Coastguard Worker */ 38*523fa7a6SAndroid Build Coastguard Worker void generate(const std::string& prompt, std::size_t max_seq_len); 39*523fa7a6SAndroid Build Coastguard Worker 40*523fa7a6SAndroid Build Coastguard Worker private: 41*523fa7a6SAndroid Build Coastguard Worker uint64_t logits_to_token(const exec_aten::Tensor& logits_tensor); 42*523fa7a6SAndroid Build Coastguard Worker uint64_t prefill(std::vector<uint64_t>& tokens); 43*523fa7a6SAndroid Build Coastguard Worker uint64_t run_model_step(uint64_t token); 44*523fa7a6SAndroid Build Coastguard Worker 45*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<executorch::extension::Module> module_; 46*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<executorch::extension::llm::Tokenizer> tokenizer_; 47*523fa7a6SAndroid Build Coastguard Worker std::unique_ptr<executorch::extension::llm::Sampler> sampler_; 48*523fa7a6SAndroid Build Coastguard Worker }; 49*523fa7a6SAndroid Build Coastguard Worker 50*523fa7a6SAndroid Build Coastguard Worker } // namespace example 51