xref: /aosp_15_r20/external/executorch/examples/mediatek/executor_runner/llama_runner/LlamaRuntime.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) 2024 MediaTek Inc.
3  *
4  * Licensed under the BSD License (the "License"); you may not use this file
5  * except in compliance with the License. See the license file in the root
6  * directory of this source tree for more details.
7  */
8 
9 #include <string>
10 #include <thread>
11 #include <vector>
12 
13 #include <executorch/runtime/platform/log.h>
14 
15 #include "LlamaRuntime.h"
16 #include "Utils.h"
17 
18 #include "llm_helper/include/rotary_embedding.h"
19 #include "llm_helper/include/token_embedding.h"
20 
21 namespace example {
22 
Initialize(const LlamaModelOptions & modelOptions,const LlamaModelPaths & modelPaths)23 void LlamaRuntime::Initialize(
24     const LlamaModelOptions& modelOptions,
25     const LlamaModelPaths& modelPaths) {
26   mModelOptions = modelOptions;
27   const size_t numChunk = modelPaths.gen_model_paths.size();
28   const size_t numCache = 2 * modelOptions.num_layer / numChunk;
29   ET_CHECK_MSG(numChunk > 0, "No model to initialize");
30 
31   // Initialize rotary embedding master lookup table
32   const size_t rotEmbDim = modelOptions.hidden_size / modelOptions.num_head;
33   mRotEmbMasterLut = std::make_unique<llm_helper::RotaryEmbeddingMasterLut>(
34       modelOptions.rot_emb_type,
35       modelOptions.max_token_length,
36       rotEmbDim,
37       modelOptions.rot_emb_base);
38   mRotEmbMasterLut->generate();
39 
40   constexpr size_t numRotEmbInputs = 1;
41   const bool usePromptModel = !modelPaths.prompt_model_paths.empty();
42   const size_t initBatchSize =
43       usePromptModel ? modelOptions.prompt_token_batch_size : 1;
44   mTokenBatchSize = initBatchSize;
45 
46   for (size_t chunkIdx = 0; chunkIdx < numChunk; chunkIdx++) {
47     ModelPathMap modelPathMap;
48     auto addModelPath = [&](const auto& modelPaths, const size_t batchSize) {
49       if (modelPaths.empty())
50         return;
51       modelPathMap[batchSize] = modelPaths[chunkIdx];
52     };
53     addModelPath(
54         modelPaths.prompt_model_paths, modelOptions.prompt_token_batch_size);
55     addModelPath(modelPaths.gen_model_paths, 1);
56     auto llamaChunk = std::make_unique<LlamaModelChunk>(
57         modelPathMap,
58         modelOptions,
59         initBatchSize,
60         numCache,
61         numRotEmbInputs,
62         mRotEmbMasterLut.get());
63     mLlamaModelChunks.push_back(std::move(llamaChunk));
64   }
65 
66   for (size_t i = 0; i < numChunk; i++) {
67     auto& modelChunk = mLlamaModelChunks[i];
68     if (i > 0) {
69       const auto& prevModelChunk = mLlamaModelChunks[i - 1];
70       modelChunk->SetInputBuffer(prevModelChunk->GetOutputBuffer());
71     }
72     modelChunk->Initialize();
73     // modelChunk->LogIoSummary();
74   }
75 
76   // NOTE: Token embedding type here is assumed to follow the model input
77   // embedding type.
78   mTokenEmbLut = std::make_unique<llm_helper::TokenEmbeddingLut>(
79       modelPaths.token_embedding_path,
80       modelOptions.model_input_type,
81       modelOptions.hidden_size);
82 
83   // Link first chunk emb input to token emb lut output
84   const auto& tokenEmbInput = mLlamaModelChunks.front()->GetInputBuffer();
85   mTokenEmbLut->setOutput(tokenEmbInput.data, tokenEmbInput.nbytes);
86 }
87 
Release()88 void LlamaRuntime::Release() {
89   for (auto& llamaChunk : mLlamaModelChunks) {
90     llamaChunk->Release();
91   }
92   mLlamaModelChunks.clear();
93   mRotEmbMasterLut.reset();
94   mTokenEmbLut.reset();
95 }
96 
SwapModel(const size_t batchSize)97 void LlamaRuntime::SwapModel(const size_t batchSize) {
98   auto hotSwapChunk = [&](const auto chunkIdx) {
99     const auto status = mLlamaModelChunks[chunkIdx]->HotSwapModel(batchSize);
100     if (!status)
101       ET_LOG(Error, "Hot swapping failed on chunk %zu", chunkIdx);
102   };
103 
104   // Use multi-threading to speedup model swapping
105   std::vector<std::thread> threads;
106   for (size_t i = 0; i < mLlamaModelChunks.size(); i++)
107     threads.emplace_back(hotSwapChunk, i);
108   for (size_t i = 0; i < mLlamaModelChunks.size(); i++)
109     threads[i].join();
110 
111   mTokenBatchSize = batchSize;
112 }
113 
Reset()114 void LlamaRuntime::Reset() {
115   for (auto& modelChunk : mLlamaModelChunks) {
116     static_cast<LlamaModelChunk*>(modelChunk.get())->Reset();
117   }
118   mTokenIndex = 0;
119 }
120 
Run(const std::vector<uint64_t> & inputTokens,const bool lastLogits)121 void* LlamaRuntime::Run(
122     const std::vector<uint64_t>& inputTokens,
123     const bool lastLogits) {
124   const auto& firstLlamaChunk = mLlamaModelChunks.front();
125   const auto tokenIndex =
126       static_cast<LlamaModelChunk*>(firstLlamaChunk.get())->GetTokenIndex();
127   const auto numNewInputToken = inputTokens.size();
128 
129   ET_CHECK_MSG(
130       numNewInputToken <= mTokenBatchSize,
131       "Input token length (%zu) > model token batch size (%zu)",
132       numNewInputToken,
133       mTokenBatchSize);
134 
135   // Handle padding
136   auto curInputTokens = inputTokens; // Make a copy
137   const size_t padSize = mTokenBatchSize - numNewInputToken;
138   constexpr uint64_t padToken = 0;
139 
140   // Use left-padding if possible as it has lower overhead than right-padding.
141   // Right-padding involves cache shifting which incurs additional overhead.
142   const bool isLeftPadAllowed = (tokenIndex == 0);
143   if (padSize > 0) {
144     if (isLeftPadAllowed) {
145       // Pad left since the cache is fresh new.
146       curInputTokens.insert(curInputTokens.begin(), padSize, padToken);
147     } else {
148       // Pad right since left side of cache is occupied either by loaded cache
149       // or previous inference pass.
150       curInputTokens.insert(curInputTokens.end(), padSize, padToken);
151     }
152     ET_LOG(Debug, "Padding size = %zu", padSize);
153   }
154 
155   // Begin inference flow
156 
157   // Lookup token embedding
158   mTokenEmbLut->lookupEmbedding(curInputTokens);
159 
160   // Decoder chunks
161   for (auto& modelChunk : mLlamaModelChunks) {
162     auto llamaChunk = static_cast<LlamaModelChunk*>(modelChunk.get());
163 
164     // Set padding if needed.
165     if (isLeftPadAllowed)
166       llamaChunk->SetLeftPadding(padSize);
167     else
168       llamaChunk->SetRightPadding(padSize);
169 
170     // Run model chunk
171     llamaChunk->Run();
172   }
173 
174   // Only consider valid tokens by ignoring padding
175   mTokenIndex += inputTokens.size();
176 
177   // Return logits
178   const auto& finalChunk = mLlamaModelChunks.back();
179   const auto logitsBuffer = finalChunk->GetOutputBuffer();
180   const auto logitsData = reinterpret_cast<char*>(logitsBuffer.data);
181   const auto logitsSize = logitsBuffer.nbytesUsed;
182   size_t offset = 0;
183   const size_t rightPadSize = !isLeftPadAllowed * padSize;
184   if (lastLogits && mTokenBatchSize > 1) {
185     offset =
186         (logitsSize / mTokenBatchSize) * (mTokenBatchSize - 1 - rightPadSize);
187     ET_DCHECK(offset <= logitsSize);
188   }
189   return logitsData + offset;
190 }
191 
GetTokenBatchSize() const192 size_t LlamaRuntime::GetTokenBatchSize() const {
193   return mTokenBatchSize;
194 }
195 
GetTokenIndex() const196 size_t LlamaRuntime::GetTokenIndex() const {
197   return mTokenIndex;
198 }
199 
GetModelOptions() const200 const LlamaModelOptions& LlamaRuntime::GetModelOptions() const {
201   return mModelOptions;
202 }
203 
204 } // namespace example
205