xref: /aosp_15_r20/external/executorch/extension/llm/runner/text_decoder_runner.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Given inputs, run a text decoder and return logits.
10 
11 #include <executorch/extension/llm/runner/text_decoder_runner.h>
12 
13 #include <ctime>
14 
15 #include <executorch/extension/llm/runner/stats.h>
16 
17 namespace executorch {
18 namespace extension {
19 namespace llm {
20 
21 // NOTE: we observed ~2x loading performance increase on iPhone 15
22 // and a ~5% improvement on Galaxy S22 by switching to
23 // FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
TextDecoderRunner(Module * module,bool use_kv_cache,int32_t vocab_size,float temperature)24 TextDecoderRunner::TextDecoderRunner(
25     Module* module,
26     bool use_kv_cache,
27     int32_t vocab_size,
28     float temperature)
29     : module_(module),
30       sampler_(std::make_unique<Sampler>(
31           vocab_size,
32           temperature,
33           kTopp,
34           static_cast<unsigned long long>(std::time(nullptr)))),
35       use_kv_cache_(use_kv_cache) {}
36 
37 // This function is functional, meaning it shouldn't modify any state of the
38 // input. It should be safe to call multiple times with the same inputs. The
39 // outer loop (call site) is responsible for managing state.
step(TensorPtr & tokens,TensorPtr & start_pos)40 ::executorch::runtime::Result<exec_aten::Tensor> TextDecoderRunner::step(
41     TensorPtr& tokens,
42     TensorPtr& start_pos) {
43   // ET_LOG(Info, "Input token %" PRIu64, input_token);
44   if (use_kv_cache_) {
45     auto outputs_res = module_->forward({tokens, start_pos});
46     ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
47     ET_CHECK_MSG(
48         outputs_res.get().size() == 1,
49         "More then one output returned from executing LLM.");
50     ET_CHECK_MSG(
51         outputs_res.get()[0].isTensor(),
52         "Non Tensor Output returned from executing LLM");
53 
54     // Return the logits tensor
55     return outputs_res.get()[0].toTensor();
56   } else { // no kv cache
57     (void)start_pos; // unused
58 
59     auto outputs_res = module_->forward(tokens);
60     ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
61     ET_CHECK_MSG(
62         outputs_res.get().size() == 1,
63         "More then one output returned from executing LLM.");
64     ET_CHECK_MSG(
65         outputs_res.get()[0].isTensor(),
66         "Non Tensor Output returned from executing LLM");
67 
68     // Return the logits tensor
69     return outputs_res.get()[0].toTensor();
70   }
71 }
72 
73 } // namespace llm
74 } // namespace extension
75 } // namespace executorch
76