xref: /aosp_15_r20/external/executorch/examples/qualcomm/oss_scripts/llama2/runner/runner.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Qualcomm Innovation Center, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // A simple llama2 runner that includes preprocessing and post processing logic.
10 // The module takes in a string as input and emits a string as output.
11 
12 #include <executorch/examples/qualcomm/oss_scripts/llama2/runner/runner.h>
13 #include <executorch/extension/evalue_util/print_evalue.h>
14 #include <executorch/extension/llm/runner/util.h>
15 #include <executorch/extension/llm/tokenizer/bpe_tokenizer.h>
16 #include <executorch/runtime/core/exec_aten/exec_aten.h>
17 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
18 #include <executorch/runtime/core/exec_aten/util/tensor_util.h>
19 #include <executorch/runtime/platform/log.h>
20 
21 #include <ctime>
22 #include <memory>
23 #include <sstream>
24 
25 using executorch::aten::ScalarType;
26 using executorch::aten::SizesType;
27 using executorch::aten::Tensor;
28 using executorch::extension::from_blob;
29 using executorch::extension::Module;
30 using executorch::extension::TensorPtr;
31 using executorch::extension::llm::BPETokenizer;
32 using executorch::extension::llm::Sampler;
33 using executorch::extension::llm::time_in_ms;
34 using executorch::runtime::Error;
35 using executorch::runtime::EValue;
36 using executorch::runtime::MethodMeta;
37 using executorch::runtime::Result;
38 using executorch::runtime::TensorInfo;
39 
40 // TODO: Remove this usage of an internal-only function.
41 using executorch::runtime::internal::set_tensor_data;
42 
43 namespace example {
44 
45 namespace {
46 static constexpr auto kTopp = 0.9f;
47 void printReport(const Runner::Stats& stats);
48 std::string statsToJsonString(const Runner::Stats& stats);
49 } // namespace
50 
Runner(const std::string & model_path,const std::string & tokenizer_path,const float temperature)51 Runner::Runner(
52     const std::string& model_path,
53     const std::string& tokenizer_path,
54     const float temperature)
55     : module_(std::make_unique<Module>(
56           model_path,
57           Module::LoadMode::MmapUseMlockIgnoreErrors)),
58       tokenizer_path_(tokenizer_path),
59       model_path_(model_path),
60       temperature_(temperature) {
61   ET_LOG(
62       Info,
63       "Creating LLaMa runner: model_path=%s, tokenizer_path=%s",
64       model_path.c_str(),
65       tokenizer_path.c_str());
66 }
67 
is_loaded() const68 bool Runner::is_loaded() const {
69   return module_->is_loaded() && tokenizer_ && sampler_;
70 }
71 
load()72 Error Runner::load() {
73   if (is_loaded()) {
74     return Error::Ok;
75   }
76   stats_.model_load_start_ms = time_in_ms();
77   ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward"));
78 
79   // Read out metadata from the model
80   ET_LOG(Info, "Reading metadata from model");
81   const auto method_names = module_->method_names();
82   ET_CHECK_MSG(method_names.ok(), "Failed to read method names from model");
83   model_methods_ = method_names.get();
84   vocab_size_ = getMetadataHelper<int64_t>("get_vocab_size", 32000);
85   bos_id_ = getMetadataHelper<int64_t>("get_bos_id", 1);
86   eos_id_ = getMetadataHelper<int64_t>("get_eos_id", 2);
87   n_bos_ = getMetadataHelper<int64_t>("get_n_bos", 1);
88   n_eos_ = getMetadataHelper<int64_t>("get_n_eos", 1);
89   max_seq_len_ = getMetadataHelper<int64_t>("get_max_seq_len", 128);
90   head_dim_ = getMetadataHelper<int64_t>("get_head_dim", 32);
91   dim_ = getMetadataHelper<int64_t>("get_dim", 4096);
92 
93   // Load tokenizer
94   tokenizer_ = std::make_unique<BPETokenizer>();
95   tokenizer_->load(tokenizer_path_);
96   if (tokenizer_->bos_tok() != bos_id_) {
97     ET_LOG(
98         Error,
99         "Tokenizer's BOS id %lu does not match model's BOS id %ld, will override tokenizer's BOS.",
100         tokenizer_->bos_tok(),
101         bos_id_);
102   }
103   if (tokenizer_->eos_tok() != eos_id_) {
104     ET_LOG(
105         Error,
106         "Tokenizer's EOS id %lu does not match model's EOS id %ld, will override tokenizer's EOS.",
107         tokenizer_->eos_tok(),
108         eos_id_);
109   }
110   // Create sampler
111   sampler_ = std::make_unique<Sampler>(
112       vocab_size_,
113       temperature_,
114       kTopp,
115       static_cast<unsigned long long>(std::time(nullptr)));
116   stats_.model_load_end_ms = time_in_ms();
117 
118   return Error::Ok;
119 }
120 
121 template <typename T>
getMetadataHelper(std::string method_name,T default_val)122 T Runner::getMetadataHelper(std::string method_name, T default_val) {
123   T res = default_val;
124   if (model_methods_.count(method_name)) {
125     Result<std::vector<EValue>> outputs = module_->execute(method_name);
126     if (outputs.ok()) {
127       std::vector<EValue> outs = outputs.get();
128       if (outs.size() > 0) {
129         res = outs[0].to<T>();
130       }
131     }
132   } else {
133     ET_LOG(
134         Info,
135         "The model does not contain %s method, using default value %lld",
136         method_name.c_str(),
137         (long long)default_val);
138   }
139   ET_LOG(Info, "%s: %lld", method_name.c_str(), (long long)res);
140   return res;
141 }
142 
143 template <typename T>
logitsToToken(const Tensor & logits_tensor)144 int32_t Runner::logitsToToken(const Tensor& logits_tensor) {
145   T* logits = logits_tensor.mutable_data_ptr<T>();
146 
147   // Since the logits are for all tokens, get the last token probabilities
148   T* logits_last = logits;
149   return sampler_->sample(logits_last);
150 }
151 
152 // Given an input token. Set up the inputs for the model and execute a single
153 // step. Returning the logits tensor.
run_model_step(int64_t input_token,TensorPtr & token,TensorPtr & start_pos,TensorPtr & atten_mask,std::vector<TensorPtr> & kv_tensors,std::vector<TensorPtr> & kv_outputs)154 Result<Tensor> Runner::run_model_step(
155     int64_t input_token,
156     TensorPtr& token,
157     TensorPtr& start_pos,
158     TensorPtr& atten_mask,
159     std::vector<TensorPtr>& kv_tensors,
160     std::vector<TensorPtr>& kv_outputs) {
161   token->mutable_data_ptr<int32_t>()[0] = input_token;
162 
163   // inputs:[tokens, start_pos, atten_mask, k_cache, v_cache]
164   std::vector<executorch::runtime::EValue> inputs = {
165       token, start_pos, atten_mask};
166   inputs.insert(inputs.end(), kv_tensors.begin(), kv_tensors.end());
167   auto outputs_res = module_->forward(inputs);
168   ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
169 
170   // TODO: need to handle batch size != 1
171   size_t v_offset = kv_outputs[0]->nbytes();
172   size_t el_size = kv_outputs[0]->element_size();
173   size_t k_input_step = (max_seq_len_ - 1) * el_size;
174   int k_tensors_end = kv_tensors.size() / 2;
175   // update k caches
176   for (int j = 0; j < k_tensors_end; ++j) {
177     uint8_t* input_addr =
178         static_cast<uint8_t*>(kv_tensors[j]->mutable_data_ptr());
179     uint8_t* output_addr =
180         static_cast<uint8_t*>(kv_outputs[j]->mutable_data_ptr());
181     // fill the output k values back
182     for (int src = 0, dst = k_input_step; src < kv_outputs[j]->nbytes();
183          src += el_size, dst += k_input_step) {
184       input_addr[dst] = output_addr[src];
185     }
186     char* new_inp_addr = io_mem_mgr_.update_k_caches_read(j, el_size);
187     // inputs
188     ET_CHECK_MSG(
189         set_tensor_data(
190             *kv_tensors[j], new_inp_addr, kv_tensors[j]->nbytes()) == Error::Ok,
191         "Failed to set input tensor when updating k_cache");
192   }
193   // update v caches
194   for (int j = k_tensors_end, v_idx = 0; j < kv_tensors.size(); ++j, ++v_idx) {
195     // inputs
196     char* new_inp_addr = io_mem_mgr_.update_v_caches_read(v_idx, v_offset);
197 
198     ET_CHECK_MSG(
199         set_tensor_data(
200             *kv_tensors[j], new_inp_addr, kv_tensors[j]->nbytes()) == Error::Ok,
201         "Failed to set input tensor when updating v_cache");
202     // outputs
203     char* new_out_addr = io_mem_mgr_.update_v_caches_write(v_idx, v_offset);
204     ET_CHECK_MSG(
205         set_tensor_data(
206             *kv_outputs[j], new_out_addr, kv_outputs[j]->nbytes()) == Error::Ok,
207         "Failed to set output tensor when updating v_cache");
208     ET_CHECK_MSG(
209         module_->set_output(*kv_outputs[j], j + 1) == Error::Ok,
210         "Failed to set llama output data pointer");
211   }
212 
213   // Bump start_pos by 1
214   start_pos->mutable_data_ptr<int32_t>()[0]++;
215 
216   // update atten_mask
217   atten_mask->mutable_data_ptr<float>()
218       [atten_mask->numel() - 1 - start_pos->const_data_ptr<int32_t>()[0]] = 0;
219   return outputs_res.get()[0].toTensor();
220 }
221 // TODO: add overloaded method for on-device tokenize
generate(const std::string & prompt,int32_t seq_len,std::function<void (const std::string &)> token_callback,std::function<void (const Stats &)> stats_callback)222 Error Runner::generate(
223     const std::string& prompt,
224     int32_t seq_len,
225     std::function<void(const std::string&)> token_callback,
226     std::function<void(const Stats&)> stats_callback) {
227   ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
228   ET_CHECK_MSG(is_loaded(), "Please invoke load method first");
229 
230   // First token time only measures the time it takes to encode the prompt and
231   // return a response token.
232   stats_.inference_start_ms = time_in_ms();
233   shouldStop_ = false;
234 
235   // Set the sequence length to the max seq length if not provided
236   seq_len = (seq_len > 0 && seq_len <= max_seq_len_) ? seq_len : max_seq_len_;
237 
238   Result<std::vector<uint64_t>> encode_res =
239       tokenizer_->encode(prompt, n_bos_, 0);
240 
241   ET_CHECK_OK_OR_RETURN_ERROR(
242       encode_res.error(), "Failed to encode prompt %s", prompt.c_str());
243 
244   // encode the (string) prompt into tokens sequence
245   std::vector<uint64_t> prompt_tokens = encode_res.get();
246   int num_prompt_tokens = prompt_tokens.size();
247 
248   ET_CHECK_MSG(
249       num_prompt_tokens < max_seq_len_,
250       "Max seq length exceeded - please increase max seq len value in static_llama.py");
251 
252   ET_CHECK_MSG(
253       num_prompt_tokens < seq_len,
254       "Sequence length exceeded - please increase the seq_len value passed to generate()");
255 
256   int32_t pos = 0, prev_token, cur_token = prompt_tokens[0];
257   std::vector<SizesType> token_shape = {1, 1};
258 
259   io_mem_mgr_.get_input_token_ptr()[0] = 0;
260   std::vector<SizesType> start_pos_shape = {1, 1};
261 
262   float* atten_mask_ptr =
263       reinterpret_cast<float*>(io_mem_mgr_.get_atten_mask_ptr());
264   std::fill(atten_mask_ptr, atten_mask_ptr + max_seq_len_, -255);
265   atten_mask_ptr[max_seq_len_ - 1] = 0;
266 
267   std::vector<SizesType> atten_mask_shape = {1, max_seq_len_};
268 
269   std::vector<SizesType> logits_data_shape = {1, vocab_size_};
270 
271   std::vector<SizesType> hidden_states_data_shape = {1, 1, dim_};
272 
273   // initialize tensor wrappers
274   auto token = from_blob(
275       io_mem_mgr_.get_input_token_ptr(), token_shape, ScalarType::Int);
276   auto start_pos = from_blob(
277       io_mem_mgr_.get_pos_idx_ptr(), start_pos_shape, ScalarType::Int);
278   auto atten_mask = from_blob(
279       io_mem_mgr_.get_atten_mask_ptr(), atten_mask_shape, ScalarType::Float);
280 
281   std::vector<TensorPtr> kv_tensors, kv_outputs;
282 
283   Result<MethodMeta> method_meta = get_method_meta();
284   size_t num_inputs = method_meta->num_inputs();
285   int k_caches_num = (num_inputs - 3) / 2;
286 
287   // TODO: need to handle batch size != 1
288   // k caches init
289   for (int input_index = 3, i = 0; input_index < k_caches_num + 3;
290        ++input_index, ++i) {
291     // inputs
292     Result<TensorInfo> tensor_meta =
293         method_meta->input_tensor_meta(input_index);
294 
295     auto tensor_shape = tensor_meta->sizes();
296     std::vector<SizesType> sizes(
297         tensor_shape.data(), tensor_shape.data() + tensor_shape.size());
298     kv_tensors.emplace_back(from_blob(
299         io_mem_mgr_.get_k_caches_read_ptr(i),
300         sizes,
301         tensor_meta->scalar_type()));
302 
303     // outpus
304     Result<TensorInfo> out_tensor_meta = method_meta->output_tensor_meta(i + 1);
305     tensor_shape = out_tensor_meta->sizes();
306     sizes = std::vector<SizesType>{
307         tensor_shape.data(), tensor_shape.data() + tensor_shape.size()};
308     kv_outputs.emplace_back(from_blob(
309         io_mem_mgr_.get_k_caches_write_ptr(i),
310         sizes,
311         kv_tensors.back()->scalar_type()));
312     ET_CHECK_MSG(
313         module_->set_output(kv_outputs.back(), i + 1) == Error::Ok,
314         "Failed to set output tensor for kv cache");
315   }
316 
317   // v caches init
318   for (int i = 0, input_index = k_caches_num + 3; input_index < num_inputs;
319        ++input_index, ++i) {
320     int output_index = i + k_caches_num + 1;
321     // inputs
322     Result<TensorInfo> tensor_meta =
323         method_meta->input_tensor_meta(input_index);
324     auto tensor_shape = tensor_meta->sizes();
325     std::vector<SizesType> sizes(
326         tensor_shape.data(), tensor_shape.data() + tensor_shape.size());
327 
328     kv_tensors.emplace_back(from_blob(
329         io_mem_mgr_.get_v_caches_read_ptr(i),
330         sizes,
331         tensor_meta->scalar_type()));
332 
333     // outputs
334     Result<TensorInfo> out_tensor_meta =
335         method_meta->output_tensor_meta(output_index);
336     tensor_shape = out_tensor_meta->sizes();
337     sizes = std::vector<SizesType>{
338         tensor_shape.data(), tensor_shape.data() + tensor_shape.size()};
339 
340     kv_outputs.push_back(from_blob(
341         io_mem_mgr_.get_v_caches_write_ptr(i),
342         sizes,
343         kv_tensors.back()->scalar_type()));
344     ET_CHECK_MSG(
345         module_->set_output(kv_outputs.back(), output_index) == Error::Ok,
346         "Failed to set output tensor for llama block");
347   }
348 
349   auto affine_logits = from_blob(
350       reinterpret_cast<float*>(io_mem_mgr_.get_logit_ptr()),
351       logits_data_shape,
352       ScalarType::Float);
353   ET_CHECK_MSG(
354       module_->set_output(affine_logits) == Error::Ok,
355       "Failed to set output tensor for affine module - logits");
356 
357   // Start consuming user's prompts and generating new tokens
358   std::string final_output;
359   while (pos < seq_len - 1) {
360     // Run the model
361     auto logits_res = run_model_step(
362         cur_token, token, start_pos, atten_mask, kv_tensors, kv_outputs);
363     if (pos == num_prompt_tokens) {
364       stats_.first_token_ms = time_in_ms();
365     } else if (pos == num_prompt_tokens - 1) {
366       stats_.prompt_eval_end_ms = time_in_ms();
367     }
368 
369     ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error());
370     Tensor& logits_tensor = logits_res.get();
371     prev_token = cur_token;
372     long sample_start_time_ms = time_in_ms();
373 
374     cur_token = logitsToToken<float>(logits_tensor);
375     stats_.aggregate_sampling_time_ms += time_in_ms() - sample_start_time_ms;
376 
377     // advance the state machine
378     if (pos < num_prompt_tokens - 1) {
379       // prefill, force the next token to be the next prompt token
380       cur_token = prompt_tokens[pos + 1];
381     }
382     pos++;
383 
384     // print the token as string, decode it with the Tokenizer object
385     auto piece_res = tokenizer_->decode(prev_token, cur_token);
386     ET_CHECK(piece_res.ok());
387 
388     if (token_callback) {
389       token_callback(piece_res.get());
390     }
391 
392     if (shouldStop_) {
393       break;
394     }
395 
396     // data-dependent terminating condition: we have n_eos_ number of EOS
397     if (pos >= num_prompt_tokens && cur_token == eos_id_) {
398       ET_LOG(Info, "Reached to the end of generation");
399       break;
400     }
401   }
402   stats_.inference_end_ms = time_in_ms();
403 
404   if (pos == seq_len) {
405     ET_LOG(Info, "Sequence length (%i tokens) reached!", seq_len);
406   }
407 
408   stats_.num_prompt_tokens = num_prompt_tokens;
409   stats_.num_generated_tokens = pos - num_prompt_tokens;
410   printReport(stats_);
411   if (stats_callback) {
412     stats_callback(stats_);
413   }
414 
415   return Error::Ok;
416 }
417 
418 namespace {
printReport(const Runner::Stats & stats)419 void printReport(const Runner::Stats& stats) {
420   printf("PyTorchObserver %s\n", statsToJsonString(stats).c_str());
421 
422   ET_LOG(
423       Info,
424       "\tPrompt Tokens: %" PRIu64 "    Generated Tokens: %" PRIu64,
425       stats.num_prompt_tokens,
426       stats.num_generated_tokens);
427 
428   ET_LOG(
429       Info,
430       "\tModel Load Time:\t\t%f (seconds)",
431       ((double)(stats.model_load_end_ms - stats.model_load_start_ms) /
432        stats.SCALING_FACTOR_UNITS_PER_SECOND));
433   double inference_time_ms =
434       (double)(stats.inference_end_ms - stats.inference_start_ms);
435   ET_LOG(
436       Info,
437       "\tTotal inference time:\t\t%f (seconds)\t\t Rate: \t%f (tokens/second)",
438       inference_time_ms / stats.SCALING_FACTOR_UNITS_PER_SECOND,
439 
440       (stats.num_generated_tokens) /
441           (double)(stats.inference_end_ms - stats.inference_start_ms) *
442           stats.SCALING_FACTOR_UNITS_PER_SECOND);
443   double prompt_eval_time =
444       (double)(stats.prompt_eval_end_ms - stats.inference_start_ms);
445   ET_LOG(
446       Info,
447       "\t\tPrompt evaluation:\t%f (seconds)\t\t Rate: \t%f (tokens/second)",
448       prompt_eval_time / stats.SCALING_FACTOR_UNITS_PER_SECOND,
449       (stats.num_prompt_tokens) / prompt_eval_time *
450           stats.SCALING_FACTOR_UNITS_PER_SECOND);
451 
452   double eval_time =
453       (double)(stats.inference_end_ms - stats.prompt_eval_end_ms);
454   ET_LOG(
455       Info,
456       "\t\tGenerated %" PRIu64
457       " tokens:\t%f (seconds)\t\t Rate: \t%f (tokens/second)",
458       stats.num_generated_tokens,
459       eval_time / stats.SCALING_FACTOR_UNITS_PER_SECOND,
460       stats.num_generated_tokens / eval_time *
461           stats.SCALING_FACTOR_UNITS_PER_SECOND);
462 
463   // Time to first token is measured from the start of inference, excluding
464   // model load time.
465   ET_LOG(
466       Info,
467       "\tTime to first generated token:\t%f (seconds)",
468       ((double)(stats.first_token_ms - stats.inference_start_ms) /
469        stats.SCALING_FACTOR_UNITS_PER_SECOND));
470 
471   ET_LOG(
472       Info,
473       "\tSampling time over %" PRIu64 " tokens:\t%f (seconds)",
474       stats.num_prompt_tokens + stats.num_generated_tokens,
475       (double)stats.aggregate_sampling_time_ms /
476           stats.SCALING_FACTOR_UNITS_PER_SECOND);
477 }
478 
statsToJsonString(const Runner::Stats & stats)479 std::string statsToJsonString(const Runner::Stats& stats) {
480   std::stringstream ss;
481   ss << "{\"prompt_tokens\":" << stats.num_prompt_tokens << ","
482      << "\"generated_tokens\":" << stats.num_generated_tokens << ","
483      << "\"model_load_start_ms\":" << stats.model_load_start_ms << ","
484      << "\"model_load_end_ms\":" << stats.model_load_end_ms << ","
485      << "\"inference_start_ms\":" << stats.inference_start_ms << ","
486      << "\"inference_end_ms\":" << stats.inference_end_ms << ","
487      << "\"prompt_eval_end_ms\":" << stats.prompt_eval_end_ms << ","
488      << "\"first_token_ms\":" << stats.first_token_ms << ","
489      << "\"aggregate_sampling_time_ms\":" << stats.aggregate_sampling_time_ms
490      << "," << "\"SCALING_FACTOR_UNITS_PER_SECOND\":"
491      << stats.SCALING_FACTOR_UNITS_PER_SECOND << "}";
492   return ss.str();
493 }
494 } // namespace
495 
IoMemMgr(MethodMeta method_meta)496 IoMemMgr::IoMemMgr(MethodMeta method_meta) {
497   method_meta_ = std::make_unique<MethodMeta>(method_meta);
498   init_io_info();
499   compute_total_nbytes();
500 }
501 
init_io_info()502 void IoMemMgr::init_io_info() {
503   set_tensor_meta();
504   for (auto info : io_info_.tensor_info) {
505     info->size = info->tensor_meta->nbytes();
506     info->rank = info->tensor_meta->sizes().size();
507     info->shape.resize(info->rank);
508     for (int i = 0; i < info->rank; i++) {
509       info->shape[i] =
510           static_cast<uint32_t>(info->tensor_meta->sizes().data()[i]);
511     }
512     info->dtype = info->tensor_meta->scalar_type();
513     info->element_size = scalar_type_to_size[info->tensor_meta->scalar_type()];
514   }
515 };
516 
set_tensor_meta()517 void IoMemMgr::set_tensor_meta() {
518   io_info_.input_token.tensor_meta =
519       std::make_unique<TensorInfo>(method_meta_->input_tensor_meta(0).get());
520   io_info_.pos_idx.tensor_meta =
521       std::make_unique<TensorInfo>(method_meta_->input_tensor_meta(1).get());
522   io_info_.atten_mask.tensor_meta =
523       std::make_unique<TensorInfo>(method_meta_->input_tensor_meta(2).get());
524 
525   io_info_.k_caches_read.tensor_meta =
526       std::make_unique<TensorInfo>(method_meta_->input_tensor_meta(3).get());
527   io_info_.k_caches_write.tensor_meta =
528       std::make_unique<TensorInfo>(method_meta_->output_tensor_meta(1).get());
529 
530   io_info_.v_caches_read.tensor_meta = std::make_unique<TensorInfo>(
531       method_meta_->input_tensor_meta(method_meta_->num_inputs() - 1).get());
532   io_info_.v_caches_write.tensor_meta = std::make_unique<TensorInfo>(
533       method_meta_->output_tensor_meta(method_meta_->num_outputs() - 1).get());
534 
535   io_info_.logit.tensor_meta =
536       std::make_unique<TensorInfo>(method_meta_->output_tensor_meta(0).get());
537 }
538 
compute_total_nbytes()539 void IoMemMgr::compute_total_nbytes() {
540   total_nbytes_ = io_info_.input_token.size + io_info_.pos_idx.size +
541       io_info_.atten_mask.size + io_info_.logit.size;
542   size_t num_heads = (method_meta_->num_inputs() - 3) / 2;
543 
544   // To update v cache via shifting pointer, v caches need a buffer with size
545   // of (max_seq_len_ - 1) * head_dim_. It is equivalent to one more cache
546   size_t num_v_cache = num_heads + 1;
547   // To update v cache via shifting pointer, k buffer need the size of
548   // max_seq_len - 1
549   size_t k_buffer = io_info_.k_caches_read.size / io_info_.k_caches_write.size;
550 
551   // k_caches_read need a buffer with size of head_dim_
552   total_nbytes_ += num_heads * io_info_.k_caches_read.size + k_buffer;
553   total_nbytes_ += num_heads * io_info_.k_caches_write.size;
554   total_nbytes_ += num_v_cache * io_info_.v_caches_read.size;
555   // Add a head dim size for the convinience of shifting ptr from the last
556   // non-used v cache write
557   total_nbytes_ += io_info_.v_caches_write.size;
558 }
559 
init_tensors()560 bool IoMemMgr::init_tensors() {
561   size_t cur_pos = input_token_pos_;
562   pos_idx_pos_ = cur_pos += io_info_.input_token.size;
563   atten_mask_pos_ = cur_pos += io_info_.pos_idx.size;
564   logit_pos_ = cur_pos += io_info_.atten_mask.size;
565   set_input_token_ptr();
566   set_pos_idx_ptr();
567   set_atten_mask_ptr();
568   set_logit_ptr();
569 
570   // set start point of kv caches
571   cur_pos += io_info_.logit.size;
572 
573   size_t num_heads = (method_meta_->num_inputs() - 3) / 2;
574   k_caches_read_pos_.resize(num_heads);
575   k_caches_write_pos_.resize(num_heads);
576   v_caches_read_pos_.resize(num_heads);
577   v_caches_write_pos_.resize(num_heads);
578 
579   for (int i = 0; i < num_heads; i++) {
580     set_k_caches_read(i, cur_pos);
581     cur_pos += io_info_.k_caches_read.size;
582   }
583   // add a size of k caches buffer
584   cur_pos += io_info_.k_caches_read.size / io_info_.k_caches_write.size;
585   for (int i = 0; i < num_heads; i++) {
586     set_k_caches_write(i, cur_pos);
587     cur_pos += io_info_.k_caches_write.size;
588   }
589 
590   for (int i = 0; i < num_heads; i++) {
591     set_v_caches_read(i, cur_pos);
592     set_v_caches_write(i, cur_pos + io_info_.v_caches_read.size);
593     cur_pos += io_info_.v_caches_read.size;
594   }
595   // add a caches as the b caches buffer
596   cur_pos += io_info_.v_caches_read.size;
597   return cur_pos <= total_nbytes_;
598 }
599 
set_all_shifted_ptrs(size_t seq_len)600 void IoMemMgr::set_all_shifted_ptrs(size_t seq_len) {
601   auto iter_setter = [&](std::vector<size_t>& cache,
602                          size_t shift_size,
603                          InfoAttrs& tensor_info) {
604     for (int i = 0; i < cache.size(); ++i) {
605       size_t pos = cache[i] + shift_size;
606       CustomMemTensorInfo info = {
607           ptr_,
608           ptr_ + pos,
609           pos,
610           tensor_info.size,
611           tensor_info.shape.data(),
612           tensor_info.rank,
613           tensor_info.dtype};
614       QnnExecuTorchAddCustomMemTensorInfo(info);
615     }
616   };
617   for (int i = 0; i < seq_len; ++i) {
618     iter_setter(
619         k_caches_read_pos_,
620         i * io_info_.k_caches_read.element_size,
621         io_info_.k_caches_read);
622     iter_setter(
623         v_caches_read_pos_,
624         i * io_info_.v_caches_write.size,
625         io_info_.v_caches_read);
626     iter_setter(
627         v_caches_write_pos_,
628         i * io_info_.v_caches_write.size,
629         io_info_.v_caches_write);
630   }
631 }
632 
stop()633 void Runner::stop() {
634   shouldStop_ = true;
635 }
636 
get_method_meta()637 Result<MethodMeta> Runner::get_method_meta() {
638   return module_->method_meta("forward");
639 }
640 
mem_alloc(size_t alignment,size_t seq_len)641 Error Runner::mem_alloc(size_t alignment, size_t seq_len) {
642   Result<MethodMeta> method_meta_result = get_method_meta();
643   io_mem_mgr_ = IoMemMgr(method_meta_result.get());
644   ET_CHECK_MSG(
645       io_mem_mgr_.allocate(alignment),
646       "IoMemMgr failed to allocate custom memory");
647 
648   ET_CHECK_MSG(
649       io_mem_mgr_.init_tensors(),
650       "IoMemMgr required more bytes than allocated bytes");
651 
652   io_mem_mgr_.set_all_shifted_ptrs(seq_len);
653   // To register rpc_mem_handle from SharedBuffer
654   // Reset and re-init again to trigger registered function
655   module_.reset();
656   module_ = std::make_unique<Module>(
657       model_path_, Module::LoadMode::MmapUseMlockIgnoreErrors);
658   ET_CHECK_MSG(load() == Error::Ok, "Runner failed to load method");
659 
660   return Error::Ok;
661 }
662 
663 // explicit instantiation of template methods
664 template int64_t Runner::getMetadataHelper<int64_t>(
665     std::string method_name,
666     int64_t default_val);
667 template bool Runner::getMetadataHelper<bool>(
668     std::string method_name,
669     bool default_val);
670 
671 } // namespace example
672