1 /*
2 * Copyright (c) 2024 MediaTek Inc.
3 *
4 * Licensed under the BSD License (the "License"); you may not use this file
5 * except in compliance with the License. See the license file in the root
6 * directory of this source tree for more details.
7 */
8
9 #include <string>
10 #include <thread>
11 #include <vector>
12
13 #include <executorch/runtime/platform/log.h>
14
15 #include "LlamaRuntime.h"
16 #include "Utils.h"
17
18 #include "llm_helper/include/rotary_embedding.h"
19 #include "llm_helper/include/token_embedding.h"
20
21 namespace example {
22
Initialize(const LlamaModelOptions & modelOptions,const LlamaModelPaths & modelPaths)23 void LlamaRuntime::Initialize(
24 const LlamaModelOptions& modelOptions,
25 const LlamaModelPaths& modelPaths) {
26 mModelOptions = modelOptions;
27 const size_t numChunk = modelPaths.gen_model_paths.size();
28 const size_t numCache = 2 * modelOptions.num_layer / numChunk;
29 ET_CHECK_MSG(numChunk > 0, "No model to initialize");
30
31 // Initialize rotary embedding master lookup table
32 const size_t rotEmbDim = modelOptions.hidden_size / modelOptions.num_head;
33 mRotEmbMasterLut = std::make_unique<llm_helper::RotaryEmbeddingMasterLut>(
34 modelOptions.rot_emb_type,
35 modelOptions.max_token_length,
36 rotEmbDim,
37 modelOptions.rot_emb_base);
38 mRotEmbMasterLut->generate();
39
40 constexpr size_t numRotEmbInputs = 1;
41 const bool usePromptModel = !modelPaths.prompt_model_paths.empty();
42 const size_t initBatchSize =
43 usePromptModel ? modelOptions.prompt_token_batch_size : 1;
44 mTokenBatchSize = initBatchSize;
45
46 for (size_t chunkIdx = 0; chunkIdx < numChunk; chunkIdx++) {
47 ModelPathMap modelPathMap;
48 auto addModelPath = [&](const auto& modelPaths, const size_t batchSize) {
49 if (modelPaths.empty())
50 return;
51 modelPathMap[batchSize] = modelPaths[chunkIdx];
52 };
53 addModelPath(
54 modelPaths.prompt_model_paths, modelOptions.prompt_token_batch_size);
55 addModelPath(modelPaths.gen_model_paths, 1);
56 auto llamaChunk = std::make_unique<LlamaModelChunk>(
57 modelPathMap,
58 modelOptions,
59 initBatchSize,
60 numCache,
61 numRotEmbInputs,
62 mRotEmbMasterLut.get());
63 mLlamaModelChunks.push_back(std::move(llamaChunk));
64 }
65
66 for (size_t i = 0; i < numChunk; i++) {
67 auto& modelChunk = mLlamaModelChunks[i];
68 if (i > 0) {
69 const auto& prevModelChunk = mLlamaModelChunks[i - 1];
70 modelChunk->SetInputBuffer(prevModelChunk->GetOutputBuffer());
71 }
72 modelChunk->Initialize();
73 // modelChunk->LogIoSummary();
74 }
75
76 // NOTE: Token embedding type here is assumed to follow the model input
77 // embedding type.
78 mTokenEmbLut = std::make_unique<llm_helper::TokenEmbeddingLut>(
79 modelPaths.token_embedding_path,
80 modelOptions.model_input_type,
81 modelOptions.hidden_size);
82
83 // Link first chunk emb input to token emb lut output
84 const auto& tokenEmbInput = mLlamaModelChunks.front()->GetInputBuffer();
85 mTokenEmbLut->setOutput(tokenEmbInput.data, tokenEmbInput.nbytes);
86 }
87
Release()88 void LlamaRuntime::Release() {
89 for (auto& llamaChunk : mLlamaModelChunks) {
90 llamaChunk->Release();
91 }
92 mLlamaModelChunks.clear();
93 mRotEmbMasterLut.reset();
94 mTokenEmbLut.reset();
95 }
96
SwapModel(const size_t batchSize)97 void LlamaRuntime::SwapModel(const size_t batchSize) {
98 auto hotSwapChunk = [&](const auto chunkIdx) {
99 const auto status = mLlamaModelChunks[chunkIdx]->HotSwapModel(batchSize);
100 if (!status)
101 ET_LOG(Error, "Hot swapping failed on chunk %zu", chunkIdx);
102 };
103
104 // Use multi-threading to speedup model swapping
105 std::vector<std::thread> threads;
106 for (size_t i = 0; i < mLlamaModelChunks.size(); i++)
107 threads.emplace_back(hotSwapChunk, i);
108 for (size_t i = 0; i < mLlamaModelChunks.size(); i++)
109 threads[i].join();
110
111 mTokenBatchSize = batchSize;
112 }
113
Reset()114 void LlamaRuntime::Reset() {
115 for (auto& modelChunk : mLlamaModelChunks) {
116 static_cast<LlamaModelChunk*>(modelChunk.get())->Reset();
117 }
118 mTokenIndex = 0;
119 }
120
Run(const std::vector<uint64_t> & inputTokens,const bool lastLogits)121 void* LlamaRuntime::Run(
122 const std::vector<uint64_t>& inputTokens,
123 const bool lastLogits) {
124 const auto& firstLlamaChunk = mLlamaModelChunks.front();
125 const auto tokenIndex =
126 static_cast<LlamaModelChunk*>(firstLlamaChunk.get())->GetTokenIndex();
127 const auto numNewInputToken = inputTokens.size();
128
129 ET_CHECK_MSG(
130 numNewInputToken <= mTokenBatchSize,
131 "Input token length (%zu) > model token batch size (%zu)",
132 numNewInputToken,
133 mTokenBatchSize);
134
135 // Handle padding
136 auto curInputTokens = inputTokens; // Make a copy
137 const size_t padSize = mTokenBatchSize - numNewInputToken;
138 constexpr uint64_t padToken = 0;
139
140 // Use left-padding if possible as it has lower overhead than right-padding.
141 // Right-padding involves cache shifting which incurs additional overhead.
142 const bool isLeftPadAllowed = (tokenIndex == 0);
143 if (padSize > 0) {
144 if (isLeftPadAllowed) {
145 // Pad left since the cache is fresh new.
146 curInputTokens.insert(curInputTokens.begin(), padSize, padToken);
147 } else {
148 // Pad right since left side of cache is occupied either by loaded cache
149 // or previous inference pass.
150 curInputTokens.insert(curInputTokens.end(), padSize, padToken);
151 }
152 ET_LOG(Debug, "Padding size = %zu", padSize);
153 }
154
155 // Begin inference flow
156
157 // Lookup token embedding
158 mTokenEmbLut->lookupEmbedding(curInputTokens);
159
160 // Decoder chunks
161 for (auto& modelChunk : mLlamaModelChunks) {
162 auto llamaChunk = static_cast<LlamaModelChunk*>(modelChunk.get());
163
164 // Set padding if needed.
165 if (isLeftPadAllowed)
166 llamaChunk->SetLeftPadding(padSize);
167 else
168 llamaChunk->SetRightPadding(padSize);
169
170 // Run model chunk
171 llamaChunk->Run();
172 }
173
174 // Only consider valid tokens by ignoring padding
175 mTokenIndex += inputTokens.size();
176
177 // Return logits
178 const auto& finalChunk = mLlamaModelChunks.back();
179 const auto logitsBuffer = finalChunk->GetOutputBuffer();
180 const auto logitsData = reinterpret_cast<char*>(logitsBuffer.data);
181 const auto logitsSize = logitsBuffer.nbytesUsed;
182 size_t offset = 0;
183 const size_t rightPadSize = !isLeftPadAllowed * padSize;
184 if (lastLogits && mTokenBatchSize > 1) {
185 offset =
186 (logitsSize / mTokenBatchSize) * (mTokenBatchSize - 1 - rightPadSize);
187 ET_DCHECK(offset <= logitsSize);
188 }
189 return logitsData + offset;
190 }
191
GetTokenBatchSize() const192 size_t LlamaRuntime::GetTokenBatchSize() const {
193 return mTokenBatchSize;
194 }
195
GetTokenIndex() const196 size_t LlamaRuntime::GetTokenIndex() const {
197 return mTokenIndex;
198 }
199
GetModelOptions() const200 const LlamaModelOptions& LlamaRuntime::GetModelOptions() const {
201 return mModelOptions;
202 }
203
204 } // namespace example
205