xref: /aosp_15_r20/external/executorch/examples/qualcomm/qaihub_scripts/llama/runner/io_memory.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Qualcomm Innovation Center, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <cstdint>
12 #include <future>
13 #include <memory>
14 #include <queue>
15 #include <thread>
16 #include <vector>
17 
18 #include <executorch/extension/module/module.h>
19 #include <executorch/runtime/executor/method_meta.h>
20 
21 #if defined(QAIHUB_LLAMA3_RUNNER)
22 #define QAIHUB_LLAMA_NUM_HEADS 8
23 #define QAIHUB_LLAMA_LOGITS 128256
24 #else
25 #define QAIHUB_LLAMA_NUM_HEADS 32
26 #define QAIHUB_LLAMA_LOGITS 32000
27 #endif
28 
29 namespace example {
30 
31 class Memory {
32  public:
33   Memory(
34       const std::vector<std::string>& pos_embs_path,
35       std::vector<std::shared_ptr<executorch::extension::Module>>& modules);
36   virtual ~Memory();
37   virtual void prepare_io(
38       const std::vector<
39           executorch::runtime::Result<executorch::runtime::MethodMeta>>&
40           methods_meta) = 0;
41   virtual void update_io(
42       int64_t cur_token,
43       int64_t pos,
44       std::vector<std::vector<executorch::aten::Tensor>>& output_tensors) = 0;
45   void* get_mutable_ptr();
46   std::vector<executorch::aten::Tensor> get_input_tensors(int shard_index);
47   std::vector<executorch::aten::Tensor> get_output_tensors(int shard_index);
48 
49  protected:
50   std::unique_ptr<void, void (*)(void*)> data_ptr_;
51   std::vector<std::vector<executorch::aten::TensorImpl*>> input_tensors_;
52   std::vector<std::vector<executorch::aten::TensorImpl*>> output_tensors_;
53   std::vector<std::string> pos_embs_path_;
54   std::vector<std::shared_ptr<executorch::extension::Module>> modules_;
55   std::vector<std::string> method_names_;
56 };
57 
58 class BertMemory : public Memory {
59  public:
60   BertMemory(
61       const std::vector<std::string>& pos_embs_path,
62       std::vector<std::shared_ptr<executorch::extension::Module>>& modules,
63       std::vector<int> shard_layers);
64   void prepare_io(const std::vector<executorch::runtime::Result<
65                       executorch::runtime::MethodMeta>>& methods_meta) override;
66   void update_io(
67       int64_t cur_token,
68       int64_t pos,
69       std::vector<std::vector<executorch::aten::Tensor>>& output_tensors)
70       override;
71   struct IO {
72     int32_t input_ids[1024 * 2];
73     uint16_t hidden_state[1024 * 4096];
74     uint16_t attention_mask[1024 * 1024];
75     uint16_t position_ids_cos[1024 * 64];
76     uint16_t position_ids_sin[1024 * 64];
77     uint8_t k_cache[32][QAIHUB_LLAMA_NUM_HEADS][128 * 1024];
78     uint8_t v_cache[32][QAIHUB_LLAMA_NUM_HEADS][1024 * 128];
79     uint16_t logits[QAIHUB_LLAMA_LOGITS];
80   };
81 
82  private:
83   std::unique_ptr<executorch::aten::TensorImpl> input_ids_;
84   std::unique_ptr<executorch::aten::TensorImpl> hidden_state_;
85   std::unique_ptr<executorch::aten::TensorImpl> attention_mask_;
86   std::unique_ptr<executorch::aten::TensorImpl> position_ids_cos_;
87   std::unique_ptr<executorch::aten::TensorImpl> position_ids_sin_;
88   std::vector<std::unique_ptr<executorch::aten::TensorImpl>> k_cache_;
89   std::vector<std::unique_ptr<executorch::aten::TensorImpl>> v_cache_;
90   std::unique_ptr<executorch::aten::TensorImpl> logits_;
91   std::vector<int> shard_layers_;
92   int num_heads_;
93 };
94 
95 class ThreadPool {
96  public:
97   ThreadPool();
98   ~ThreadPool();
99 
100   std::future<void> issue(std::function<void(void*)> func, void* arg);
101   size_t num_workers();
102 
103  private:
104   struct JobInfo {
JobInfoJobInfo105     explicit JobInfo(std::packaged_task<void(void*)>&& func, void* arg)
106         : func(std::move(func)), arg(arg) {}
JobInfoJobInfo107     explicit JobInfo(JobInfo&& job_info)
108         : func(std::move(job_info.func)), arg(job_info.arg) {}
109     std::packaged_task<void(void*)> func;
110     void* arg;
111   };
112   size_t num_workers_;
113   std::vector<std::thread> threads_;
114   std::queue<JobInfo> jobs_;
115   std::mutex mutex_;
116   std::condition_variable cv_;
117   bool stop_;
118 };
119 
120 class KVCachedMemory : public Memory {
121  public:
122   KVCachedMemory(
123       const std::vector<std::string>& pos_embs_path,
124       std::vector<std::shared_ptr<executorch::extension::Module>>& modules,
125       std::vector<int> shard_layers);
126   void prepare_io(const std::vector<executorch::runtime::Result<
127                       executorch::runtime::MethodMeta>>& methods_meta) override;
128   void update_io(
129       int64_t cur_token,
130       int64_t pos,
131       std::vector<std::vector<executorch::aten::Tensor>>& output_tensors)
132       override;
133   struct IO {
134     int32_t input_ids;
135     uint16_t hidden_state[4096];
136     uint16_t attention_mask[1024];
137     uint16_t position_ids_cos[1024 * 64];
138     uint16_t position_ids_sin[1024 * 64];
139     uint8_t k_cache[32][QAIHUB_LLAMA_NUM_HEADS][129 * 1023];
140     uint8_t v_cache[32][(QAIHUB_LLAMA_NUM_HEADS + 1) * 1023 * 128];
141     uint8_t k_cache_out[32][QAIHUB_LLAMA_NUM_HEADS][128];
142     uint16_t logits[QAIHUB_LLAMA_LOGITS];
143   };
144   struct LoopRange {
145     int32_t start;
146     int32_t end;
147     int32_t step;
148   };
149 
150  private:
151   std::unique_ptr<executorch::aten::TensorImpl> input_ids_;
152   std::unique_ptr<executorch::aten::TensorImpl> hidden_state_;
153   std::unique_ptr<executorch::aten::TensorImpl> attention_mask_;
154   std::unique_ptr<executorch::aten::TensorImpl> position_ids_cos_;
155   std::unique_ptr<executorch::aten::TensorImpl> position_ids_sin_;
156   std::vector<std::unique_ptr<executorch::aten::TensorImpl>> k_cache_in_;
157   std::vector<std::unique_ptr<executorch::aten::TensorImpl>> v_cache_in_;
158   std::vector<std::unique_ptr<executorch::aten::TensorImpl>> k_cache_out_;
159   std::vector<std::unique_ptr<executorch::aten::TensorImpl>> v_cache_out_;
160   std::unique_ptr<executorch::aten::TensorImpl> logits_;
161   std::vector<LoopRange> lr_update_kv_;
162   std::vector<std::future<void>> futures_;
163   ThreadPool thread_pool_;
164   std::vector<int> shard_layers_;
165   int num_heads_;
166 };
167 
168 } // namespace example
169