xref: /aosp_15_r20/external/executorch/examples/mediatek/executor_runner/llama_runner/ModelChunk.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) 2024 MediaTek Inc.
3  *
4  * Licensed under the BSD License (the "License"); you may not use this file
5  * except in compliance with the License. See the license file in the root
6  * directory of this source tree for more details.
7  */
8 
9 #pragma once
10 
11 #include <optional>
12 #include <string>
13 #include <vector>
14 
15 #include <executorch/runtime/executor/method.h>
16 
17 #include "MultiModelLoader.h"
18 
19 namespace example {
20 
21 struct BufferInfo {
22   void* data = nullptr;
23   size_t nbytes = 0;
24   size_t nbytesUsed = 0;
25 };
26 
27 using MultiTokenSizeModelLoader = MultiModelLoader<size_t>;
28 using ModelPathMap = MultiTokenSizeModelLoader::ModelPathMap;
29 
30 class ModelChunk : protected MultiTokenSizeModelLoader {
31  public:
32   explicit ModelChunk(
33       const ModelPathMap& modelPathMap,
34       const size_t initBatchSize = 1)
MultiTokenSizeModelLoader(modelPathMap,initBatchSize)35       : MultiTokenSizeModelLoader(modelPathMap, initBatchSize),
36         mTokenBatchSize(initBatchSize) {}
37 
38   explicit ModelChunk(
39       const std::string& modelPath,
40       const size_t initBatchSize = 1)
MultiTokenSizeModelLoader(modelPath,initBatchSize)41       : MultiTokenSizeModelLoader(modelPath, initBatchSize),
42         mTokenBatchSize(initBatchSize) {}
43 
~ModelChunk()44   ~ModelChunk() {}
45 
46   virtual void Initialize();
47 
48   virtual void Release();
49 
50   virtual void Run();
51 
52   virtual bool HotSwapModel(const size_t tokenBatchSize);
53 
54   void
55   SetInputBuffer(const void* data, const size_t size, const size_t index = 0);
56 
57   void SetInputBuffer(const BufferInfo& bufferInfo, const size_t index = 0);
58 
59   BufferInfo GetInputBuffer(const size_t index = 0);
60 
61   BufferInfo GetOutputBuffer(const size_t index = 0);
62 
63   void LogIoSummary();
64 
65  protected:
66   // Check if model chunk has been initialized
67   bool Initialized();
68 
69   // Get model IO info after model has been loaded
70   void GetModelIoInfo();
71 
72   // Update IO sizes actually used by the model
73   void UpdateModelIoInfo();
74 
75   // Model IO linkage to share the same buffer among a pair of linked input and
76   // output
77   void LinkModelIO(const size_t inputIndex, const size_t outputIndex);
78 
79   // Return the input index that the given output should share the same buffer
80   std::optional<size_t> GetLinkedInputIndex(const size_t outputIndex) const;
81 
82   // Assign input buffers to model inputs using backend APIs
83   void SetBackendInputs();
84 
85   // Assign output buffers to model outputs using backend APIs
86   void SetBackendOutputs();
87 
88   // Allocate buffers for model IOs
89   void AllocateIoBuffers();
90 
91   // Release allocated buffers for model IOs
92   void ReleaseIoBuffers();
93 
94   executorch::runtime::Method& GetModelMethod();
95 
96  private:
97   // Override the virtual functions
98   void* CreateModelInstance(const std::string& modelPath) override;
99 
100   void ReleaseModelInstance(void* modelInstance) override;
101 
102  private:
AllowModelsCoexist()103   bool AllowModelsCoexist() const override {
104     return false;
105   }
106 
107  protected:
108   // State of initialization
109   bool mIsInitialized = false;
110 
111   // The number of input tokens the the fixed-shape model takes
112   size_t mTokenBatchSize = 1;
113 
114   // Input/Output buffers info
115   std::vector<BufferInfo> mInputBufferInfos;
116   std::vector<BufferInfo> mOutputBufferInfos;
117 
118   // Model IO linkage, where linked IO will share the same buffer
119   std::unordered_map<size_t, size_t> mModelOutToInIndexLinks;
120 };
121 
122 } // namespace example
123