1 /*
2 * Copyright (c) Qualcomm Innovation Center, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 /**
10 * @file
11 *
12 * This tool can run Llama2 7b with Qualcomm AI Engine Direct.
13 *
14 * User could specify arguments like desired prompt, eval_mode, etc.
15 */
16
17 #include <executorch/backends/qualcomm/runtime/QnnExecuTorch.h>
18 #include <executorch/examples/qualcomm/qaihub_scripts/llama/runner/runner.h>
19 #include <executorch/runtime/platform/log.h>
20
21 #include <gflags/gflags.h>
22
23 #include <fstream>
24
25 DEFINE_string(sharded_1_path, "", "Path to 1st sharded pte file");
26 DEFINE_string(sharded_2_path, "", "Path to 2nd sharded pte file");
27 DEFINE_string(sharded_3_path, "", "Path to 3rd sharded pte file");
28 DEFINE_string(sharded_4_path, "", "Path to 4th sharded pte file");
29
30 DEFINE_string(freq_cos_path, "", "Path to precomputed position embeddings");
31 DEFINE_string(freq_sin_path, "", "Path to precomputed position embeddings");
32
33 DEFINE_string(output_path, "outputs", "Executorch inference data output path.");
34 DEFINE_string(tokenizer_path, "tokenizer.bin", "Tokenizer stuff.");
35 DEFINE_string(prompt, "The answer to the ultimate question is", "Prompt.");
36 DEFINE_double(
37 temperature,
38 0.0f,
39 "Temperature; Default is 0.0f. 0 = greedy argmax sampling (deterministic). Lower temperature = more deterministic");
40 DEFINE_int32(
41 eval_mode,
42 0,
43 "0: PromptProcessor / 1: TokenGenerator / 2: MixedMode (TBD)");
44 DEFINE_int32(
45 seq_len,
46 128,
47 "Total number of tokens to generate (prompt + output). Defaults to max_seq_len. If the number of input tokens + seq_len > max_seq_len, the output will be truncated to max_seq_len tokens.");
48 DEFINE_double(logits_scale, 0.0, "Path to logits scale file");
49 DEFINE_int32(logits_offset, 0, "Path to logits offset file");
50
main(int argc,char ** argv)51 int main(int argc, char** argv) {
52 gflags::ParseCommandLineFlags(&argc, &argv, true);
53
54 std::vector<std::string> models_path = {
55 FLAGS_sharded_1_path,
56 FLAGS_sharded_2_path,
57 FLAGS_sharded_3_path,
58 FLAGS_sharded_4_path};
59 std::vector<std::string> pos_embs_path = {
60 FLAGS_freq_cos_path, FLAGS_freq_sin_path};
61
62 // create llama runner
63 example::Runner runner(
64 models_path,
65 pos_embs_path,
66 {8, 8, 8, 8},
67 FLAGS_tokenizer_path.c_str(),
68 FLAGS_eval_mode,
69 FLAGS_temperature,
70 FLAGS_logits_scale,
71 FLAGS_logits_offset);
72
73 // generate tokens & store inference output
74 std::ofstream fout(FLAGS_output_path.c_str());
75 runner.generate(
76 FLAGS_prompt, "", FLAGS_seq_len, [&](const std::string& piece) {
77 fout << piece;
78 });
79 fout.close();
80 return 0;
81 }
82