xref: /aosp_15_r20/external/executorch/examples/mediatek/executor_runner/llama_runner/MultiModelLoader.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) 2024 MediaTek Inc.
3  *
4  * Licensed under the BSD License (the "License"); you may not use this file
5  * except in compliance with the License. See the license file in the root
6  * directory of this source tree for more details.
7  */
8 
9 #include <executorch/runtime/platform/assert.h>
10 #include <executorch/runtime/platform/log.h>
11 
12 #include "MultiModelLoader.h"
13 
14 #include <sstream>
15 #include <string>
16 #include <unordered_map>
17 #include <vector>
18 
19 namespace example {
20 
21 template <typename IdType>
LoadModels()22 void MultiModelLoader<IdType>::LoadModels() {
23   // Init empty model instance map
24   for (const auto& [id, _] : mModelPathMap) {
25     ET_CHECK_MSG(
26         !HasModel(id),
27         "Model is already initialized before calling LoadModels.");
28     mModelInstanceMap[id] = nullptr;
29   }
30   const size_t numModels = mModelPathMap.size();
31   if (!AllowModelsCoexist()) {
32     SelectModel(mDefaultModelId);
33     ET_CHECK_MSG(
34         GetModelInstance() == nullptr,
35         "Model is already initialized before calling LoadModels.");
36     void* instance = CreateModelInstance(mModelPathMap[mDefaultModelId]);
37     SetModelInstance(instance);
38     ET_LOG(
39         Debug,
40         "LoadModels(): Loaded single exclusive model (Total=%zu)",
41         numModels);
42     return;
43   }
44   for (const auto& [id, modelPath] : mModelPathMap) {
45     SelectModel(id);
46     ET_CHECK_MSG(
47         GetModelInstance() == nullptr,
48         "Model is already initialized before calling LoadModels.");
49     void* instance = CreateModelInstance(modelPath);
50     SetModelInstance(instance);
51   }
52   SelectModel(mDefaultModelId); // Select the default instance
53   ET_LOG(Debug, "LoadModels(): Loaded multiple models (Total=%zu)", numModels);
54 }
55 
56 template <typename IdType>
ReleaseModels()57 void MultiModelLoader<IdType>::ReleaseModels() {
58   if (!AllowModelsCoexist()) {
59     // Select the current instance
60     ReleaseModelInstance(GetModelInstance());
61     SetModelInstance(nullptr);
62     return;
63   }
64 
65   for (const auto& [id, _] : mModelInstanceMap) {
66     SelectModel(id);
67     ReleaseModelInstance(GetModelInstance());
68     SetModelInstance(nullptr);
69   }
70 }
71 
72 template <typename IdType>
GetModelInstance() const73 void* MultiModelLoader<IdType>::GetModelInstance() const {
74   ET_DCHECK_MSG(
75       HasModel(mCurrentModelId),
76       "Invalid id: %s",
77       GetIdString(mCurrentModelId).c_str());
78   return mModelInstanceMap.at(mCurrentModelId);
79 }
80 
81 template <typename IdType>
SetModelInstance(void * instance)82 void MultiModelLoader<IdType>::SetModelInstance(void* instance) {
83   ET_DCHECK_MSG(
84       HasModel(mCurrentModelId),
85       "Invalid id: %s",
86       GetIdString(mCurrentModelId).c_str());
87   mModelInstanceMap[mCurrentModelId] = instance;
88 }
89 
90 template <typename IdType>
SetDefaultModelId(const IdType & id)91 void MultiModelLoader<IdType>::SetDefaultModelId(const IdType& id) {
92   mDefaultModelId = id;
93 }
94 
95 template <typename IdType>
GetModelId() const96 IdType MultiModelLoader<IdType>::GetModelId() const {
97   return mCurrentModelId;
98 }
99 
100 template <typename IdType>
SelectModel(const IdType & id)101 void MultiModelLoader<IdType>::SelectModel(const IdType& id) {
102   ET_CHECK_MSG(HasModel(id), "Invalid id: %s", GetIdString(id).c_str());
103 
104   if (mCurrentModelId == id) {
105     return; // Do nothing
106   } else if (AllowModelsCoexist()) {
107     mCurrentModelId = id;
108     return;
109   }
110 
111   // Release current instance if already loaded
112   if (HasModel(mCurrentModelId) && GetModelInstance() != nullptr) {
113     ReleaseModelInstance(GetModelInstance());
114     SetModelInstance(nullptr);
115   }
116 
117   // Load new instance
118   mCurrentModelId = id;
119   void* newInstance = CreateModelInstance(mModelPathMap[id]);
120   SetModelInstance(newInstance);
121 }
122 
123 template <typename IdType>
GetNumModels() const124 size_t MultiModelLoader<IdType>::GetNumModels() const {
125   ET_CHECK_MSG(
126       mModelInstanceMap.size() == mModelPathMap.size(),
127       "Please ensure that LoadModels() is called first.");
128   return mModelInstanceMap.size();
129 }
130 
131 template <typename IdType>
GetModelPath() const132 const std::string& MultiModelLoader<IdType>::GetModelPath() const {
133   ET_CHECK_MSG(
134       HasModel(mCurrentModelId),
135       "Invalid id: %s",
136       GetIdString(mCurrentModelId).c_str());
137   return mModelPathMap.at(mCurrentModelId);
138 }
139 
140 template <typename IdType>
AddModel(const IdType & id,const std::string & modelPath)141 void MultiModelLoader<IdType>::AddModel(
142     const IdType& id,
143     const std::string& modelPath) {
144   if (HasModel(id)) {
145     ET_LOG(
146         Info,
147         "Overlapping model identifier detected. Replacing existing model instance.");
148     auto& oldInstance = mModelInstanceMap[id];
149     if (oldInstance != nullptr)
150       ReleaseModelInstance(oldInstance);
151     oldInstance = nullptr;
152   }
153   mModelPathMap[id] = modelPath;
154 
155   // Create runtime immediately if can coexist
156   mModelInstanceMap[id] = AllowModelsCoexist()
157       ? CreateModelInstance(mModelPathMap[mDefaultModelId])
158       : nullptr;
159 }
160 
161 template <typename IdType>
HasModel(const IdType & id) const162 bool MultiModelLoader<IdType>::HasModel(const IdType& id) const {
163   return mModelInstanceMap.find(id) != mModelInstanceMap.end();
164 }
165 
166 template <typename IdType>
GetIdString(const IdType & id)167 std::string MultiModelLoader<IdType>::GetIdString(const IdType& id) {
168   std::ostringstream ss;
169   ss << id;
170   return ss.str();
171 }
172 
173 // Explicit instantiation of MultiModelLoader for some integral Id types
174 template class MultiModelLoader<int>;
175 template class MultiModelLoader<size_t>;
176 
177 } // namespace example
178