1 /*
2 * Copyright (c) Qualcomm Innovation Center, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 // A simple llama2 runner that includes preprocessing and post processing logic.
10 // The module takes in a string as input and emits a string as output.
11
12 #include <executorch/examples/qualcomm/oss_scripts/llama2/runner/runner.h>
13 #include <executorch/extension/evalue_util/print_evalue.h>
14 #include <executorch/extension/llm/runner/util.h>
15 #include <executorch/extension/llm/tokenizer/bpe_tokenizer.h>
16 #include <executorch/runtime/core/exec_aten/exec_aten.h>
17 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
18 #include <executorch/runtime/core/exec_aten/util/tensor_util.h>
19 #include <executorch/runtime/platform/log.h>
20
21 #include <ctime>
22 #include <memory>
23 #include <sstream>
24
25 using executorch::aten::ScalarType;
26 using executorch::aten::SizesType;
27 using executorch::aten::Tensor;
28 using executorch::extension::from_blob;
29 using executorch::extension::Module;
30 using executorch::extension::TensorPtr;
31 using executorch::extension::llm::BPETokenizer;
32 using executorch::extension::llm::Sampler;
33 using executorch::extension::llm::time_in_ms;
34 using executorch::runtime::Error;
35 using executorch::runtime::EValue;
36 using executorch::runtime::MethodMeta;
37 using executorch::runtime::Result;
38 using executorch::runtime::TensorInfo;
39
40 // TODO: Remove this usage of an internal-only function.
41 using executorch::runtime::internal::set_tensor_data;
42
43 namespace example {
44
45 namespace {
46 static constexpr auto kTopp = 0.9f;
47 void printReport(const Runner::Stats& stats);
48 std::string statsToJsonString(const Runner::Stats& stats);
49 } // namespace
50
Runner(const std::string & model_path,const std::string & tokenizer_path,const float temperature)51 Runner::Runner(
52 const std::string& model_path,
53 const std::string& tokenizer_path,
54 const float temperature)
55 : module_(std::make_unique<Module>(
56 model_path,
57 Module::LoadMode::MmapUseMlockIgnoreErrors)),
58 tokenizer_path_(tokenizer_path),
59 model_path_(model_path),
60 temperature_(temperature) {
61 ET_LOG(
62 Info,
63 "Creating LLaMa runner: model_path=%s, tokenizer_path=%s",
64 model_path.c_str(),
65 tokenizer_path.c_str());
66 }
67
is_loaded() const68 bool Runner::is_loaded() const {
69 return module_->is_loaded() && tokenizer_ && sampler_;
70 }
71
load()72 Error Runner::load() {
73 if (is_loaded()) {
74 return Error::Ok;
75 }
76 stats_.model_load_start_ms = time_in_ms();
77 ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward"));
78
79 // Read out metadata from the model
80 ET_LOG(Info, "Reading metadata from model");
81 const auto method_names = module_->method_names();
82 ET_CHECK_MSG(method_names.ok(), "Failed to read method names from model");
83 model_methods_ = method_names.get();
84 vocab_size_ = getMetadataHelper<int64_t>("get_vocab_size", 32000);
85 bos_id_ = getMetadataHelper<int64_t>("get_bos_id", 1);
86 eos_id_ = getMetadataHelper<int64_t>("get_eos_id", 2);
87 n_bos_ = getMetadataHelper<int64_t>("get_n_bos", 1);
88 n_eos_ = getMetadataHelper<int64_t>("get_n_eos", 1);
89 max_seq_len_ = getMetadataHelper<int64_t>("get_max_seq_len", 128);
90 head_dim_ = getMetadataHelper<int64_t>("get_head_dim", 32);
91 dim_ = getMetadataHelper<int64_t>("get_dim", 4096);
92
93 // Load tokenizer
94 tokenizer_ = std::make_unique<BPETokenizer>();
95 tokenizer_->load(tokenizer_path_);
96 if (tokenizer_->bos_tok() != bos_id_) {
97 ET_LOG(
98 Error,
99 "Tokenizer's BOS id %lu does not match model's BOS id %ld, will override tokenizer's BOS.",
100 tokenizer_->bos_tok(),
101 bos_id_);
102 }
103 if (tokenizer_->eos_tok() != eos_id_) {
104 ET_LOG(
105 Error,
106 "Tokenizer's EOS id %lu does not match model's EOS id %ld, will override tokenizer's EOS.",
107 tokenizer_->eos_tok(),
108 eos_id_);
109 }
110 // Create sampler
111 sampler_ = std::make_unique<Sampler>(
112 vocab_size_,
113 temperature_,
114 kTopp,
115 static_cast<unsigned long long>(std::time(nullptr)));
116 stats_.model_load_end_ms = time_in_ms();
117
118 return Error::Ok;
119 }
120
121 template <typename T>
getMetadataHelper(std::string method_name,T default_val)122 T Runner::getMetadataHelper(std::string method_name, T default_val) {
123 T res = default_val;
124 if (model_methods_.count(method_name)) {
125 Result<std::vector<EValue>> outputs = module_->execute(method_name);
126 if (outputs.ok()) {
127 std::vector<EValue> outs = outputs.get();
128 if (outs.size() > 0) {
129 res = outs[0].to<T>();
130 }
131 }
132 } else {
133 ET_LOG(
134 Info,
135 "The model does not contain %s method, using default value %lld",
136 method_name.c_str(),
137 (long long)default_val);
138 }
139 ET_LOG(Info, "%s: %lld", method_name.c_str(), (long long)res);
140 return res;
141 }
142
143 template <typename T>
logitsToToken(const Tensor & logits_tensor)144 int32_t Runner::logitsToToken(const Tensor& logits_tensor) {
145 T* logits = logits_tensor.mutable_data_ptr<T>();
146
147 // Since the logits are for all tokens, get the last token probabilities
148 T* logits_last = logits;
149 return sampler_->sample(logits_last);
150 }
151
152 // Given an input token. Set up the inputs for the model and execute a single
153 // step. Returning the logits tensor.
run_model_step(int64_t input_token,TensorPtr & token,TensorPtr & start_pos,TensorPtr & atten_mask,std::vector<TensorPtr> & kv_tensors,std::vector<TensorPtr> & kv_outputs)154 Result<Tensor> Runner::run_model_step(
155 int64_t input_token,
156 TensorPtr& token,
157 TensorPtr& start_pos,
158 TensorPtr& atten_mask,
159 std::vector<TensorPtr>& kv_tensors,
160 std::vector<TensorPtr>& kv_outputs) {
161 token->mutable_data_ptr<int32_t>()[0] = input_token;
162
163 // inputs:[tokens, start_pos, atten_mask, k_cache, v_cache]
164 std::vector<executorch::runtime::EValue> inputs = {
165 token, start_pos, atten_mask};
166 inputs.insert(inputs.end(), kv_tensors.begin(), kv_tensors.end());
167 auto outputs_res = module_->forward(inputs);
168 ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
169
170 // TODO: need to handle batch size != 1
171 size_t v_offset = kv_outputs[0]->nbytes();
172 size_t el_size = kv_outputs[0]->element_size();
173 size_t k_input_step = (max_seq_len_ - 1) * el_size;
174 int k_tensors_end = kv_tensors.size() / 2;
175 // update k caches
176 for (int j = 0; j < k_tensors_end; ++j) {
177 uint8_t* input_addr =
178 static_cast<uint8_t*>(kv_tensors[j]->mutable_data_ptr());
179 uint8_t* output_addr =
180 static_cast<uint8_t*>(kv_outputs[j]->mutable_data_ptr());
181 // fill the output k values back
182 for (int src = 0, dst = k_input_step; src < kv_outputs[j]->nbytes();
183 src += el_size, dst += k_input_step) {
184 input_addr[dst] = output_addr[src];
185 }
186 char* new_inp_addr = io_mem_mgr_.update_k_caches_read(j, el_size);
187 // inputs
188 ET_CHECK_MSG(
189 set_tensor_data(
190 *kv_tensors[j], new_inp_addr, kv_tensors[j]->nbytes()) == Error::Ok,
191 "Failed to set input tensor when updating k_cache");
192 }
193 // update v caches
194 for (int j = k_tensors_end, v_idx = 0; j < kv_tensors.size(); ++j, ++v_idx) {
195 // inputs
196 char* new_inp_addr = io_mem_mgr_.update_v_caches_read(v_idx, v_offset);
197
198 ET_CHECK_MSG(
199 set_tensor_data(
200 *kv_tensors[j], new_inp_addr, kv_tensors[j]->nbytes()) == Error::Ok,
201 "Failed to set input tensor when updating v_cache");
202 // outputs
203 char* new_out_addr = io_mem_mgr_.update_v_caches_write(v_idx, v_offset);
204 ET_CHECK_MSG(
205 set_tensor_data(
206 *kv_outputs[j], new_out_addr, kv_outputs[j]->nbytes()) == Error::Ok,
207 "Failed to set output tensor when updating v_cache");
208 ET_CHECK_MSG(
209 module_->set_output(*kv_outputs[j], j + 1) == Error::Ok,
210 "Failed to set llama output data pointer");
211 }
212
213 // Bump start_pos by 1
214 start_pos->mutable_data_ptr<int32_t>()[0]++;
215
216 // update atten_mask
217 atten_mask->mutable_data_ptr<float>()
218 [atten_mask->numel() - 1 - start_pos->const_data_ptr<int32_t>()[0]] = 0;
219 return outputs_res.get()[0].toTensor();
220 }
221 // TODO: add overloaded method for on-device tokenize
generate(const std::string & prompt,int32_t seq_len,std::function<void (const std::string &)> token_callback,std::function<void (const Stats &)> stats_callback)222 Error Runner::generate(
223 const std::string& prompt,
224 int32_t seq_len,
225 std::function<void(const std::string&)> token_callback,
226 std::function<void(const Stats&)> stats_callback) {
227 ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null");
228 ET_CHECK_MSG(is_loaded(), "Please invoke load method first");
229
230 // First token time only measures the time it takes to encode the prompt and
231 // return a response token.
232 stats_.inference_start_ms = time_in_ms();
233 shouldStop_ = false;
234
235 // Set the sequence length to the max seq length if not provided
236 seq_len = (seq_len > 0 && seq_len <= max_seq_len_) ? seq_len : max_seq_len_;
237
238 Result<std::vector<uint64_t>> encode_res =
239 tokenizer_->encode(prompt, n_bos_, 0);
240
241 ET_CHECK_OK_OR_RETURN_ERROR(
242 encode_res.error(), "Failed to encode prompt %s", prompt.c_str());
243
244 // encode the (string) prompt into tokens sequence
245 std::vector<uint64_t> prompt_tokens = encode_res.get();
246 int num_prompt_tokens = prompt_tokens.size();
247
248 ET_CHECK_MSG(
249 num_prompt_tokens < max_seq_len_,
250 "Max seq length exceeded - please increase max seq len value in static_llama.py");
251
252 ET_CHECK_MSG(
253 num_prompt_tokens < seq_len,
254 "Sequence length exceeded - please increase the seq_len value passed to generate()");
255
256 int32_t pos = 0, prev_token, cur_token = prompt_tokens[0];
257 std::vector<SizesType> token_shape = {1, 1};
258
259 io_mem_mgr_.get_input_token_ptr()[0] = 0;
260 std::vector<SizesType> start_pos_shape = {1, 1};
261
262 float* atten_mask_ptr =
263 reinterpret_cast<float*>(io_mem_mgr_.get_atten_mask_ptr());
264 std::fill(atten_mask_ptr, atten_mask_ptr + max_seq_len_, -255);
265 atten_mask_ptr[max_seq_len_ - 1] = 0;
266
267 std::vector<SizesType> atten_mask_shape = {1, max_seq_len_};
268
269 std::vector<SizesType> logits_data_shape = {1, vocab_size_};
270
271 std::vector<SizesType> hidden_states_data_shape = {1, 1, dim_};
272
273 // initialize tensor wrappers
274 auto token = from_blob(
275 io_mem_mgr_.get_input_token_ptr(), token_shape, ScalarType::Int);
276 auto start_pos = from_blob(
277 io_mem_mgr_.get_pos_idx_ptr(), start_pos_shape, ScalarType::Int);
278 auto atten_mask = from_blob(
279 io_mem_mgr_.get_atten_mask_ptr(), atten_mask_shape, ScalarType::Float);
280
281 std::vector<TensorPtr> kv_tensors, kv_outputs;
282
283 Result<MethodMeta> method_meta = get_method_meta();
284 size_t num_inputs = method_meta->num_inputs();
285 int k_caches_num = (num_inputs - 3) / 2;
286
287 // TODO: need to handle batch size != 1
288 // k caches init
289 for (int input_index = 3, i = 0; input_index < k_caches_num + 3;
290 ++input_index, ++i) {
291 // inputs
292 Result<TensorInfo> tensor_meta =
293 method_meta->input_tensor_meta(input_index);
294
295 auto tensor_shape = tensor_meta->sizes();
296 std::vector<SizesType> sizes(
297 tensor_shape.data(), tensor_shape.data() + tensor_shape.size());
298 kv_tensors.emplace_back(from_blob(
299 io_mem_mgr_.get_k_caches_read_ptr(i),
300 sizes,
301 tensor_meta->scalar_type()));
302
303 // outpus
304 Result<TensorInfo> out_tensor_meta = method_meta->output_tensor_meta(i + 1);
305 tensor_shape = out_tensor_meta->sizes();
306 sizes = std::vector<SizesType>{
307 tensor_shape.data(), tensor_shape.data() + tensor_shape.size()};
308 kv_outputs.emplace_back(from_blob(
309 io_mem_mgr_.get_k_caches_write_ptr(i),
310 sizes,
311 kv_tensors.back()->scalar_type()));
312 ET_CHECK_MSG(
313 module_->set_output(kv_outputs.back(), i + 1) == Error::Ok,
314 "Failed to set output tensor for kv cache");
315 }
316
317 // v caches init
318 for (int i = 0, input_index = k_caches_num + 3; input_index < num_inputs;
319 ++input_index, ++i) {
320 int output_index = i + k_caches_num + 1;
321 // inputs
322 Result<TensorInfo> tensor_meta =
323 method_meta->input_tensor_meta(input_index);
324 auto tensor_shape = tensor_meta->sizes();
325 std::vector<SizesType> sizes(
326 tensor_shape.data(), tensor_shape.data() + tensor_shape.size());
327
328 kv_tensors.emplace_back(from_blob(
329 io_mem_mgr_.get_v_caches_read_ptr(i),
330 sizes,
331 tensor_meta->scalar_type()));
332
333 // outputs
334 Result<TensorInfo> out_tensor_meta =
335 method_meta->output_tensor_meta(output_index);
336 tensor_shape = out_tensor_meta->sizes();
337 sizes = std::vector<SizesType>{
338 tensor_shape.data(), tensor_shape.data() + tensor_shape.size()};
339
340 kv_outputs.push_back(from_blob(
341 io_mem_mgr_.get_v_caches_write_ptr(i),
342 sizes,
343 kv_tensors.back()->scalar_type()));
344 ET_CHECK_MSG(
345 module_->set_output(kv_outputs.back(), output_index) == Error::Ok,
346 "Failed to set output tensor for llama block");
347 }
348
349 auto affine_logits = from_blob(
350 reinterpret_cast<float*>(io_mem_mgr_.get_logit_ptr()),
351 logits_data_shape,
352 ScalarType::Float);
353 ET_CHECK_MSG(
354 module_->set_output(affine_logits) == Error::Ok,
355 "Failed to set output tensor for affine module - logits");
356
357 // Start consuming user's prompts and generating new tokens
358 std::string final_output;
359 while (pos < seq_len - 1) {
360 // Run the model
361 auto logits_res = run_model_step(
362 cur_token, token, start_pos, atten_mask, kv_tensors, kv_outputs);
363 if (pos == num_prompt_tokens) {
364 stats_.first_token_ms = time_in_ms();
365 } else if (pos == num_prompt_tokens - 1) {
366 stats_.prompt_eval_end_ms = time_in_ms();
367 }
368
369 ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error());
370 Tensor& logits_tensor = logits_res.get();
371 prev_token = cur_token;
372 long sample_start_time_ms = time_in_ms();
373
374 cur_token = logitsToToken<float>(logits_tensor);
375 stats_.aggregate_sampling_time_ms += time_in_ms() - sample_start_time_ms;
376
377 // advance the state machine
378 if (pos < num_prompt_tokens - 1) {
379 // prefill, force the next token to be the next prompt token
380 cur_token = prompt_tokens[pos + 1];
381 }
382 pos++;
383
384 // print the token as string, decode it with the Tokenizer object
385 auto piece_res = tokenizer_->decode(prev_token, cur_token);
386 ET_CHECK(piece_res.ok());
387
388 if (token_callback) {
389 token_callback(piece_res.get());
390 }
391
392 if (shouldStop_) {
393 break;
394 }
395
396 // data-dependent terminating condition: we have n_eos_ number of EOS
397 if (pos >= num_prompt_tokens && cur_token == eos_id_) {
398 ET_LOG(Info, "Reached to the end of generation");
399 break;
400 }
401 }
402 stats_.inference_end_ms = time_in_ms();
403
404 if (pos == seq_len) {
405 ET_LOG(Info, "Sequence length (%i tokens) reached!", seq_len);
406 }
407
408 stats_.num_prompt_tokens = num_prompt_tokens;
409 stats_.num_generated_tokens = pos - num_prompt_tokens;
410 printReport(stats_);
411 if (stats_callback) {
412 stats_callback(stats_);
413 }
414
415 return Error::Ok;
416 }
417
418 namespace {
printReport(const Runner::Stats & stats)419 void printReport(const Runner::Stats& stats) {
420 printf("PyTorchObserver %s\n", statsToJsonString(stats).c_str());
421
422 ET_LOG(
423 Info,
424 "\tPrompt Tokens: %" PRIu64 " Generated Tokens: %" PRIu64,
425 stats.num_prompt_tokens,
426 stats.num_generated_tokens);
427
428 ET_LOG(
429 Info,
430 "\tModel Load Time:\t\t%f (seconds)",
431 ((double)(stats.model_load_end_ms - stats.model_load_start_ms) /
432 stats.SCALING_FACTOR_UNITS_PER_SECOND));
433 double inference_time_ms =
434 (double)(stats.inference_end_ms - stats.inference_start_ms);
435 ET_LOG(
436 Info,
437 "\tTotal inference time:\t\t%f (seconds)\t\t Rate: \t%f (tokens/second)",
438 inference_time_ms / stats.SCALING_FACTOR_UNITS_PER_SECOND,
439
440 (stats.num_generated_tokens) /
441 (double)(stats.inference_end_ms - stats.inference_start_ms) *
442 stats.SCALING_FACTOR_UNITS_PER_SECOND);
443 double prompt_eval_time =
444 (double)(stats.prompt_eval_end_ms - stats.inference_start_ms);
445 ET_LOG(
446 Info,
447 "\t\tPrompt evaluation:\t%f (seconds)\t\t Rate: \t%f (tokens/second)",
448 prompt_eval_time / stats.SCALING_FACTOR_UNITS_PER_SECOND,
449 (stats.num_prompt_tokens) / prompt_eval_time *
450 stats.SCALING_FACTOR_UNITS_PER_SECOND);
451
452 double eval_time =
453 (double)(stats.inference_end_ms - stats.prompt_eval_end_ms);
454 ET_LOG(
455 Info,
456 "\t\tGenerated %" PRIu64
457 " tokens:\t%f (seconds)\t\t Rate: \t%f (tokens/second)",
458 stats.num_generated_tokens,
459 eval_time / stats.SCALING_FACTOR_UNITS_PER_SECOND,
460 stats.num_generated_tokens / eval_time *
461 stats.SCALING_FACTOR_UNITS_PER_SECOND);
462
463 // Time to first token is measured from the start of inference, excluding
464 // model load time.
465 ET_LOG(
466 Info,
467 "\tTime to first generated token:\t%f (seconds)",
468 ((double)(stats.first_token_ms - stats.inference_start_ms) /
469 stats.SCALING_FACTOR_UNITS_PER_SECOND));
470
471 ET_LOG(
472 Info,
473 "\tSampling time over %" PRIu64 " tokens:\t%f (seconds)",
474 stats.num_prompt_tokens + stats.num_generated_tokens,
475 (double)stats.aggregate_sampling_time_ms /
476 stats.SCALING_FACTOR_UNITS_PER_SECOND);
477 }
478
statsToJsonString(const Runner::Stats & stats)479 std::string statsToJsonString(const Runner::Stats& stats) {
480 std::stringstream ss;
481 ss << "{\"prompt_tokens\":" << stats.num_prompt_tokens << ","
482 << "\"generated_tokens\":" << stats.num_generated_tokens << ","
483 << "\"model_load_start_ms\":" << stats.model_load_start_ms << ","
484 << "\"model_load_end_ms\":" << stats.model_load_end_ms << ","
485 << "\"inference_start_ms\":" << stats.inference_start_ms << ","
486 << "\"inference_end_ms\":" << stats.inference_end_ms << ","
487 << "\"prompt_eval_end_ms\":" << stats.prompt_eval_end_ms << ","
488 << "\"first_token_ms\":" << stats.first_token_ms << ","
489 << "\"aggregate_sampling_time_ms\":" << stats.aggregate_sampling_time_ms
490 << "," << "\"SCALING_FACTOR_UNITS_PER_SECOND\":"
491 << stats.SCALING_FACTOR_UNITS_PER_SECOND << "}";
492 return ss.str();
493 }
494 } // namespace
495
IoMemMgr(MethodMeta method_meta)496 IoMemMgr::IoMemMgr(MethodMeta method_meta) {
497 method_meta_ = std::make_unique<MethodMeta>(method_meta);
498 init_io_info();
499 compute_total_nbytes();
500 }
501
init_io_info()502 void IoMemMgr::init_io_info() {
503 set_tensor_meta();
504 for (auto info : io_info_.tensor_info) {
505 info->size = info->tensor_meta->nbytes();
506 info->rank = info->tensor_meta->sizes().size();
507 info->shape.resize(info->rank);
508 for (int i = 0; i < info->rank; i++) {
509 info->shape[i] =
510 static_cast<uint32_t>(info->tensor_meta->sizes().data()[i]);
511 }
512 info->dtype = info->tensor_meta->scalar_type();
513 info->element_size = scalar_type_to_size[info->tensor_meta->scalar_type()];
514 }
515 };
516
set_tensor_meta()517 void IoMemMgr::set_tensor_meta() {
518 io_info_.input_token.tensor_meta =
519 std::make_unique<TensorInfo>(method_meta_->input_tensor_meta(0).get());
520 io_info_.pos_idx.tensor_meta =
521 std::make_unique<TensorInfo>(method_meta_->input_tensor_meta(1).get());
522 io_info_.atten_mask.tensor_meta =
523 std::make_unique<TensorInfo>(method_meta_->input_tensor_meta(2).get());
524
525 io_info_.k_caches_read.tensor_meta =
526 std::make_unique<TensorInfo>(method_meta_->input_tensor_meta(3).get());
527 io_info_.k_caches_write.tensor_meta =
528 std::make_unique<TensorInfo>(method_meta_->output_tensor_meta(1).get());
529
530 io_info_.v_caches_read.tensor_meta = std::make_unique<TensorInfo>(
531 method_meta_->input_tensor_meta(method_meta_->num_inputs() - 1).get());
532 io_info_.v_caches_write.tensor_meta = std::make_unique<TensorInfo>(
533 method_meta_->output_tensor_meta(method_meta_->num_outputs() - 1).get());
534
535 io_info_.logit.tensor_meta =
536 std::make_unique<TensorInfo>(method_meta_->output_tensor_meta(0).get());
537 }
538
compute_total_nbytes()539 void IoMemMgr::compute_total_nbytes() {
540 total_nbytes_ = io_info_.input_token.size + io_info_.pos_idx.size +
541 io_info_.atten_mask.size + io_info_.logit.size;
542 size_t num_heads = (method_meta_->num_inputs() - 3) / 2;
543
544 // To update v cache via shifting pointer, v caches need a buffer with size
545 // of (max_seq_len_ - 1) * head_dim_. It is equivalent to one more cache
546 size_t num_v_cache = num_heads + 1;
547 // To update v cache via shifting pointer, k buffer need the size of
548 // max_seq_len - 1
549 size_t k_buffer = io_info_.k_caches_read.size / io_info_.k_caches_write.size;
550
551 // k_caches_read need a buffer with size of head_dim_
552 total_nbytes_ += num_heads * io_info_.k_caches_read.size + k_buffer;
553 total_nbytes_ += num_heads * io_info_.k_caches_write.size;
554 total_nbytes_ += num_v_cache * io_info_.v_caches_read.size;
555 // Add a head dim size for the convinience of shifting ptr from the last
556 // non-used v cache write
557 total_nbytes_ += io_info_.v_caches_write.size;
558 }
559
init_tensors()560 bool IoMemMgr::init_tensors() {
561 size_t cur_pos = input_token_pos_;
562 pos_idx_pos_ = cur_pos += io_info_.input_token.size;
563 atten_mask_pos_ = cur_pos += io_info_.pos_idx.size;
564 logit_pos_ = cur_pos += io_info_.atten_mask.size;
565 set_input_token_ptr();
566 set_pos_idx_ptr();
567 set_atten_mask_ptr();
568 set_logit_ptr();
569
570 // set start point of kv caches
571 cur_pos += io_info_.logit.size;
572
573 size_t num_heads = (method_meta_->num_inputs() - 3) / 2;
574 k_caches_read_pos_.resize(num_heads);
575 k_caches_write_pos_.resize(num_heads);
576 v_caches_read_pos_.resize(num_heads);
577 v_caches_write_pos_.resize(num_heads);
578
579 for (int i = 0; i < num_heads; i++) {
580 set_k_caches_read(i, cur_pos);
581 cur_pos += io_info_.k_caches_read.size;
582 }
583 // add a size of k caches buffer
584 cur_pos += io_info_.k_caches_read.size / io_info_.k_caches_write.size;
585 for (int i = 0; i < num_heads; i++) {
586 set_k_caches_write(i, cur_pos);
587 cur_pos += io_info_.k_caches_write.size;
588 }
589
590 for (int i = 0; i < num_heads; i++) {
591 set_v_caches_read(i, cur_pos);
592 set_v_caches_write(i, cur_pos + io_info_.v_caches_read.size);
593 cur_pos += io_info_.v_caches_read.size;
594 }
595 // add a caches as the b caches buffer
596 cur_pos += io_info_.v_caches_read.size;
597 return cur_pos <= total_nbytes_;
598 }
599
set_all_shifted_ptrs(size_t seq_len)600 void IoMemMgr::set_all_shifted_ptrs(size_t seq_len) {
601 auto iter_setter = [&](std::vector<size_t>& cache,
602 size_t shift_size,
603 InfoAttrs& tensor_info) {
604 for (int i = 0; i < cache.size(); ++i) {
605 size_t pos = cache[i] + shift_size;
606 CustomMemTensorInfo info = {
607 ptr_,
608 ptr_ + pos,
609 pos,
610 tensor_info.size,
611 tensor_info.shape.data(),
612 tensor_info.rank,
613 tensor_info.dtype};
614 QnnExecuTorchAddCustomMemTensorInfo(info);
615 }
616 };
617 for (int i = 0; i < seq_len; ++i) {
618 iter_setter(
619 k_caches_read_pos_,
620 i * io_info_.k_caches_read.element_size,
621 io_info_.k_caches_read);
622 iter_setter(
623 v_caches_read_pos_,
624 i * io_info_.v_caches_write.size,
625 io_info_.v_caches_read);
626 iter_setter(
627 v_caches_write_pos_,
628 i * io_info_.v_caches_write.size,
629 io_info_.v_caches_write);
630 }
631 }
632
stop()633 void Runner::stop() {
634 shouldStop_ = true;
635 }
636
get_method_meta()637 Result<MethodMeta> Runner::get_method_meta() {
638 return module_->method_meta("forward");
639 }
640
mem_alloc(size_t alignment,size_t seq_len)641 Error Runner::mem_alloc(size_t alignment, size_t seq_len) {
642 Result<MethodMeta> method_meta_result = get_method_meta();
643 io_mem_mgr_ = IoMemMgr(method_meta_result.get());
644 ET_CHECK_MSG(
645 io_mem_mgr_.allocate(alignment),
646 "IoMemMgr failed to allocate custom memory");
647
648 ET_CHECK_MSG(
649 io_mem_mgr_.init_tensors(),
650 "IoMemMgr required more bytes than allocated bytes");
651
652 io_mem_mgr_.set_all_shifted_ptrs(seq_len);
653 // To register rpc_mem_handle from SharedBuffer
654 // Reset and re-init again to trigger registered function
655 module_.reset();
656 module_ = std::make_unique<Module>(
657 model_path_, Module::LoadMode::MmapUseMlockIgnoreErrors);
658 ET_CHECK_MSG(load() == Error::Ok, "Runner failed to load method");
659
660 return Error::Ok;
661 }
662
663 // explicit instantiation of template methods
664 template int64_t Runner::getMetadataHelper<int64_t>(
665 std::string method_name,
666 int64_t default_val);
667 template bool Runner::getMetadataHelper<bool>(
668 std::string method_name,
669 bool default_val);
670
671 } // namespace example
672