1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 // Given a text prompt, encode it using tokenizer and prefill the KV cache of a
10 // LLM.
11
12 #include <executorch/extension/llm/runner/text_prefiller.h>
13
14 namespace executorch {
15 namespace extension {
16 namespace llm {
17
TextPrefiller(TextDecoderRunner * text_decoder_runner,bool use_kv_cache,bool enable_parallel_prefill)18 TextPrefiller::TextPrefiller(
19 TextDecoderRunner* text_decoder_runner,
20 bool use_kv_cache,
21 bool enable_parallel_prefill)
22 : text_decoder_runner_(text_decoder_runner),
23 use_kv_cache_(use_kv_cache),
24 enable_parallel_prefill_(enable_parallel_prefill) {}
25
prefill(std::vector<uint64_t> & prompt_tokens,int64_t & start_pos)26 ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
27 std::vector<uint64_t>& prompt_tokens,
28 int64_t& start_pos) {
29 ET_CHECK_MSG(!prompt_tokens.empty(), "Prompt cannot be null");
30 if (!text_decoder_runner_->is_method_loaded()) {
31 ET_CHECK_OK_OR_RETURN_ERROR(text_decoder_runner_->load());
32 }
33 // enable_parallel_prefill_ maybe set even when not using kv cache
34 // When kv cache is not used, start pos is ignored
35 int32_t num_prompt_tokens = prompt_tokens.size();
36
37 // store the token
38 uint64_t cur_token;
39 if (enable_parallel_prefill_ || !use_kv_cache_) {
40 // initialize tensor wrappers
41 auto tokens = from_blob(
42 prompt_tokens.data(),
43 {1, num_prompt_tokens},
44 exec_aten::ScalarType::Long);
45
46 auto start_pos_tensor =
47 from_blob(&start_pos, {1}, exec_aten::ScalarType::Long);
48
49 auto outputs_res = text_decoder_runner_->step(tokens, start_pos_tensor);
50
51 ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
52 ET_LOG(
53 Info, "Prefill token result numel(): %zu", outputs_res.get().numel());
54
55 start_pos += num_prompt_tokens;
56 cur_token = text_decoder_runner_->logits_to_token(outputs_res.get());
57 } else { // sequential prefill
58 int64_t pos = 0; // position in the sequence
59 // NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
60 cur_token = prompt_tokens[0];
61
62 // initialize tensor wrappers
63 auto tokens = from_blob(&cur_token, {1, 1}, exec_aten::ScalarType::Long);
64
65 auto start_pos_tensor =
66 from_blob(&start_pos, {1}, exec_aten::ScalarType::Long);
67
68 // run the first token and get back logits tensor. Assuming the first token
69 // is bos so don't callback.
70 auto logits_tensor =
71 ET_UNWRAP(text_decoder_runner_->step(tokens, start_pos_tensor));
72
73 pos += 1; // start the loop from index 1
74 start_pos += 1;
75
76 while (pos < num_prompt_tokens) {
77 // Run the model
78 // NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
79 cur_token = prompt_tokens[pos];
80
81 logits_tensor =
82 ET_UNWRAP(text_decoder_runner_->step(tokens, start_pos_tensor));
83
84 pos++;
85 start_pos++;
86 }
87
88 cur_token = text_decoder_runner_->logits_to_token(logits_tensor);
89 }
90 return cur_token;
91 }
92
93 } // namespace llm
94 } // namespace extension
95 } // namespace executorch
96