xref: /aosp_15_r20/external/executorch/examples/mediatek/executor_runner/llama_runner/MultiModelLoader.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) 2024 MediaTek Inc.
3  *
4  * Licensed under the BSD License (the "License"); you may not use this file
5  * except in compliance with the License. See the license file in the root
6  * directory of this source tree for more details.
7  */
8 
9 #pragma once
10 
11 #include <string>
12 #include <unordered_map>
13 #include <vector>
14 
15 namespace example {
16 
17 template <typename IdType = size_t>
18 class MultiModelLoader {
19  public:
20   using ModelPathMap = std::unordered_map<IdType, std::string>;
21   using ModelInstanceMap = std::unordered_map<IdType, void*>;
22 
23   explicit MultiModelLoader(
24       const ModelPathMap& modelPathMap,
25       const IdType defaultModelId = {})
mModelPathMap(modelPathMap)26       : mModelPathMap(modelPathMap),
27         mDefaultModelId(defaultModelId),
28         mCurrentModelId(defaultModelId) {}
29 
30   explicit MultiModelLoader(
31       const std::string& modelPath,
32       const IdType defaultModelId = {})
33       : mModelPathMap({{defaultModelId, modelPath}}),
34         mDefaultModelId(defaultModelId),
35         mCurrentModelId(defaultModelId) {}
36 
~MultiModelLoader()37   virtual ~MultiModelLoader() {}
38 
39  protected:
40   // Initialize all models if they can coexist, otherwise initialize the default
41   // model.
42   void LoadModels();
43 
44   // Release all active model instances.
45   void ReleaseModels();
46 
47   // Get the current model instance.
48   void* GetModelInstance() const;
49 
50   // Set the current model instance.
51   void SetModelInstance(void* modelInstance);
52 
53   // Set the default active model after LoadModels() has been called.
54   void SetDefaultModelId(const IdType& id);
55 
56   // Get the id of the current model instance.
57   IdType GetModelId() const;
58 
59   // Select the model of given id to be active.
60   void SelectModel(const IdType& id);
61 
62   // Get total number of models.
63   size_t GetNumModels() const;
64 
65   // Get the model path of the current active model.
66   const std::string& GetModelPath() const;
67 
68   // Add new model post initialization, and returns the model id.
69   void AddModel(const IdType& id, const std::string& modelPath);
70 
71   bool HasModel(const IdType& id) const;
72 
73   static std::string GetIdString(const IdType& id);
74 
75  private:
76   // Create and returns a model instance given a model path. To be implemented
77   // by subclass.
78   virtual void* CreateModelInstance(const std::string& modelPath) = 0;
79 
80   // Release a model instance. To be implemented by subclass.
81   virtual void ReleaseModelInstance(void* modelInstance) = 0;
82 
83   // Determine whether multiple models are allowed to be alive concurrently.
AllowModelsCoexist()84   virtual bool AllowModelsCoexist() const {
85     return false;
86   }
87 
88  private:
89   ModelPathMap mModelPathMap;
90   ModelInstanceMap mModelInstanceMap;
91   IdType mDefaultModelId = 0;
92   IdType mCurrentModelId = 0;
93 };
94 
95 } // namespace example
96