1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 // Given inputs, run a text decoder in LLM and return the output. 10 11 #pragma once 12 13 #include <executorch/extension/llm/sampler/sampler.h> 14 #include <executorch/extension/module/module.h> 15 #include <executorch/extension/tensor/tensor.h> 16 #include <executorch/runtime/platform/compiler.h> 17 #include <functional> 18 19 namespace executorch { 20 namespace extension { 21 namespace llm { 22 23 class ET_EXPERIMENTAL TextDecoderRunner { 24 public: 25 TextDecoderRunner( 26 Module* module, 27 bool use_kv_cache, 28 int32_t vocab_size, 29 float temperature); 30 31 virtual ~TextDecoderRunner() = default; 32 33 /** 34 * Run LLM text decoder with inputs to generate next token. 35 * @param input The input to the LLM Module. 36 * @param start_pos The starting position in KV cache of the input in the LLM 37 * Module. 38 * @return The output of the LLM Module. This will be a tensor of logits. 39 */ 40 virtual ::executorch::runtime::Result<executorch::aten::Tensor> step( 41 TensorPtr& input, 42 TensorPtr& start_pos); 43 44 /** 45 * Load the Module for text decode purpose. 46 * @return The error code. 47 */ load()48 virtual ::executorch::runtime::Error load() { 49 return module_->load_method("forward"); 50 } 51 52 /** 53 * Check if the required methods in the Module is loaded. 54 * @return True if the Module is loaded, false otherwise. 55 */ is_method_loaded()56 virtual bool is_method_loaded() { 57 return module_->is_method_loaded("forward"); 58 } 59 stop()60 inline void stop() { 61 should_stop_ = true; 62 } 63 64 /** 65 * Sample the next token from the logits tensor. 66 * @param logits_tensor The logits tensor. 67 * @return The next token. 68 */ logits_to_token(const executorch::aten::Tensor & logits_tensor)69 inline int32_t logits_to_token( 70 const executorch::aten::Tensor& logits_tensor) { 71 int32_t result = 0; 72 ET_SWITCH_THREE_TYPES( 73 Float, 74 Half, 75 BFloat16, 76 logits_tensor.scalar_type(), 77 unused, 78 "logits_to_token", 79 CTYPE, 80 [&]() { 81 // If the logit_tensor rank is 3, the shape is [batch, seq_length, 82 // vocab_size], get the last logits, sample and return. Else the model 83 // outputs the last logit, directly sample and return. 84 auto* logits = logits_tensor.mutable_data_ptr<CTYPE>(); 85 if (logits_tensor.dim() == 3) { 86 auto num_tokens = logits_tensor.size(1); 87 auto vocab_size = logits_tensor.size(2); 88 auto* logits_last = logits; 89 logits_last += (num_tokens - 1) * vocab_size; 90 result = sampler_->sample(logits_last); 91 } else { 92 result = sampler_->sample(logits); 93 } 94 }); 95 return result; 96 } 97 98 protected: 99 // TODO: use shared_ptr for module 100 Module* module_; 101 std::unique_ptr<Sampler> sampler_; 102 bool use_kv_cache_; 103 bool should_stop_{false}; 104 }; 105 106 } // namespace llm 107 } // namespace extension 108 } // namespace executorch 109 110 namespace torch { 111 namespace executor { 112 // TODO(T197294990): Remove these deprecated aliases once all users have moved 113 // to the new `::executorch` namespaces. 114 using ::executorch::extension::llm::TextDecoderRunner; 115 } // namespace executor 116 } // namespace torch 117