xref: /aosp_15_r20/external/executorch/examples/models/llava/main.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker /*
2*523fa7a6SAndroid Build Coastguard Worker  * Copyright (c) Meta Platforms, Inc. and affiliates.
3*523fa7a6SAndroid Build Coastguard Worker  * All rights reserved.
4*523fa7a6SAndroid Build Coastguard Worker  *
5*523fa7a6SAndroid Build Coastguard Worker  * This source code is licensed under the BSD-style license found in the
6*523fa7a6SAndroid Build Coastguard Worker  * LICENSE file in the root directory of this source tree.
7*523fa7a6SAndroid Build Coastguard Worker  */
8*523fa7a6SAndroid Build Coastguard Worker 
9*523fa7a6SAndroid Build Coastguard Worker #include <executorch/examples/models/llava/runner/llava_runner.h>
10*523fa7a6SAndroid Build Coastguard Worker #include <gflags/gflags.h>
11*523fa7a6SAndroid Build Coastguard Worker #ifndef LLAVA_NO_TORCH_DUMMY_IMAGE
12*523fa7a6SAndroid Build Coastguard Worker #include <torch/torch.h>
13*523fa7a6SAndroid Build Coastguard Worker #else
14*523fa7a6SAndroid Build Coastguard Worker #include <algorithm> // std::fill
15*523fa7a6SAndroid Build Coastguard Worker #endif
16*523fa7a6SAndroid Build Coastguard Worker 
17*523fa7a6SAndroid Build Coastguard Worker #if defined(ET_USE_THREADPOOL)
18*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/threadpool/cpuinfo_utils.h>
19*523fa7a6SAndroid Build Coastguard Worker #include <executorch/extension/threadpool/threadpool.h>
20*523fa7a6SAndroid Build Coastguard Worker #endif
21*523fa7a6SAndroid Build Coastguard Worker 
22*523fa7a6SAndroid Build Coastguard Worker DEFINE_string(
23*523fa7a6SAndroid Build Coastguard Worker     model_path,
24*523fa7a6SAndroid Build Coastguard Worker     "llava.pte",
25*523fa7a6SAndroid Build Coastguard Worker     "Model serialized in flatbuffer format.");
26*523fa7a6SAndroid Build Coastguard Worker 
27*523fa7a6SAndroid Build Coastguard Worker DEFINE_string(tokenizer_path, "tokenizer.bin", "Tokenizer stuff.");
28*523fa7a6SAndroid Build Coastguard Worker 
29*523fa7a6SAndroid Build Coastguard Worker DEFINE_string(prompt, "The answer to the ultimate question is", "Prompt.");
30*523fa7a6SAndroid Build Coastguard Worker 
31*523fa7a6SAndroid Build Coastguard Worker DEFINE_string(
32*523fa7a6SAndroid Build Coastguard Worker     image_path,
33*523fa7a6SAndroid Build Coastguard Worker     "",
34*523fa7a6SAndroid Build Coastguard Worker     "The path to a .pt file, a serialized torch tensor for an image, longest edge resized to 336.");
35*523fa7a6SAndroid Build Coastguard Worker 
36*523fa7a6SAndroid Build Coastguard Worker DEFINE_double(
37*523fa7a6SAndroid Build Coastguard Worker     temperature,
38*523fa7a6SAndroid Build Coastguard Worker     0.8f,
39*523fa7a6SAndroid Build Coastguard Worker     "Temperature; Default is 0.8f. 0 = greedy argmax sampling (deterministic). Lower temperature = more deterministic");
40*523fa7a6SAndroid Build Coastguard Worker 
41*523fa7a6SAndroid Build Coastguard Worker DEFINE_int32(
42*523fa7a6SAndroid Build Coastguard Worker     seq_len,
43*523fa7a6SAndroid Build Coastguard Worker     1024,
44*523fa7a6SAndroid Build Coastguard Worker     "Total number of tokens to generate (prompt + output). Defaults to max_seq_len. If the number of input tokens + seq_len > max_seq_len, the output will be truncated to max_seq_len tokens.");
45*523fa7a6SAndroid Build Coastguard Worker 
46*523fa7a6SAndroid Build Coastguard Worker DEFINE_int32(
47*523fa7a6SAndroid Build Coastguard Worker     cpu_threads,
48*523fa7a6SAndroid Build Coastguard Worker     -1,
49*523fa7a6SAndroid Build Coastguard Worker     "Number of CPU threads for inference. Defaults to -1, which implies we'll use a heuristic to derive the # of performant cores for a specific device.");
50*523fa7a6SAndroid Build Coastguard Worker 
51*523fa7a6SAndroid Build Coastguard Worker using executorch::extension::llm::Image;
52*523fa7a6SAndroid Build Coastguard Worker 
main(int32_t argc,char ** argv)53*523fa7a6SAndroid Build Coastguard Worker int32_t main(int32_t argc, char** argv) {
54*523fa7a6SAndroid Build Coastguard Worker   gflags::ParseCommandLineFlags(&argc, &argv, true);
55*523fa7a6SAndroid Build Coastguard Worker 
56*523fa7a6SAndroid Build Coastguard Worker   // Create a loader to get the data of the program file. There are other
57*523fa7a6SAndroid Build Coastguard Worker   // DataLoaders that use mmap() or point32_t to data that's already in memory,
58*523fa7a6SAndroid Build Coastguard Worker   // and users can create their own DataLoaders to load from arbitrary sources.
59*523fa7a6SAndroid Build Coastguard Worker   const char* model_path = FLAGS_model_path.c_str();
60*523fa7a6SAndroid Build Coastguard Worker 
61*523fa7a6SAndroid Build Coastguard Worker   const char* tokenizer_path = FLAGS_tokenizer_path.c_str();
62*523fa7a6SAndroid Build Coastguard Worker 
63*523fa7a6SAndroid Build Coastguard Worker   const char* prompt = FLAGS_prompt.c_str();
64*523fa7a6SAndroid Build Coastguard Worker 
65*523fa7a6SAndroid Build Coastguard Worker   std::string image_path = FLAGS_image_path;
66*523fa7a6SAndroid Build Coastguard Worker 
67*523fa7a6SAndroid Build Coastguard Worker   double temperature = FLAGS_temperature;
68*523fa7a6SAndroid Build Coastguard Worker 
69*523fa7a6SAndroid Build Coastguard Worker   int32_t seq_len = FLAGS_seq_len;
70*523fa7a6SAndroid Build Coastguard Worker 
71*523fa7a6SAndroid Build Coastguard Worker   int32_t cpu_threads = FLAGS_cpu_threads;
72*523fa7a6SAndroid Build Coastguard Worker 
73*523fa7a6SAndroid Build Coastguard Worker #if defined(ET_USE_THREADPOOL)
74*523fa7a6SAndroid Build Coastguard Worker   uint32_t num_performant_cores = cpu_threads == -1
75*523fa7a6SAndroid Build Coastguard Worker       ? ::executorch::extension::cpuinfo::get_num_performant_cores()
76*523fa7a6SAndroid Build Coastguard Worker       : static_cast<uint32_t>(cpu_threads);
77*523fa7a6SAndroid Build Coastguard Worker   ET_LOG(
78*523fa7a6SAndroid Build Coastguard Worker       Info, "Resetting threadpool with num threads = %d", num_performant_cores);
79*523fa7a6SAndroid Build Coastguard Worker   if (num_performant_cores > 0) {
80*523fa7a6SAndroid Build Coastguard Worker     ::executorch::extension::threadpool::get_threadpool()
81*523fa7a6SAndroid Build Coastguard Worker         ->_unsafe_reset_threadpool(num_performant_cores);
82*523fa7a6SAndroid Build Coastguard Worker   }
83*523fa7a6SAndroid Build Coastguard Worker #endif
84*523fa7a6SAndroid Build Coastguard Worker   // create llama runner
85*523fa7a6SAndroid Build Coastguard Worker   example::LlavaRunner runner(model_path, tokenizer_path, temperature);
86*523fa7a6SAndroid Build Coastguard Worker 
87*523fa7a6SAndroid Build Coastguard Worker   // read image and resize the longest edge to 336
88*523fa7a6SAndroid Build Coastguard Worker   std::vector<uint8_t> image_data;
89*523fa7a6SAndroid Build Coastguard Worker 
90*523fa7a6SAndroid Build Coastguard Worker #ifdef LLAVA_NO_TORCH_DUMMY_IMAGE
91*523fa7a6SAndroid Build Coastguard Worker   // Work without torch using a random data
92*523fa7a6SAndroid Build Coastguard Worker   image_data.resize(3 * 240 * 336);
93*523fa7a6SAndroid Build Coastguard Worker   std::fill(image_data.begin(), image_data.end(), 0); // black
94*523fa7a6SAndroid Build Coastguard Worker   std::array<int32_t, 3> image_shape = {3, 240, 336};
95*523fa7a6SAndroid Build Coastguard Worker   std::vector<Image> images = {
96*523fa7a6SAndroid Build Coastguard Worker       {.data = image_data, .width = image_shape[2], .height = image_shape[1]}};
97*523fa7a6SAndroid Build Coastguard Worker #else //  LLAVA_NO_TORCH_DUMMY_IMAGE
98*523fa7a6SAndroid Build Coastguard Worker   //   cv::Mat image = cv::imread(image_path, cv::IMREAD_COLOR);
99*523fa7a6SAndroid Build Coastguard Worker   //   int longest_edge = std::max(image.rows, image.cols);
100*523fa7a6SAndroid Build Coastguard Worker   //   float scale_factor = 336.0f / longest_edge;
101*523fa7a6SAndroid Build Coastguard Worker   //   cv::Size new_size(image.cols * scale_factor, image.rows * scale_factor);
102*523fa7a6SAndroid Build Coastguard Worker   //   cv::Mat resized_image;
103*523fa7a6SAndroid Build Coastguard Worker   //   cv::resize(image, resized_image, new_size);
104*523fa7a6SAndroid Build Coastguard Worker   //   image_data.assign(resized_image.datastart, resized_image.dataend);
105*523fa7a6SAndroid Build Coastguard Worker   torch::Tensor image_tensor;
106*523fa7a6SAndroid Build Coastguard Worker   torch::load(image_tensor, image_path); // CHW
107*523fa7a6SAndroid Build Coastguard Worker   ET_LOG(
108*523fa7a6SAndroid Build Coastguard Worker       Info,
109*523fa7a6SAndroid Build Coastguard Worker       "image size(0): %" PRId64 ", size(1): %" PRId64 ", size(2): %" PRId64,
110*523fa7a6SAndroid Build Coastguard Worker       image_tensor.size(0),
111*523fa7a6SAndroid Build Coastguard Worker       image_tensor.size(1),
112*523fa7a6SAndroid Build Coastguard Worker       image_tensor.size(2));
113*523fa7a6SAndroid Build Coastguard Worker   image_data.assign(
114*523fa7a6SAndroid Build Coastguard Worker       image_tensor.data_ptr<uint8_t>(),
115*523fa7a6SAndroid Build Coastguard Worker       image_tensor.data_ptr<uint8_t>() + image_tensor.numel());
116*523fa7a6SAndroid Build Coastguard Worker   std::vector<Image> images = {
117*523fa7a6SAndroid Build Coastguard Worker       {.data = image_data,
118*523fa7a6SAndroid Build Coastguard Worker        .width = static_cast<int32_t>(image_tensor.size(2)),
119*523fa7a6SAndroid Build Coastguard Worker        .height = static_cast<int32_t>(image_tensor.size(1))}};
120*523fa7a6SAndroid Build Coastguard Worker #endif // LLAVA_NO_TORCH_DUMMY_IMAGE
121*523fa7a6SAndroid Build Coastguard Worker 
122*523fa7a6SAndroid Build Coastguard Worker   // generate
123*523fa7a6SAndroid Build Coastguard Worker   runner.generate(std::move(images), prompt, seq_len);
124*523fa7a6SAndroid Build Coastguard Worker   return 0;
125*523fa7a6SAndroid Build Coastguard Worker }
126