1 /*
2 * Copyright (c) 2024 MediaTek Inc.
3 *
4 * Licensed under the BSD License (the "License"); you may not use this file
5 * except in compliance with the License. See the license file in the root
6 * directory of this source tree for more details.
7 */
8
9 #include <executorch/runtime/platform/assert.h>
10 #include <executorch/runtime/platform/log.h>
11
12 #include "MultiModelLoader.h"
13
14 #include <sstream>
15 #include <string>
16 #include <unordered_map>
17 #include <vector>
18
19 namespace example {
20
21 template <typename IdType>
LoadModels()22 void MultiModelLoader<IdType>::LoadModels() {
23 // Init empty model instance map
24 for (const auto& [id, _] : mModelPathMap) {
25 ET_CHECK_MSG(
26 !HasModel(id),
27 "Model is already initialized before calling LoadModels.");
28 mModelInstanceMap[id] = nullptr;
29 }
30 const size_t numModels = mModelPathMap.size();
31 if (!AllowModelsCoexist()) {
32 SelectModel(mDefaultModelId);
33 ET_CHECK_MSG(
34 GetModelInstance() == nullptr,
35 "Model is already initialized before calling LoadModels.");
36 void* instance = CreateModelInstance(mModelPathMap[mDefaultModelId]);
37 SetModelInstance(instance);
38 ET_LOG(
39 Debug,
40 "LoadModels(): Loaded single exclusive model (Total=%zu)",
41 numModels);
42 return;
43 }
44 for (const auto& [id, modelPath] : mModelPathMap) {
45 SelectModel(id);
46 ET_CHECK_MSG(
47 GetModelInstance() == nullptr,
48 "Model is already initialized before calling LoadModels.");
49 void* instance = CreateModelInstance(modelPath);
50 SetModelInstance(instance);
51 }
52 SelectModel(mDefaultModelId); // Select the default instance
53 ET_LOG(Debug, "LoadModels(): Loaded multiple models (Total=%zu)", numModels);
54 }
55
56 template <typename IdType>
ReleaseModels()57 void MultiModelLoader<IdType>::ReleaseModels() {
58 if (!AllowModelsCoexist()) {
59 // Select the current instance
60 ReleaseModelInstance(GetModelInstance());
61 SetModelInstance(nullptr);
62 return;
63 }
64
65 for (const auto& [id, _] : mModelInstanceMap) {
66 SelectModel(id);
67 ReleaseModelInstance(GetModelInstance());
68 SetModelInstance(nullptr);
69 }
70 }
71
72 template <typename IdType>
GetModelInstance() const73 void* MultiModelLoader<IdType>::GetModelInstance() const {
74 ET_DCHECK_MSG(
75 HasModel(mCurrentModelId),
76 "Invalid id: %s",
77 GetIdString(mCurrentModelId).c_str());
78 return mModelInstanceMap.at(mCurrentModelId);
79 }
80
81 template <typename IdType>
SetModelInstance(void * instance)82 void MultiModelLoader<IdType>::SetModelInstance(void* instance) {
83 ET_DCHECK_MSG(
84 HasModel(mCurrentModelId),
85 "Invalid id: %s",
86 GetIdString(mCurrentModelId).c_str());
87 mModelInstanceMap[mCurrentModelId] = instance;
88 }
89
90 template <typename IdType>
SetDefaultModelId(const IdType & id)91 void MultiModelLoader<IdType>::SetDefaultModelId(const IdType& id) {
92 mDefaultModelId = id;
93 }
94
95 template <typename IdType>
GetModelId() const96 IdType MultiModelLoader<IdType>::GetModelId() const {
97 return mCurrentModelId;
98 }
99
100 template <typename IdType>
SelectModel(const IdType & id)101 void MultiModelLoader<IdType>::SelectModel(const IdType& id) {
102 ET_CHECK_MSG(HasModel(id), "Invalid id: %s", GetIdString(id).c_str());
103
104 if (mCurrentModelId == id) {
105 return; // Do nothing
106 } else if (AllowModelsCoexist()) {
107 mCurrentModelId = id;
108 return;
109 }
110
111 // Release current instance if already loaded
112 if (HasModel(mCurrentModelId) && GetModelInstance() != nullptr) {
113 ReleaseModelInstance(GetModelInstance());
114 SetModelInstance(nullptr);
115 }
116
117 // Load new instance
118 mCurrentModelId = id;
119 void* newInstance = CreateModelInstance(mModelPathMap[id]);
120 SetModelInstance(newInstance);
121 }
122
123 template <typename IdType>
GetNumModels() const124 size_t MultiModelLoader<IdType>::GetNumModels() const {
125 ET_CHECK_MSG(
126 mModelInstanceMap.size() == mModelPathMap.size(),
127 "Please ensure that LoadModels() is called first.");
128 return mModelInstanceMap.size();
129 }
130
131 template <typename IdType>
GetModelPath() const132 const std::string& MultiModelLoader<IdType>::GetModelPath() const {
133 ET_CHECK_MSG(
134 HasModel(mCurrentModelId),
135 "Invalid id: %s",
136 GetIdString(mCurrentModelId).c_str());
137 return mModelPathMap.at(mCurrentModelId);
138 }
139
140 template <typename IdType>
AddModel(const IdType & id,const std::string & modelPath)141 void MultiModelLoader<IdType>::AddModel(
142 const IdType& id,
143 const std::string& modelPath) {
144 if (HasModel(id)) {
145 ET_LOG(
146 Info,
147 "Overlapping model identifier detected. Replacing existing model instance.");
148 auto& oldInstance = mModelInstanceMap[id];
149 if (oldInstance != nullptr)
150 ReleaseModelInstance(oldInstance);
151 oldInstance = nullptr;
152 }
153 mModelPathMap[id] = modelPath;
154
155 // Create runtime immediately if can coexist
156 mModelInstanceMap[id] = AllowModelsCoexist()
157 ? CreateModelInstance(mModelPathMap[mDefaultModelId])
158 : nullptr;
159 }
160
161 template <typename IdType>
HasModel(const IdType & id) const162 bool MultiModelLoader<IdType>::HasModel(const IdType& id) const {
163 return mModelInstanceMap.find(id) != mModelInstanceMap.end();
164 }
165
166 template <typename IdType>
GetIdString(const IdType & id)167 std::string MultiModelLoader<IdType>::GetIdString(const IdType& id) {
168 std::ostringstream ss;
169 ss << id;
170 return ss.str();
171 }
172
173 // Explicit instantiation of MultiModelLoader for some integral Id types
174 template class MultiModelLoader<int>;
175 template class MultiModelLoader<size_t>;
176
177 } // namespace example
178