1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 // Given inputs, run a text decoder and return logits.
10
11 #include <executorch/extension/llm/runner/text_decoder_runner.h>
12
13 #include <ctime>
14
15 #include <executorch/extension/llm/runner/stats.h>
16
17 namespace executorch {
18 namespace extension {
19 namespace llm {
20
21 // NOTE: we observed ~2x loading performance increase on iPhone 15
22 // and a ~5% improvement on Galaxy S22 by switching to
23 // FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
TextDecoderRunner(Module * module,bool use_kv_cache,int32_t vocab_size,float temperature)24 TextDecoderRunner::TextDecoderRunner(
25 Module* module,
26 bool use_kv_cache,
27 int32_t vocab_size,
28 float temperature)
29 : module_(module),
30 sampler_(std::make_unique<Sampler>(
31 vocab_size,
32 temperature,
33 kTopp,
34 static_cast<unsigned long long>(std::time(nullptr)))),
35 use_kv_cache_(use_kv_cache) {}
36
37 // This function is functional, meaning it shouldn't modify any state of the
38 // input. It should be safe to call multiple times with the same inputs. The
39 // outer loop (call site) is responsible for managing state.
step(TensorPtr & tokens,TensorPtr & start_pos)40 ::executorch::runtime::Result<exec_aten::Tensor> TextDecoderRunner::step(
41 TensorPtr& tokens,
42 TensorPtr& start_pos) {
43 // ET_LOG(Info, "Input token %" PRIu64, input_token);
44 if (use_kv_cache_) {
45 auto outputs_res = module_->forward({tokens, start_pos});
46 ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
47 ET_CHECK_MSG(
48 outputs_res.get().size() == 1,
49 "More then one output returned from executing LLM.");
50 ET_CHECK_MSG(
51 outputs_res.get()[0].isTensor(),
52 "Non Tensor Output returned from executing LLM");
53
54 // Return the logits tensor
55 return outputs_res.get()[0].toTensor();
56 } else { // no kv cache
57 (void)start_pos; // unused
58
59 auto outputs_res = module_->forward(tokens);
60 ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
61 ET_CHECK_MSG(
62 outputs_res.get().size() == 1,
63 "More then one output returned from executing LLM.");
64 ET_CHECK_MSG(
65 outputs_res.get()[0].isTensor(),
66 "Non Tensor Output returned from executing LLM");
67
68 // Return the logits tensor
69 return outputs_res.get()[0].toTensor();
70 }
71 }
72
73 } // namespace llm
74 } // namespace extension
75 } // namespace executorch
76