1 /* 2 * Copyright (c) 2024 MediaTek Inc. 3 * 4 * Licensed under the BSD License (the "License"); you may not use this file 5 * except in compliance with the License. See the license file in the root 6 * directory of this source tree for more details. 7 */ 8 9 #pragma once 10 11 #include <optional> 12 #include <string> 13 #include <vector> 14 15 #include <executorch/runtime/executor/method.h> 16 17 #include "MultiModelLoader.h" 18 19 namespace example { 20 21 struct BufferInfo { 22 void* data = nullptr; 23 size_t nbytes = 0; 24 size_t nbytesUsed = 0; 25 }; 26 27 using MultiTokenSizeModelLoader = MultiModelLoader<size_t>; 28 using ModelPathMap = MultiTokenSizeModelLoader::ModelPathMap; 29 30 class ModelChunk : protected MultiTokenSizeModelLoader { 31 public: 32 explicit ModelChunk( 33 const ModelPathMap& modelPathMap, 34 const size_t initBatchSize = 1) MultiTokenSizeModelLoader(modelPathMap,initBatchSize)35 : MultiTokenSizeModelLoader(modelPathMap, initBatchSize), 36 mTokenBatchSize(initBatchSize) {} 37 38 explicit ModelChunk( 39 const std::string& modelPath, 40 const size_t initBatchSize = 1) MultiTokenSizeModelLoader(modelPath,initBatchSize)41 : MultiTokenSizeModelLoader(modelPath, initBatchSize), 42 mTokenBatchSize(initBatchSize) {} 43 ~ModelChunk()44 ~ModelChunk() {} 45 46 virtual void Initialize(); 47 48 virtual void Release(); 49 50 virtual void Run(); 51 52 virtual bool HotSwapModel(const size_t tokenBatchSize); 53 54 void 55 SetInputBuffer(const void* data, const size_t size, const size_t index = 0); 56 57 void SetInputBuffer(const BufferInfo& bufferInfo, const size_t index = 0); 58 59 BufferInfo GetInputBuffer(const size_t index = 0); 60 61 BufferInfo GetOutputBuffer(const size_t index = 0); 62 63 void LogIoSummary(); 64 65 protected: 66 // Check if model chunk has been initialized 67 bool Initialized(); 68 69 // Get model IO info after model has been loaded 70 void GetModelIoInfo(); 71 72 // Update IO sizes actually used by the model 73 void UpdateModelIoInfo(); 74 75 // Model IO linkage to share the same buffer among a pair of linked input and 76 // output 77 void LinkModelIO(const size_t inputIndex, const size_t outputIndex); 78 79 // Return the input index that the given output should share the same buffer 80 std::optional<size_t> GetLinkedInputIndex(const size_t outputIndex) const; 81 82 // Assign input buffers to model inputs using backend APIs 83 void SetBackendInputs(); 84 85 // Assign output buffers to model outputs using backend APIs 86 void SetBackendOutputs(); 87 88 // Allocate buffers for model IOs 89 void AllocateIoBuffers(); 90 91 // Release allocated buffers for model IOs 92 void ReleaseIoBuffers(); 93 94 executorch::runtime::Method& GetModelMethod(); 95 96 private: 97 // Override the virtual functions 98 void* CreateModelInstance(const std::string& modelPath) override; 99 100 void ReleaseModelInstance(void* modelInstance) override; 101 102 private: AllowModelsCoexist()103 bool AllowModelsCoexist() const override { 104 return false; 105 } 106 107 protected: 108 // State of initialization 109 bool mIsInitialized = false; 110 111 // The number of input tokens the the fixed-shape model takes 112 size_t mTokenBatchSize = 1; 113 114 // Input/Output buffers info 115 std::vector<BufferInfo> mInputBufferInfos; 116 std::vector<BufferInfo> mOutputBufferInfos; 117 118 // Model IO linkage, where linked IO will share the same buffer 119 std::unordered_map<size_t, size_t> mModelOutToInIndexLinks; 120 }; 121 122 } // namespace example 123