1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker *
5*523fa7a6SAndroid Build Coastguard Worker * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker */
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Worker #include <gflags/gflags.h>
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Worker #include <executorch/examples/models/llama/runner/runner.h>
12*523fa7a6SAndroid Build Coastguard Worker
13*523fa7a6SAndroid Build Coastguard Worker #if defined(ET_USE_THREADPOOL)
14*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/threadpool/cpuinfo_utils.h>
15*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/threadpool/threadpool.h>
16*523fa7a6SAndroid Build Coastguard Worker #endif
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard Worker DEFINE_string(
19*523fa7a6SAndroid Build Coastguard Worker model_path,
20*523fa7a6SAndroid Build Coastguard Worker "llama2.pte",
21*523fa7a6SAndroid Build Coastguard Worker "Model serialized in flatbuffer format.");
22*523fa7a6SAndroid Build Coastguard Worker
23*523fa7a6SAndroid Build Coastguard Worker DEFINE_string(tokenizer_path, "tokenizer.bin", "Tokenizer stuff.");
24*523fa7a6SAndroid Build Coastguard Worker
25*523fa7a6SAndroid Build Coastguard Worker DEFINE_string(prompt, "The answer to the ultimate question is", "Prompt.");
26*523fa7a6SAndroid Build Coastguard Worker
27*523fa7a6SAndroid Build Coastguard Worker DEFINE_double(
28*523fa7a6SAndroid Build Coastguard Worker temperature,
29*523fa7a6SAndroid Build Coastguard Worker 0.8f,
30*523fa7a6SAndroid Build Coastguard Worker "Temperature; Default is 0.8f. 0 = greedy argmax sampling (deterministic). Lower temperature = more deterministic");
31*523fa7a6SAndroid Build Coastguard Worker
32*523fa7a6SAndroid Build Coastguard Worker DEFINE_int32(
33*523fa7a6SAndroid Build Coastguard Worker seq_len,
34*523fa7a6SAndroid Build Coastguard Worker 128,
35*523fa7a6SAndroid Build Coastguard Worker "Total number of tokens to generate (prompt + output). Defaults to max_seq_len. If the number of input tokens + seq_len > max_seq_len, the output will be truncated to max_seq_len tokens.");
36*523fa7a6SAndroid Build Coastguard Worker
37*523fa7a6SAndroid Build Coastguard Worker DEFINE_int32(
38*523fa7a6SAndroid Build Coastguard Worker cpu_threads,
39*523fa7a6SAndroid Build Coastguard Worker -1,
40*523fa7a6SAndroid Build Coastguard Worker "Number of CPU threads for inference. Defaults to -1, which implies we'll use a heuristic to derive the # of performant cores for a specific device.");
41*523fa7a6SAndroid Build Coastguard Worker
42*523fa7a6SAndroid Build Coastguard Worker DEFINE_bool(warmup, false, "Whether to run a warmup run.");
43*523fa7a6SAndroid Build Coastguard Worker
main(int32_t argc,char ** argv)44*523fa7a6SAndroid Build Coastguard Worker int32_t main(int32_t argc, char** argv) {
45*523fa7a6SAndroid Build Coastguard Worker gflags::ParseCommandLineFlags(&argc, &argv, true);
46*523fa7a6SAndroid Build Coastguard Worker
47*523fa7a6SAndroid Build Coastguard Worker // Create a loader to get the data of the program file. There are other
48*523fa7a6SAndroid Build Coastguard Worker // DataLoaders that use mmap() or point32_t to data that's already in memory,
49*523fa7a6SAndroid Build Coastguard Worker // and users can create their own DataLoaders to load from arbitrary sources.
50*523fa7a6SAndroid Build Coastguard Worker const char* model_path = FLAGS_model_path.c_str();
51*523fa7a6SAndroid Build Coastguard Worker
52*523fa7a6SAndroid Build Coastguard Worker const char* tokenizer_path = FLAGS_tokenizer_path.c_str();
53*523fa7a6SAndroid Build Coastguard Worker
54*523fa7a6SAndroid Build Coastguard Worker const char* prompt = FLAGS_prompt.c_str();
55*523fa7a6SAndroid Build Coastguard Worker
56*523fa7a6SAndroid Build Coastguard Worker double temperature = FLAGS_temperature;
57*523fa7a6SAndroid Build Coastguard Worker
58*523fa7a6SAndroid Build Coastguard Worker int32_t seq_len = FLAGS_seq_len;
59*523fa7a6SAndroid Build Coastguard Worker
60*523fa7a6SAndroid Build Coastguard Worker int32_t cpu_threads = FLAGS_cpu_threads;
61*523fa7a6SAndroid Build Coastguard Worker
62*523fa7a6SAndroid Build Coastguard Worker bool warmup = FLAGS_warmup;
63*523fa7a6SAndroid Build Coastguard Worker
64*523fa7a6SAndroid Build Coastguard Worker #if defined(ET_USE_THREADPOOL)
65*523fa7a6SAndroid Build Coastguard Worker uint32_t num_performant_cores = cpu_threads == -1
66*523fa7a6SAndroid Build Coastguard Worker ? ::executorch::extension::cpuinfo::get_num_performant_cores()
67*523fa7a6SAndroid Build Coastguard Worker : static_cast<uint32_t>(cpu_threads);
68*523fa7a6SAndroid Build Coastguard Worker ET_LOG(
69*523fa7a6SAndroid Build Coastguard Worker Info, "Resetting threadpool with num threads = %d", num_performant_cores);
70*523fa7a6SAndroid Build Coastguard Worker if (num_performant_cores > 0) {
71*523fa7a6SAndroid Build Coastguard Worker ::executorch::extension::threadpool::get_threadpool()
72*523fa7a6SAndroid Build Coastguard Worker ->_unsafe_reset_threadpool(num_performant_cores);
73*523fa7a6SAndroid Build Coastguard Worker }
74*523fa7a6SAndroid Build Coastguard Worker #endif
75*523fa7a6SAndroid Build Coastguard Worker // create llama runner
76*523fa7a6SAndroid Build Coastguard Worker example::Runner runner(model_path, tokenizer_path, temperature);
77*523fa7a6SAndroid Build Coastguard Worker
78*523fa7a6SAndroid Build Coastguard Worker if (warmup) {
79*523fa7a6SAndroid Build Coastguard Worker runner.warmup(prompt, seq_len);
80*523fa7a6SAndroid Build Coastguard Worker }
81*523fa7a6SAndroid Build Coastguard Worker // generate
82*523fa7a6SAndroid Build Coastguard Worker runner.generate(prompt, seq_len);
83*523fa7a6SAndroid Build Coastguard Worker
84*523fa7a6SAndroid Build Coastguard Worker return 0;
85*523fa7a6SAndroid Build Coastguard Worker }
86