xref: /aosp_15_r20/external/executorch/extension/llm/runner/text_decoder_runner.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // Given inputs, run a text decoder in LLM and return the output.
10 
11 #pragma once
12 
13 #include <executorch/extension/llm/sampler/sampler.h>
14 #include <executorch/extension/module/module.h>
15 #include <executorch/extension/tensor/tensor.h>
16 #include <executorch/runtime/platform/compiler.h>
17 #include <functional>
18 
19 namespace executorch {
20 namespace extension {
21 namespace llm {
22 
23 class ET_EXPERIMENTAL TextDecoderRunner {
24  public:
25   TextDecoderRunner(
26       Module* module,
27       bool use_kv_cache,
28       int32_t vocab_size,
29       float temperature);
30 
31   virtual ~TextDecoderRunner() = default;
32 
33   /**
34    * Run LLM text decoder with inputs to generate next token.
35    * @param input The input to the LLM Module.
36    * @param start_pos The starting position in KV cache of the input in the LLM
37    * Module.
38    * @return The output of the LLM Module. This will be a tensor of logits.
39    */
40   virtual ::executorch::runtime::Result<executorch::aten::Tensor> step(
41       TensorPtr& input,
42       TensorPtr& start_pos);
43 
44   /**
45    * Load the Module for text decode purpose.
46    * @return The error code.
47    */
load()48   virtual ::executorch::runtime::Error load() {
49     return module_->load_method("forward");
50   }
51 
52   /**
53    * Check if the required methods in the Module is loaded.
54    * @return True if the Module is loaded, false otherwise.
55    */
is_method_loaded()56   virtual bool is_method_loaded() {
57     return module_->is_method_loaded("forward");
58   }
59 
stop()60   inline void stop() {
61     should_stop_ = true;
62   }
63 
64   /**
65    * Sample the next token from the logits tensor.
66    * @param logits_tensor The logits tensor.
67    * @return The next token.
68    */
logits_to_token(const executorch::aten::Tensor & logits_tensor)69   inline int32_t logits_to_token(
70       const executorch::aten::Tensor& logits_tensor) {
71     int32_t result = 0;
72     ET_SWITCH_THREE_TYPES(
73         Float,
74         Half,
75         BFloat16,
76         logits_tensor.scalar_type(),
77         unused,
78         "logits_to_token",
79         CTYPE,
80         [&]() {
81           // If the logit_tensor rank is 3, the shape is [batch, seq_length,
82           // vocab_size], get the last logits, sample and return. Else the model
83           // outputs the last logit, directly sample and return.
84           auto* logits = logits_tensor.mutable_data_ptr<CTYPE>();
85           if (logits_tensor.dim() == 3) {
86             auto num_tokens = logits_tensor.size(1);
87             auto vocab_size = logits_tensor.size(2);
88             auto* logits_last = logits;
89             logits_last += (num_tokens - 1) * vocab_size;
90             result = sampler_->sample(logits_last);
91           } else {
92             result = sampler_->sample(logits);
93           }
94         });
95     return result;
96   }
97 
98  protected:
99   // TODO: use shared_ptr for module
100   Module* module_;
101   std::unique_ptr<Sampler> sampler_;
102   bool use_kv_cache_;
103   bool should_stop_{false};
104 };
105 
106 } // namespace llm
107 } // namespace extension
108 } // namespace executorch
109 
110 namespace torch {
111 namespace executor {
112 // TODO(T197294990): Remove these deprecated aliases once all users have moved
113 // to the new `::executorch` namespaces.
114 using ::executorch::extension::llm::TextDecoderRunner;
115 } // namespace executor
116 } // namespace torch
117