xref: /aosp_15_r20/external/executorch/examples/models/phi-3-mini/main.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker  * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker  * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker  *
5*523fa7a6SAndroid Build Coastguard Worker  * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker  * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker  */
8*523fa7a6SAndroid Build Coastguard Worker 
9*523fa7a6SAndroid Build Coastguard Worker #include <gflags/gflags.h>
10*523fa7a6SAndroid Build Coastguard Worker 
11*523fa7a6SAndroid Build Coastguard Worker #include <executorch/examples/models/phi-3-mini/runner.h>
12*523fa7a6SAndroid Build Coastguard Worker 
13*523fa7a6SAndroid Build Coastguard Worker DEFINE_string(
14*523fa7a6SAndroid Build Coastguard Worker     model_path,
15*523fa7a6SAndroid Build Coastguard Worker     "phi-3-mini.pte",
16*523fa7a6SAndroid Build Coastguard Worker     "File path for model serialized in flatbuffer format.");
17*523fa7a6SAndroid Build Coastguard Worker 
18*523fa7a6SAndroid Build Coastguard Worker DEFINE_string(tokenizer_path, "tokenizer.bin", "File path for tokenizer.");
19*523fa7a6SAndroid Build Coastguard Worker 
20*523fa7a6SAndroid Build Coastguard Worker DEFINE_string(prompt, "Tell me a story", "Prompt.");
21*523fa7a6SAndroid Build Coastguard Worker 
22*523fa7a6SAndroid Build Coastguard Worker DEFINE_double(
23*523fa7a6SAndroid Build Coastguard Worker     temperature,
24*523fa7a6SAndroid Build Coastguard Worker     0.8f,
25*523fa7a6SAndroid Build Coastguard Worker     "Temperature; Default is 0.8f. 0 = greedy argmax sampling (deterministic). Lower temperature = more deterministic");
26*523fa7a6SAndroid Build Coastguard Worker 
27*523fa7a6SAndroid Build Coastguard Worker DEFINE_int32(
28*523fa7a6SAndroid Build Coastguard Worker     seq_len,
29*523fa7a6SAndroid Build Coastguard Worker     128,
30*523fa7a6SAndroid Build Coastguard Worker     "Total number of tokens to generate (prompt + output).");
31*523fa7a6SAndroid Build Coastguard Worker 
main(int32_t argc,char ** argv)32*523fa7a6SAndroid Build Coastguard Worker int main(int32_t argc, char** argv) {
33*523fa7a6SAndroid Build Coastguard Worker   gflags::ParseCommandLineFlags(&argc, &argv, true);
34*523fa7a6SAndroid Build Coastguard Worker 
35*523fa7a6SAndroid Build Coastguard Worker   const char* model_path = FLAGS_model_path.c_str();
36*523fa7a6SAndroid Build Coastguard Worker 
37*523fa7a6SAndroid Build Coastguard Worker   const char* tokenizer_path = FLAGS_tokenizer_path.c_str();
38*523fa7a6SAndroid Build Coastguard Worker 
39*523fa7a6SAndroid Build Coastguard Worker   const char* prompt = FLAGS_prompt.c_str();
40*523fa7a6SAndroid Build Coastguard Worker 
41*523fa7a6SAndroid Build Coastguard Worker   double temperature = FLAGS_temperature;
42*523fa7a6SAndroid Build Coastguard Worker 
43*523fa7a6SAndroid Build Coastguard Worker   int32_t seq_len = FLAGS_seq_len;
44*523fa7a6SAndroid Build Coastguard Worker 
45*523fa7a6SAndroid Build Coastguard Worker   example::Runner runner(model_path, tokenizer_path, temperature);
46*523fa7a6SAndroid Build Coastguard Worker 
47*523fa7a6SAndroid Build Coastguard Worker   runner.generate(prompt, seq_len);
48*523fa7a6SAndroid Build Coastguard Worker 
49*523fa7a6SAndroid Build Coastguard Worker   return 0;
50*523fa7a6SAndroid Build Coastguard Worker }
51