1 /* 2 * Copyright (c) Qualcomm Innovation Center, Inc. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 #include <cstdint> 12 #include <future> 13 #include <memory> 14 #include <queue> 15 #include <thread> 16 #include <vector> 17 18 #include <executorch/extension/module/module.h> 19 #include <executorch/runtime/executor/method_meta.h> 20 21 #if defined(QAIHUB_LLAMA3_RUNNER) 22 #define QAIHUB_LLAMA_NUM_HEADS 8 23 #define QAIHUB_LLAMA_LOGITS 128256 24 #else 25 #define QAIHUB_LLAMA_NUM_HEADS 32 26 #define QAIHUB_LLAMA_LOGITS 32000 27 #endif 28 29 namespace example { 30 31 class Memory { 32 public: 33 Memory( 34 const std::vector<std::string>& pos_embs_path, 35 std::vector<std::shared_ptr<executorch::extension::Module>>& modules); 36 virtual ~Memory(); 37 virtual void prepare_io( 38 const std::vector< 39 executorch::runtime::Result<executorch::runtime::MethodMeta>>& 40 methods_meta) = 0; 41 virtual void update_io( 42 int64_t cur_token, 43 int64_t pos, 44 std::vector<std::vector<executorch::aten::Tensor>>& output_tensors) = 0; 45 void* get_mutable_ptr(); 46 std::vector<executorch::aten::Tensor> get_input_tensors(int shard_index); 47 std::vector<executorch::aten::Tensor> get_output_tensors(int shard_index); 48 49 protected: 50 std::unique_ptr<void, void (*)(void*)> data_ptr_; 51 std::vector<std::vector<executorch::aten::TensorImpl*>> input_tensors_; 52 std::vector<std::vector<executorch::aten::TensorImpl*>> output_tensors_; 53 std::vector<std::string> pos_embs_path_; 54 std::vector<std::shared_ptr<executorch::extension::Module>> modules_; 55 std::vector<std::string> method_names_; 56 }; 57 58 class BertMemory : public Memory { 59 public: 60 BertMemory( 61 const std::vector<std::string>& pos_embs_path, 62 std::vector<std::shared_ptr<executorch::extension::Module>>& modules, 63 std::vector<int> shard_layers); 64 void prepare_io(const std::vector<executorch::runtime::Result< 65 executorch::runtime::MethodMeta>>& methods_meta) override; 66 void update_io( 67 int64_t cur_token, 68 int64_t pos, 69 std::vector<std::vector<executorch::aten::Tensor>>& output_tensors) 70 override; 71 struct IO { 72 int32_t input_ids[1024 * 2]; 73 uint16_t hidden_state[1024 * 4096]; 74 uint16_t attention_mask[1024 * 1024]; 75 uint16_t position_ids_cos[1024 * 64]; 76 uint16_t position_ids_sin[1024 * 64]; 77 uint8_t k_cache[32][QAIHUB_LLAMA_NUM_HEADS][128 * 1024]; 78 uint8_t v_cache[32][QAIHUB_LLAMA_NUM_HEADS][1024 * 128]; 79 uint16_t logits[QAIHUB_LLAMA_LOGITS]; 80 }; 81 82 private: 83 std::unique_ptr<executorch::aten::TensorImpl> input_ids_; 84 std::unique_ptr<executorch::aten::TensorImpl> hidden_state_; 85 std::unique_ptr<executorch::aten::TensorImpl> attention_mask_; 86 std::unique_ptr<executorch::aten::TensorImpl> position_ids_cos_; 87 std::unique_ptr<executorch::aten::TensorImpl> position_ids_sin_; 88 std::vector<std::unique_ptr<executorch::aten::TensorImpl>> k_cache_; 89 std::vector<std::unique_ptr<executorch::aten::TensorImpl>> v_cache_; 90 std::unique_ptr<executorch::aten::TensorImpl> logits_; 91 std::vector<int> shard_layers_; 92 int num_heads_; 93 }; 94 95 class ThreadPool { 96 public: 97 ThreadPool(); 98 ~ThreadPool(); 99 100 std::future<void> issue(std::function<void(void*)> func, void* arg); 101 size_t num_workers(); 102 103 private: 104 struct JobInfo { JobInfoJobInfo105 explicit JobInfo(std::packaged_task<void(void*)>&& func, void* arg) 106 : func(std::move(func)), arg(arg) {} JobInfoJobInfo107 explicit JobInfo(JobInfo&& job_info) 108 : func(std::move(job_info.func)), arg(job_info.arg) {} 109 std::packaged_task<void(void*)> func; 110 void* arg; 111 }; 112 size_t num_workers_; 113 std::vector<std::thread> threads_; 114 std::queue<JobInfo> jobs_; 115 std::mutex mutex_; 116 std::condition_variable cv_; 117 bool stop_; 118 }; 119 120 class KVCachedMemory : public Memory { 121 public: 122 KVCachedMemory( 123 const std::vector<std::string>& pos_embs_path, 124 std::vector<std::shared_ptr<executorch::extension::Module>>& modules, 125 std::vector<int> shard_layers); 126 void prepare_io(const std::vector<executorch::runtime::Result< 127 executorch::runtime::MethodMeta>>& methods_meta) override; 128 void update_io( 129 int64_t cur_token, 130 int64_t pos, 131 std::vector<std::vector<executorch::aten::Tensor>>& output_tensors) 132 override; 133 struct IO { 134 int32_t input_ids; 135 uint16_t hidden_state[4096]; 136 uint16_t attention_mask[1024]; 137 uint16_t position_ids_cos[1024 * 64]; 138 uint16_t position_ids_sin[1024 * 64]; 139 uint8_t k_cache[32][QAIHUB_LLAMA_NUM_HEADS][129 * 1023]; 140 uint8_t v_cache[32][(QAIHUB_LLAMA_NUM_HEADS + 1) * 1023 * 128]; 141 uint8_t k_cache_out[32][QAIHUB_LLAMA_NUM_HEADS][128]; 142 uint16_t logits[QAIHUB_LLAMA_LOGITS]; 143 }; 144 struct LoopRange { 145 int32_t start; 146 int32_t end; 147 int32_t step; 148 }; 149 150 private: 151 std::unique_ptr<executorch::aten::TensorImpl> input_ids_; 152 std::unique_ptr<executorch::aten::TensorImpl> hidden_state_; 153 std::unique_ptr<executorch::aten::TensorImpl> attention_mask_; 154 std::unique_ptr<executorch::aten::TensorImpl> position_ids_cos_; 155 std::unique_ptr<executorch::aten::TensorImpl> position_ids_sin_; 156 std::vector<std::unique_ptr<executorch::aten::TensorImpl>> k_cache_in_; 157 std::vector<std::unique_ptr<executorch::aten::TensorImpl>> v_cache_in_; 158 std::vector<std::unique_ptr<executorch::aten::TensorImpl>> k_cache_out_; 159 std::vector<std::unique_ptr<executorch::aten::TensorImpl>> v_cache_out_; 160 std::unique_ptr<executorch::aten::TensorImpl> logits_; 161 std::vector<LoopRange> lr_update_kv_; 162 std::vector<std::future<void>> futures_; 163 ThreadPool thread_pool_; 164 std::vector<int> shard_layers_; 165 int num_heads_; 166 }; 167 168 } // namespace example 169