xref: /aosp_15_r20/external/executorch/examples/qualcomm/oss_scripts/llama2/qnn_llama_runner.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Qualcomm Innovation Center, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 /**
10  * @file
11  *
12  * This tool can run ExecuTorch model files with Qualcomm AI Engine Direct.
13  *
14  * User could specify arguments like desired prompt, temperature, etc.
15  */
16 
17 #include <executorch/backends/qualcomm/runtime/QnnExecuTorch.h>
18 #include <executorch/examples/qualcomm/oss_scripts/llama2/runner/runner.h>
19 #include <executorch/runtime/platform/log.h>
20 
21 #include <gflags/gflags.h>
22 
23 #include <fstream>
24 #include <vector>
25 
26 DEFINE_string(
27     model_path,
28     "qnn_llama2.pte",
29     "Model serialized in flatbuffer format.");
30 
31 DEFINE_string(
32     output_folder_path,
33     "outputs",
34     "Executorch inference data output path.");
35 
36 DEFINE_string(tokenizer_path, "tokenizer.bin", "Tokenizer stuff.");
37 
38 DEFINE_string(prompt, "The answer to the ultimate question is", "Prompt.");
39 
40 DEFINE_double(
41     temperature,
42     0.8f,
43     "Temperature; Default is 0.8f. 0 = greedy argmax sampling (deterministic). Lower temperature = more deterministic");
44 
45 DEFINE_int32(
46     seq_len,
47     128,
48     "Total number of tokens to generate (prompt + output). Defaults to max_seq_len. If the number of input tokens + seq_len > max_seq_len, the output will be truncated to max_seq_len tokens.");
49 
50 using executorch::runtime::Error;
51 using executorch::runtime::MemoryAllocator;
52 using executorch::runtime::MethodMeta;
53 using executorch::runtime::Result;
54 
main(int argc,char ** argv)55 int main(int argc, char** argv) {
56   gflags::ParseCommandLineFlags(&argc, &argv, true);
57 
58   const char* tokenizer_path = FLAGS_tokenizer_path.c_str();
59   const char* prompt = FLAGS_prompt.c_str();
60   double temperature = FLAGS_temperature;
61   int32_t seq_len = FLAGS_seq_len;
62 
63   // create llama runner
64   example::Runner runner(FLAGS_model_path, tokenizer_path, temperature);
65   ET_CHECK_MSG(runner.load() == Error::Ok, "Runner failed to load method");
66 
67   // MethodMeta describes the memory requirements of the method.
68   Result<MethodMeta> method_meta = runner.get_method_meta();
69   ET_CHECK_MSG(
70       method_meta.ok(),
71       "Failed to get method_meta 0x%x",
72       (unsigned int)method_meta.error());
73   ET_CHECK_MSG(
74       runner.mem_alloc(MemoryAllocator::kDefaultAlignment, seq_len) ==
75           Error::Ok,
76       "Runner failed to allocate memory");
77 
78   // generate tokens
79   std::string inference_output;
80   // prompt are determined by command line arguments
81   // pos_ids, atten_mask are infered inside runner
82   runner.generate(prompt, seq_len, [&](const std::string& piece) {
83     inference_output += piece;
84   });
85 
86   size_t inference_index = 0;
87   auto output_file_name = FLAGS_output_folder_path + "/output_" +
88       std::to_string(inference_index++) + "_0.raw";
89   std::ofstream fout(output_file_name.c_str());
90   fout << inference_output;
91   fout.close();
92 
93   return 0;
94 }
95