1 /*
2 * Copyright (c) Qualcomm Innovation Center, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <algorithm>
10 #include <fstream>
11
12 #include <executorch/examples/qualcomm/qaihub_scripts/llama/runner/io_memory.h>
13 #include <executorch/runtime/core/exec_aten/util/tensor_util.h>
14
15 using executorch::aten::Tensor;
16 using executorch::aten::TensorImpl;
17 using executorch::extension::Module;
18 using executorch::runtime::Error;
19 using executorch::runtime::MethodMeta;
20 using executorch::runtime::Result;
21 using executorch::runtime::TensorInfo;
22
23 namespace example {
24
Memory(const std::vector<std::string> & pos_embs_path,std::vector<std::shared_ptr<Module>> & modules)25 Memory::Memory(
26 const std::vector<std::string>& pos_embs_path,
27 std::vector<std::shared_ptr<Module>>& modules)
28 : data_ptr_(nullptr, [](void*) {}),
29 input_tensors_(modules.size()),
30 output_tensors_(modules.size()),
31 pos_embs_path_(pos_embs_path),
32 modules_(modules) {
33 for (std::shared_ptr<Module>& module : modules_) {
34 method_names_.emplace_back(*module->method_names()->begin());
35 }
36 }
37
~Memory()38 Memory::~Memory() {}
39
get_mutable_ptr()40 void* Memory::get_mutable_ptr() {
41 return data_ptr_.get();
42 }
43
get_input_tensors(int shard_index)44 std::vector<Tensor> Memory::get_input_tensors(int shard_index) {
45 std::vector<Tensor> ret;
46 ret.reserve(input_tensors_.size());
47 for (TensorImpl* impl : input_tensors_[shard_index]) {
48 ret.emplace_back(Tensor(impl));
49 }
50 return ret;
51 }
52
get_output_tensors(int shard_index)53 std::vector<Tensor> Memory::get_output_tensors(int shard_index) {
54 std::vector<Tensor> ret;
55 ret.reserve(output_tensors_.size());
56 for (TensorImpl* impl : output_tensors_[shard_index]) {
57 ret.emplace_back(Tensor(impl));
58 }
59 return ret;
60 }
61
BertMemory(const std::vector<std::string> & pos_embs_path,std::vector<std::shared_ptr<Module>> & modules,std::vector<int> shard_layers)62 BertMemory::BertMemory(
63 const std::vector<std::string>& pos_embs_path,
64 std::vector<std::shared_ptr<Module>>& modules,
65 std::vector<int> shard_layers)
66 : Memory(pos_embs_path, modules),
67 shard_layers_(shard_layers),
68 num_heads_(QAIHUB_LLAMA_NUM_HEADS) {
69 data_ptr_ = std::unique_ptr<void, void (*)(void*)>(
70 new IO, [](void* ptr) { delete static_cast<IO*>(ptr); });
71 }
72
prepare_io(const std::vector<Result<MethodMeta>> & methods_meta)73 void BertMemory::prepare_io(
74 const std::vector<Result<MethodMeta>>& methods_meta) {
75 IO* ptr = static_cast<IO*>(data_ptr_.get());
76 std::memset(ptr, 0, sizeof(IO));
77
78 for (int i = 0; i < modules_.size(); ++i) {
79 ET_CHECK_MSG(
80 methods_meta[i].ok(),
81 "Failed to get method_meta 0x%x",
82 static_cast<uint32_t>(methods_meta[i].error()));
83 }
84 // [I] position embedding initialization
85 for (size_t i = 0; i < pos_embs_path_.size(); ++i) {
86 std::ifstream fin(pos_embs_path_[i], std::ios::binary);
87 fin.read(
88 reinterpret_cast<char*>(
89 i == 0 ? ptr->position_ids_cos : ptr->position_ids_sin),
90 1024 * 64 * 2);
91 fin.close();
92 }
93 // [I]: all shards (4 shards for llama2, 5 shards for llama)
94 {
95 // [I]: input_ids
96 Result<TensorInfo> input_ids = methods_meta[0]->input_tensor_meta(0);
97 input_ids_ = std::make_unique<TensorImpl>(
98 input_ids->scalar_type(),
99 input_ids->sizes().size(),
100 const_cast<TensorImpl::SizesType*>(input_ids->sizes().data()),
101 ptr->input_ids,
102 const_cast<TensorImpl::DimOrderType*>(input_ids->dim_order().data()));
103 input_tensors_[0].push_back(input_ids_.get());
104 // [I]: atten_mask
105 Result<TensorInfo> atten_mask = methods_meta[0]->input_tensor_meta(1);
106 attention_mask_ = std::make_unique<TensorImpl>(
107 atten_mask->scalar_type(),
108 atten_mask->sizes().size(),
109 const_cast<TensorImpl::SizesType*>(atten_mask->sizes().data()),
110 ptr->attention_mask,
111 const_cast<TensorImpl::DimOrderType*>(atten_mask->dim_order().data()));
112 input_tensors_[0].push_back(attention_mask_.get());
113 // [I]: pos_ids_cos
114 Result<TensorInfo> pos_ids_cos = methods_meta[0]->input_tensor_meta(2);
115 position_ids_cos_ = std::make_unique<TensorImpl>(
116 pos_ids_cos->scalar_type(),
117 pos_ids_cos->sizes().size(),
118 const_cast<TensorImpl::SizesType*>(pos_ids_cos->sizes().data()),
119 ptr->position_ids_cos,
120 const_cast<TensorImpl::DimOrderType*>(pos_ids_cos->dim_order().data()));
121 input_tensors_[0].push_back(position_ids_cos_.get());
122 // [I]: pos_ids_sin
123 Result<TensorInfo> pos_ids_sin = methods_meta[0]->input_tensor_meta(3);
124 position_ids_sin_ = std::make_unique<TensorImpl>(
125 pos_ids_sin->scalar_type(),
126 pos_ids_sin->sizes().size(),
127 const_cast<TensorImpl::SizesType*>(pos_ids_sin->sizes().data()),
128 ptr->position_ids_sin,
129 const_cast<TensorImpl::DimOrderType*>(pos_ids_sin->dim_order().data()));
130 input_tensors_[0].push_back(position_ids_sin_.get());
131 // [IO]: hidden_state => [I] shard2,3,4
132 int output_index =
133 shard_layers_[0] * 2 * num_heads_; // layers*(k + v caches)*heads
134 Result<TensorInfo> hidden_state =
135 methods_meta[0]->output_tensor_meta(output_index);
136 hidden_state_ = std::make_unique<TensorImpl>(
137 hidden_state->scalar_type(),
138 hidden_state->sizes().size(),
139 const_cast<TensorImpl::SizesType*>(hidden_state->sizes().data()),
140 ptr->hidden_state,
141 const_cast<TensorImpl::DimOrderType*>(
142 hidden_state->dim_order().data()));
143 // reuse inputs for following tensors
144 for (int shard_index = 1; shard_index < modules_.size(); ++shard_index) {
145 // inputs of shards 1 to n: hidden_state, atten_mask, pos_ids_cos,
146 // pos_ids_sin
147 input_tensors_[shard_index].push_back(hidden_state_.get());
148 input_tensors_[shard_index].push_back(attention_mask_.get());
149 input_tensors_[shard_index].push_back(position_ids_cos_.get());
150 input_tensors_[shard_index].push_back(position_ids_sin_.get());
151 }
152 }
153 // [O] kv_cache for all shards (4 shards for llama2 and 5 shards for llama3)
154 for (int offset = 0, shard_index = 0; shard_index < modules_.size();
155 offset += shard_layers_[shard_index], shard_index++) {
156 for (int layer = 0; layer < shard_layers_[shard_index]; ++layer) {
157 for (int cache_group = 0; cache_group < 2; ++cache_group) {
158 for (int head = 0; head < num_heads_; ++head) {
159 int index = num_heads_ * 2 * layer + cache_group * num_heads_ + head;
160 Result<TensorInfo> kv_cache =
161 methods_meta[shard_index]->output_tensor_meta(index);
162 std::vector<std::unique_ptr<TensorImpl>>& cache =
163 (cache_group == 0 ? v_cache_ : k_cache_);
164 cache.emplace_back(std::make_unique<TensorImpl>(
165 kv_cache->scalar_type(),
166 kv_cache->sizes().size(),
167 const_cast<TensorImpl::SizesType*>(kv_cache->sizes().data()),
168 cache_group == 0 ? ptr->v_cache[layer + offset][head]
169 : ptr->k_cache[layer + offset][head],
170 const_cast<TensorImpl::DimOrderType*>(
171 kv_cache->dim_order().data())));
172 output_tensors_[shard_index].push_back(cache.back().get());
173 }
174 }
175 }
176 }
177 // [O]: hidden_state for shard 0 to n-1
178 for (int shard_index = 0; shard_index < modules_.size() - 1; ++shard_index) {
179 output_tensors_[shard_index].push_back(hidden_state_.get());
180 }
181 // [O]: logits
182 {
183 int output_index = shard_layers_[modules_.size() - 1] * 2 *
184 num_heads_; // layers*(k + v caches)*heads
185 Result<TensorInfo> logits =
186 methods_meta[modules_.size() - 1]->output_tensor_meta(output_index);
187 logits_ = std::make_unique<TensorImpl>(
188 logits->scalar_type(),
189 logits->sizes().size(),
190 const_cast<TensorImpl::SizesType*>(logits->sizes().data()),
191 ptr->logits,
192 const_cast<TensorImpl::DimOrderType*>(logits->dim_order().data()));
193 output_tensors_[modules_.size() - 1].push_back(logits_.get());
194 }
195 }
196
update_io(int64_t cur_token,int64_t pos,std::vector<std::vector<Tensor>> & output_tensors)197 void BertMemory::update_io(
198 int64_t cur_token,
199 int64_t pos,
200 std::vector<std::vector<Tensor>>& output_tensors) {
201 (void)output_tensors;
202 IO* ptr = static_cast<IO*>(data_ptr_.get());
203 static int num_tokens_generated = 0;
204 int seq_len = 1024, last_index = seq_len - 1;
205 // refill past token ids, which is equivalent to following snippet:
206 // --->
207 // for (int i = 0; i < last_index; ++i) {
208 // ptr->input_ids[i] = ptr->input_ids[i + 1];
209 // }
210 // ptr->input_ids[last_index] = static_cast<int32_t>(cur_token);
211 // <---
212 int32_t* new_addr = ++num_tokens_generated + ptr->input_ids;
213 new_addr[last_index] = static_cast<int32_t>(cur_token);
214 input_ids_->set_data(new_addr);
215 // update causal mask for next token
216 int tokens = pos + 1, start = last_index - tokens;
217 for (int i = last_index; tokens >= 0; --i, --tokens) {
218 ptr->attention_mask[i * seq_len + start] = 65535;
219 }
220 }
221
KVCachedMemory(const std::vector<std::string> & pos_embs_path,std::vector<std::shared_ptr<Module>> & modules,std::vector<int> shard_layers)222 KVCachedMemory::KVCachedMemory(
223 const std::vector<std::string>& pos_embs_path,
224 std::vector<std::shared_ptr<Module>>& modules,
225 std::vector<int> shard_layers)
226 : Memory(pos_embs_path, modules),
227 shard_layers_(shard_layers),
228 num_heads_(QAIHUB_LLAMA_NUM_HEADS) {
229 data_ptr_ = std::unique_ptr<void, void (*)(void*)>(
230 new IO, [](void* ptr) { delete static_cast<IO*>(ptr); });
231 if (num_heads_ == 32) {
232 futures_ = std::vector<std::future<void>>(thread_pool_.num_workers());
233 }
234 }
235
prepare_io(const std::vector<Result<MethodMeta>> & methods_meta)236 void KVCachedMemory::prepare_io(
237 const std::vector<Result<MethodMeta>>& methods_meta) {
238 IO* ptr = static_cast<IO*>(data_ptr_.get());
239 std::memset(ptr, 0, sizeof(IO));
240 for (int i = 0; i < modules_.size(); ++i) {
241 ET_CHECK_MSG(
242 methods_meta[i].ok(),
243 "Failed to get method_meta 0x%x",
244 static_cast<uint32_t>(methods_meta[i].error()));
245 }
246 // [I] position embedding initialization
247 for (size_t i = 0; i < pos_embs_path_.size(); ++i) {
248 std::ifstream fin(pos_embs_path_[i], std::ios::binary);
249 fin.read(
250 reinterpret_cast<char*>(
251 i == 0 ? ptr->position_ids_cos : ptr->position_ids_sin),
252 1024 * 64 * 2);
253 fin.close();
254 }
255 // [I]: all shards (4 shards for llama2, 5 shards for llama)
256 {
257 // [I]: input_ids
258 Result<TensorInfo> input_ids = methods_meta[0]->input_tensor_meta(0);
259 input_ids_ = std::make_unique<TensorImpl>(
260 input_ids->scalar_type(),
261 input_ids->sizes().size(),
262 const_cast<TensorImpl::SizesType*>(input_ids->sizes().data()),
263 &ptr->input_ids,
264 const_cast<TensorImpl::DimOrderType*>(input_ids->dim_order().data()));
265 input_tensors_[0].push_back(input_ids_.get());
266 // [I]: atten_mask
267 Result<TensorInfo> atten_mask = methods_meta[0]->input_tensor_meta(1);
268 attention_mask_ = std::make_unique<TensorImpl>(
269 atten_mask->scalar_type(),
270 atten_mask->sizes().size(),
271 const_cast<TensorImpl::SizesType*>(atten_mask->sizes().data()),
272 ptr->attention_mask,
273 const_cast<TensorImpl::DimOrderType*>(atten_mask->dim_order().data()));
274 input_tensors_[0].push_back(attention_mask_.get());
275 // [I]: pos_ids_cos
276 Result<TensorInfo> pos_ids_cos = methods_meta[0]->input_tensor_meta(2);
277 position_ids_cos_ = std::make_unique<TensorImpl>(
278 pos_ids_cos->scalar_type(),
279 pos_ids_cos->sizes().size(),
280 const_cast<TensorImpl::SizesType*>(pos_ids_cos->sizes().data()),
281 ptr->position_ids_cos,
282 const_cast<TensorImpl::DimOrderType*>(pos_ids_cos->dim_order().data()));
283 input_tensors_[0].push_back(position_ids_cos_.get());
284 // [I]: pos_ids_sin
285 Result<TensorInfo> pos_ids_sin = methods_meta[0]->input_tensor_meta(3);
286 position_ids_sin_ = std::make_unique<TensorImpl>(
287 pos_ids_sin->scalar_type(),
288 pos_ids_sin->sizes().size(),
289 const_cast<TensorImpl::SizesType*>(pos_ids_sin->sizes().data()),
290 ptr->position_ids_sin,
291 const_cast<TensorImpl::DimOrderType*>(pos_ids_sin->dim_order().data()));
292 input_tensors_[0].push_back(position_ids_sin_.get());
293 // [IO]: hidden_state => [I] shard2,3,4
294 int output_index =
295 shard_layers_[0] * 2 * num_heads_; // layers*(k + v caches)*heads
296 Result<TensorInfo> hidden_state =
297 methods_meta[0]->output_tensor_meta(output_index);
298 hidden_state_ = std::make_unique<TensorImpl>(
299 hidden_state->scalar_type(),
300 hidden_state->sizes().size(),
301 const_cast<TensorImpl::SizesType*>(hidden_state->sizes().data()),
302 ptr->hidden_state,
303 const_cast<TensorImpl::DimOrderType*>(
304 hidden_state->dim_order().data()));
305 // reuse inputs for following tensors
306 for (int shard_index = 1; shard_index < modules_.size(); ++shard_index) {
307 // inputs of shards 1 to n: hidden_state, atten_mask, pos_ids_cos,
308 // pos_ids_sin
309 input_tensors_[shard_index].push_back(hidden_state_.get());
310 input_tensors_[shard_index].push_back(attention_mask_.get());
311 input_tensors_[shard_index].push_back(position_ids_cos_.get());
312 input_tensors_[shard_index].push_back(position_ids_sin_.get());
313 }
314 }
315 // [I] kv_cache for all shards (4 shards for llama2 and 5 shards for llama3)
316 for (int offset = 0, shard_index = 0, v_stride = 1023 * 128;
317 shard_index < modules_.size();
318 offset += shard_layers_[shard_index], shard_index++) {
319 for (int layer = 0; layer < shard_layers_[shard_index]; ++layer) {
320 for (int cache_group = 0; cache_group < 2; ++cache_group) {
321 for (int head = 0; head < num_heads_; ++head) {
322 // bypass hidden_state(input_ids), atten_mask, pos_cos, pos_sin
323 int index =
324 num_heads_ * 2 * layer + cache_group * num_heads_ + head + 4;
325 Result<TensorInfo> kv_cache =
326 methods_meta[shard_index]->input_tensor_meta(index);
327 std::vector<std::unique_ptr<TensorImpl>>& cache =
328 (cache_group == 0 ? k_cache_in_ : v_cache_in_);
329
330 void* cache_ptr = (cache_group == 0)
331 ? static_cast<void*>(ptr->k_cache[layer + offset][head])
332 : static_cast<void*>(
333 ptr->v_cache[layer + offset] + head * v_stride);
334
335 cache.emplace_back(std::make_unique<TensorImpl>(
336 kv_cache->scalar_type(),
337 kv_cache->sizes().size(),
338 const_cast<TensorImpl::SizesType*>(kv_cache->sizes().data()),
339 cache_ptr,
340 const_cast<TensorImpl::DimOrderType*>(
341 kv_cache->dim_order().data())));
342 input_tensors_[shard_index].push_back(cache.back().get());
343 }
344 }
345 }
346 }
347 // [O] kv_cache for all shards (4 shards for llama2 and 5 shards for llama3)
348 for (int offset = 0, shard_index = 0, v_stride = 1023 * 128;
349 shard_index < modules_.size();
350 offset += shard_layers_[shard_index], shard_index++) {
351 for (int layer = 0; layer < shard_layers_[shard_index]; ++layer) {
352 for (int cache_group = 0; cache_group < 2; ++cache_group) {
353 for (int head = 0; head < num_heads_; ++head) {
354 int index = num_heads_ * 2 * layer + cache_group * num_heads_ + head;
355 Result<TensorInfo> kv_cache =
356 methods_meta[shard_index]->output_tensor_meta(index);
357 std::vector<std::unique_ptr<TensorImpl>>& cache =
358 (cache_group == 0 ? v_cache_out_ : k_cache_out_);
359
360 void* cache_ptr = (cache_group == 0)
361 ? static_cast<void*>(
362 ptr->v_cache[layer + offset] + (head + 1) * v_stride)
363 : static_cast<void*>(ptr->k_cache_out[layer + offset][head]);
364
365 cache.emplace_back(std::make_unique<TensorImpl>(
366 kv_cache->scalar_type(),
367 kv_cache->sizes().size(),
368 const_cast<TensorImpl::SizesType*>(kv_cache->sizes().data()),
369 cache_ptr,
370 const_cast<TensorImpl::DimOrderType*>(
371 kv_cache->dim_order().data())));
372 output_tensors_[shard_index].push_back(cache.back().get());
373 }
374 }
375 }
376 }
377 // [O]: hidden_state for shard 0 to n-1
378 for (int shard_index = 0; shard_index < modules_.size() - 1; ++shard_index) {
379 output_tensors_[shard_index].push_back(hidden_state_.get());
380 }
381 // [O]: logits
382 {
383 int output_index = shard_layers_[modules_.size() - 1] * 2 *
384 num_heads_; // layers*(k + v caches)*heads
385 Result<TensorInfo> logits =
386 methods_meta[modules_.size() - 1]->output_tensor_meta(output_index);
387 logits_ = std::make_unique<TensorImpl>(
388 logits->scalar_type(),
389 logits->sizes().size(),
390 const_cast<TensorImpl::SizesType*>(logits->sizes().data()),
391 ptr->logits,
392 const_cast<TensorImpl::DimOrderType*>(logits->dim_order().data()));
393 output_tensors_[modules_.size() - 1].push_back(logits_.get());
394 }
395
396 // QAIHub Llama2 have 4* io compared to QAIHub Llama3,
397 // so we use multi-threading for Llama2 when updating io
398 if (num_heads_ == 32) {
399 // thread pool jobs
400 for (int i = 0, range = 1024 / thread_pool_.num_workers();
401 i < thread_pool_.num_workers();
402 ++i) {
403 lr_update_kv_.push_back(
404 {.start = i * range, .end = (i + 1) * range, .step = 1});
405 }
406 }
407 }
408
update_io(int64_t cur_token,int64_t pos,std::vector<std::vector<Tensor>> & output_tensors)409 void KVCachedMemory::update_io(
410 int64_t cur_token,
411 int64_t pos,
412 std::vector<std::vector<Tensor>>& output_tensors) {
413 IO* ptr = static_cast<IO*>(data_ptr_.get());
414 int seq_len = 1023;
415 // update input_ids
416 ptr->input_ids = static_cast<int32_t>(cur_token);
417 // update causal mask for next token
418 ptr->attention_mask[seq_len - pos] = 65535;
419 // update position_ids
420 position_ids_cos_->set_data(position_ids_cos_->mutable_data<uint16_t>() + 64);
421 position_ids_sin_->set_data(position_ids_sin_->mutable_data<uint16_t>() + 64);
422
423 // use multithreading when we have a lot of ios, Llama2 in this case
424 if (num_heads_ == 32) {
425 auto update_kv = [&](void* arg) {
426 LoopRange* lr = static_cast<LoopRange*>(arg);
427 // update v_cache
428 for (int i = lr->start; i < lr->end; i += lr->step) {
429 v_cache_in_[i]->set_data(v_cache_in_[i]->mutable_data<uint8_t>() + 128);
430 v_cache_out_[i]->set_data(
431 v_cache_out_[i]->mutable_data<uint8_t>() + 128);
432 }
433 // update output tensors of v_cache, 256 is the number of kvs per shard
434 int shard = lr->start >> 8, offset = shard << 8;
435 int start = lr->start - offset, end = lr->end - offset;
436 for (int cache_stride = start; cache_stride < end; cache_stride += 32) {
437 for (int cache_group = 0; cache_group < 2; ++cache_group) {
438 for (int head = 0; head < 32; ++head) {
439 // k, v are placed interleaved
440 int index = (cache_stride << 1) + (cache_group << 5) + head;
441 ET_CHECK_MSG(
442 modules_[shard]->set_output(
443 method_names_[shard],
444 output_tensors[shard][index],
445 index) == Error::Ok,
446 "failed to set output tensor for module %d's %d'th output "
447 "while updating kv_cache output tensors",
448 shard,
449 index);
450 }
451 }
452 }
453 };
454
455 for (int i = 0; i < lr_update_kv_.size(); ++i) {
456 futures_[i] = std::move(thread_pool_.issue(update_kv, &lr_update_kv_[i]));
457 }
458 } else {
459 // update v_cache
460 for (int i = 0; i < v_cache_in_.size(); i++) {
461 v_cache_in_[i]->set_data(v_cache_in_[i]->mutable_data<uint8_t>() + 128);
462 v_cache_out_[i]->set_data(v_cache_out_[i]->mutable_data<uint8_t>() + 128);
463 }
464 for (int shard = 0; shard < output_tensors.size(); shard++) {
465 for (int index = 0; index < output_tensors[shard].size(); index++) {
466 ET_CHECK_MSG(
467 modules_[shard]->set_output(
468 method_names_[shard], output_tensors[shard][index], index) ==
469 Error::Ok,
470 "failed to set output tensor for module %d's %d'th output "
471 "while updating kv_cache output tensors",
472 shard,
473 index);
474 }
475 }
476 }
477 // update k_cache by single thread, this part is cpu cache sensitive
478 for (int i = 0; i < k_cache_in_.size(); ++i) {
479 uint8_t* ptr_in = k_cache_in_[i]->mutable_data<uint8_t>();
480 const uint8_t* ptr_out = k_cache_out_[i]->data<uint8_t>();
481 for (size_t j = 0, offset = seq_len; j < 128; ++j, offset += seq_len) {
482 ptr_in[offset] = ptr_out[j];
483 }
484 k_cache_in_[i]->set_data(ptr_in + 1);
485 }
486 for (auto& future : futures_) {
487 future.wait();
488 }
489 }
490
ThreadPool()491 ThreadPool::ThreadPool() : stop_(false) {
492 size_t hc = (std::thread::hardware_concurrency() + 3) / 4;
493 // maximum number should be divisible by head dimension which equals to 32
494 num_workers_ = std::min<size_t>(32, hc * 4);
495 for (size_t i = 0; i < num_workers_; ++i) {
496 threads_.emplace_back([this]() {
497 while (1) {
498 std::unique_lock<std::mutex> lock(mutex_);
499 cv_.wait(lock, [this] { return !jobs_.empty() || stop_; });
500
501 if (stop_ && jobs_.empty())
502 return;
503
504 JobInfo job_info(std::move(jobs_.front()));
505 jobs_.pop();
506 lock.unlock();
507 job_info.func(job_info.arg);
508 }
509 });
510 }
511 }
512
~ThreadPool()513 ThreadPool::~ThreadPool() {
514 std::unique_lock<std::mutex> lock(mutex_);
515 stop_ = true;
516 lock.unlock();
517 cv_.notify_all();
518 for (auto& thread : threads_) {
519 thread.join();
520 }
521 }
522
issue(std::function<void (void *)> func,void * arg)523 std::future<void> ThreadPool::issue(
524 std::function<void(void*)> func,
525 void* arg) {
526 std::unique_lock<std::mutex> lock(mutex_);
527 jobs_.push(JobInfo(std::packaged_task<void(void*)>(func), arg));
528 std::future<void> f = std::move(jobs_.back().func.get_future());
529 lock.unlock();
530 cv_.notify_one();
531 return f;
532 }
533
num_workers()534 size_t ThreadPool::num_workers() {
535 return num_workers_;
536 }
537
538 } // namespace example
539