xref: /aosp_15_r20/external/executorch/extension/llm/runner/text_prefiller.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Given a text prompt, encode it using tokenizer and prefill the KV cache of a
10 // LLM.
11 
12 #include <executorch/extension/llm/runner/text_prefiller.h>
13 
14 namespace executorch {
15 namespace extension {
16 namespace llm {
17 
TextPrefiller(TextDecoderRunner * text_decoder_runner,bool use_kv_cache,bool enable_parallel_prefill)18 TextPrefiller::TextPrefiller(
19     TextDecoderRunner* text_decoder_runner,
20     bool use_kv_cache,
21     bool enable_parallel_prefill)
22     : text_decoder_runner_(text_decoder_runner),
23       use_kv_cache_(use_kv_cache),
24       enable_parallel_prefill_(enable_parallel_prefill) {}
25 
prefill(std::vector<uint64_t> & prompt_tokens,int64_t & start_pos)26 ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
27     std::vector<uint64_t>& prompt_tokens,
28     int64_t& start_pos) {
29   ET_CHECK_MSG(!prompt_tokens.empty(), "Prompt cannot be null");
30   if (!text_decoder_runner_->is_method_loaded()) {
31     ET_CHECK_OK_OR_RETURN_ERROR(text_decoder_runner_->load());
32   }
33   // enable_parallel_prefill_ maybe set even when not using kv cache
34   // When kv cache is not used, start pos is ignored
35   int32_t num_prompt_tokens = prompt_tokens.size();
36 
37   // store the token
38   uint64_t cur_token;
39   if (enable_parallel_prefill_ || !use_kv_cache_) {
40     // initialize tensor wrappers
41     auto tokens = from_blob(
42         prompt_tokens.data(),
43         {1, num_prompt_tokens},
44         exec_aten::ScalarType::Long);
45 
46     auto start_pos_tensor =
47         from_blob(&start_pos, {1}, exec_aten::ScalarType::Long);
48 
49     auto outputs_res = text_decoder_runner_->step(tokens, start_pos_tensor);
50 
51     ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
52     ET_LOG(
53         Info, "Prefill token result numel(): %zu", outputs_res.get().numel());
54 
55     start_pos += num_prompt_tokens;
56     cur_token = text_decoder_runner_->logits_to_token(outputs_res.get());
57   } else { // sequential prefill
58     int64_t pos = 0; // position in the sequence
59     // NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
60     cur_token = prompt_tokens[0];
61 
62     // initialize tensor wrappers
63     auto tokens = from_blob(&cur_token, {1, 1}, exec_aten::ScalarType::Long);
64 
65     auto start_pos_tensor =
66         from_blob(&start_pos, {1}, exec_aten::ScalarType::Long);
67 
68     // run the first token and get back logits tensor. Assuming the first token
69     // is bos so don't callback.
70     auto logits_tensor =
71         ET_UNWRAP(text_decoder_runner_->step(tokens, start_pos_tensor));
72 
73     pos += 1; // start the loop from index 1
74     start_pos += 1;
75 
76     while (pos < num_prompt_tokens) {
77       // Run the model
78       // NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
79       cur_token = prompt_tokens[pos];
80 
81       logits_tensor =
82           ET_UNWRAP(text_decoder_runner_->step(tokens, start_pos_tensor));
83 
84       pos++;
85       start_pos++;
86     }
87 
88     cur_token = text_decoder_runner_->logits_to_token(logits_tensor);
89   }
90   return cur_token;
91 }
92 
93 } // namespace llm
94 } // namespace extension
95 } // namespace executorch
96