xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/graph/ComputeGraph.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 // @lint-ignore-every CLANGTIDY
10 // facebook-security-vulnerable-integer-sign-conversion
11 
12 #include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
13 
14 #include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
15 
16 #include <executorch/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h>
17 
18 namespace vkcompute {
19 
20 //
21 // VTensorPtr
22 //
23 
24 #define VALUE_PTR_CLASS_IMPL(classname, ctype, type_name)                 \
25   classname::classname(ComputeGraph* const graph, const ValueRef idx)     \
26       : graph_(graph), ptr_(&(graph_->values_.at(idx).to##type_name())) { \
27     graph_->values_in_use_++;                                             \
28   }                                                                       \
29   ctype* classname::operator->() const {                                  \
30     return ptr_;                                                          \
31   }                                                                       \
32   ctype& classname::operator*() const {                                   \
33     return *ptr_;                                                         \
34   }                                                                       \
35   classname::~classname() {                                               \
36     graph_->values_in_use_--;                                             \
37   }
38 
VALUE_PTR_CLASS_IMPL(vTensorPtr,api::vTensor,Tensor)39 VALUE_PTR_CLASS_IMPL(vTensorPtr, api::vTensor, Tensor)
40 VALUE_PTR_CLASS_IMPL(TensorRefPtr, TensorRef, TensorRef)
41 VALUE_PTR_CLASS_IMPL(StagingPtr, api::StagingBuffer, Staging)
42 VALUE_PTR_CLASS_IMPL(IntListPtr, std::vector<int64_t>, IntList)
43 VALUE_PTR_CLASS_IMPL(DoubleListPtr, std::vector<double>, DoubleList)
44 VALUE_PTR_CLASS_IMPL(BoolListPtr, std::vector<bool>, BoolList)
45 VALUE_PTR_CLASS_IMPL(ValueListPtr, std::vector<ValueRef>, ValueList)
46 VALUE_PTR_CLASS_IMPL(SymIntPtr, SymInt, SymInt)
47 
48 #undef VALUE_PTR_CLASS_IMPL
49 
50 //
51 // TmpTensor
52 //
53 
54 TmpTensor::TmpTensor(
55     ComputeGraph* const graph_ptr,
56     const std::vector<int64_t>& sizes,
57     const vkapi::ScalarType dtype,
58     const utils::StorageType storage_type,
59     const utils::GPUMemoryLayout memory_layout)
60     : graph_p(graph_ptr),
61       sobj_idx(get_sobj_idx()),
62       vref(graph_p->add_tensor(
63           sizes,
64           dtype,
65           storage_type,
66           memory_layout,
67           sobj_idx)) {}
68 
TmpTensor(ComputeGraph * const graph_ptr,const std::vector<int64_t> & sizes,const vkapi::ScalarType dtype,const utils::StorageType storage_type)69 TmpTensor::TmpTensor(
70     ComputeGraph* const graph_ptr,
71     const std::vector<int64_t>& sizes,
72     const vkapi::ScalarType dtype,
73     const utils::StorageType storage_type)
74     : graph_p(graph_ptr),
75       sobj_idx(get_sobj_idx()),
76       vref(graph_p->add_tensor(sizes, dtype, storage_type, sobj_idx)) {}
77 
TmpTensor(ComputeGraph * const graph_ptr,const std::vector<int64_t> & sizes,const vkapi::ScalarType dtype,const utils::GPUMemoryLayout memory_layout)78 TmpTensor::TmpTensor(
79     ComputeGraph* const graph_ptr,
80     const std::vector<int64_t>& sizes,
81     const vkapi::ScalarType dtype,
82     const utils::GPUMemoryLayout memory_layout)
83     : graph_p(graph_ptr),
84       sobj_idx(get_sobj_idx()),
85       vref(graph_p->add_tensor(sizes, dtype, memory_layout, sobj_idx)) {}
86 
TmpTensor(ComputeGraph * const graph_ptr,const std::vector<int64_t> & sizes,const vkapi::ScalarType dtype)87 TmpTensor::TmpTensor(
88     ComputeGraph* const graph_ptr,
89     const std::vector<int64_t>& sizes,
90     const vkapi::ScalarType dtype)
91     : graph_p(graph_ptr),
92       sobj_idx(get_sobj_idx()),
93       vref(graph_p->add_tensor(sizes, dtype, sobj_idx)) {}
94 
~TmpTensor()95 TmpTensor::~TmpTensor() {
96   // Lifetime of this temporary tensor is expired; return the shared object to
97   // the pool, as long as the sobj index is valid
98   if (sobj_idx >= 0) {
99     graph_p->tmp_shared_object_idxs_.emplace(sobj_idx);
100   }
101 }
102 
get_sobj_idx()103 int64_t TmpTensor::get_sobj_idx() {
104   int64_t sobj_idx;
105   // If no available temporary shared objects, request a new one to be created
106   if (graph_p->tmp_shared_object_idxs_.empty()) {
107     sobj_idx = graph_p->shared_objects_.size();
108   } else {
109     // Get the first available shared object idx
110     sobj_idx = graph_p->tmp_shared_object_idxs_.top();
111     graph_p->tmp_shared_object_idxs_.pop();
112   }
113   return sobj_idx;
114 }
115 
116 //
117 // ComputeGraph
118 //
119 
ComputeGraph(GraphConfig config)120 ComputeGraph::ComputeGraph(GraphConfig config)
121     : config_{config},
122       prepack_descriptor_counts_{},
123       execute_descriptor_counts_{},
124       context_{new api::Context(
125           vkapi::runtime()->default_adapter_i(),
126           config_.context_config)},
127       shared_objects_{},
128       values_{},
129       param_ubos_{},
130       prepack_nodes_{},
131       execute_nodes_{},
132       inputs_{},
133       outputs_{} {
134   // Ensure that descriptor counts are initialized to 0
135   prepack_descriptor_counts_.descriptor_pool_max_sets = 0;
136   prepack_descriptor_counts_.descriptor_uniform_buffer_count = 0;
137   prepack_descriptor_counts_.descriptor_storage_buffer_count = 0;
138   prepack_descriptor_counts_.descriptor_combined_sampler_count = 0;
139   prepack_descriptor_counts_.descriptor_storage_image_count = 0;
140 
141   execute_descriptor_counts_.descriptor_pool_max_sets = 0;
142   execute_descriptor_counts_.descriptor_uniform_buffer_count = 0;
143   execute_descriptor_counts_.descriptor_storage_buffer_count = 0;
144   execute_descriptor_counts_.descriptor_combined_sampler_count = 0;
145   execute_descriptor_counts_.descriptor_storage_image_count = 0;
146 
147   context_->set_cmd(/*reusable = */ true);
148 }
149 
~ComputeGraph()150 ComputeGraph::~ComputeGraph() {
151   values_.clear();
152 
153   prepack_nodes_.clear();
154   execute_nodes_.clear();
155 
156   context_->flush();
157 }
158 
suggested_storage_type()159 utils::StorageType ComputeGraph::suggested_storage_type() {
160   if (config_.enable_storage_type_override) {
161     return config_.storage_type_override;
162   }
163   return utils::kTexture3D;
164 }
165 
suggested_memory_layout(const std::vector<int64_t> & sizes)166 utils::GPUMemoryLayout ComputeGraph::suggested_memory_layout(
167     const std::vector<int64_t>& sizes) {
168   if (config_.enable_memory_layout_override) {
169     return config_.memory_layout_override;
170   }
171   if (sizes.size() < 3) {
172     return utils::kWidthPacked;
173   }
174   // For 3 dimensional tensors that only have a channels dimension of 1, still
175   // prefer width packed.
176   if (utils::val_at(-3, sizes) == 1) {
177     return utils::kWidthPacked;
178   }
179   return utils::kChannelsPacked;
180 }
181 
check_no_active_value_ptrs()182 void ComputeGraph::check_no_active_value_ptrs() {
183   VK_CHECK_COND(
184       values_in_use_ == 0,
185       "Make sure that there are no pointers stored from the return values of "
186       "`ComputeGraph::get_*()` functions in scope before adding Values to the "
187       "graph. Modifying the graph's values may cause existing pointers to be "
188       "invalidated.");
189 }
190 
sizes_of(const ValueRef idx) const191 std::vector<int64_t> ComputeGraph::sizes_of(const ValueRef idx) const {
192   const Value& val = values_.at(idx);
193   if (val.isTensor()) {
194     return val.toConstTensor().sizes();
195   } else if (val.isTensorRef()) {
196     return val.toConstTensorRef().sizes;
197   }
198   VK_THROW("Could not get sizes of value with type ", val.type());
199 }
200 
dim_of(const ValueRef idx) const201 int64_t ComputeGraph::dim_of(const ValueRef idx) const {
202   const Value& val = values_.at(idx);
203   if (val.isTensor()) {
204     return val.toConstTensor().dim();
205   } else if (val.isTensorRef()) {
206     return val.toConstTensorRef().sizes.size();
207   }
208   VK_THROW("Could not get dim of value with type ", val.type());
209 }
210 
dim_order_of(const ValueRef idx) const211 std::vector<int64_t> ComputeGraph::dim_order_of(const ValueRef idx) const {
212   const Value& val = values_.at(idx);
213   if (val.isTensor()) {
214     return val.toConstTensor().dim_order();
215   }
216   VK_THROW("Could not get dim order of value with type ", val.type());
217 }
218 
strides_of(const ValueRef idx) const219 std::vector<int64_t> ComputeGraph::strides_of(const ValueRef idx) const {
220   const Value& val = values_.at(idx);
221   if (val.isTensor()) {
222     return val.toConstTensor().strides();
223   }
224   VK_THROW("Could not get strides of value with type ", val.type());
225 }
226 
dtype_of(const ValueRef idx) const227 vkapi::ScalarType ComputeGraph::dtype_of(const ValueRef idx) const {
228   const Value& val = values_.at(idx);
229   if (val.isTensor()) {
230     return val.toConstTensor().dtype();
231   } else if (val.isTensorRef()) {
232     return val.toConstTensorRef().dtype;
233   }
234   VK_THROW("Could not get dtype of value with type ", val.type());
235 }
236 
add_tensor(const std::vector<int64_t> & sizes,const vkapi::ScalarType dtype,const utils::StorageType storage_type,const utils::GPUMemoryLayout memory_layout,const int64_t shared_object_idx)237 ValueRef ComputeGraph::add_tensor(
238     const std::vector<int64_t>& sizes,
239     const vkapi::ScalarType dtype,
240     const utils::StorageType storage_type,
241     const utils::GPUMemoryLayout memory_layout,
242     const int64_t shared_object_idx) {
243   bool allocate_memory = shared_object_idx < 0;
244 
245   ValueRef idx(static_cast<int>(values_.size()));
246   check_no_active_value_ptrs();
247   values_.emplace_back(api::vTensor(
248       context(), sizes, dtype, storage_type, memory_layout, allocate_memory));
249 
250   if (!allocate_memory) {
251     get_shared_object(shared_object_idx).add_user(this, idx);
252   }
253   return idx;
254 }
255 
add_tensor(const std::vector<int64_t> & sizes,const vkapi::ScalarType dtype,const utils::StorageType storage_type,const int64_t shared_object_idx)256 ValueRef ComputeGraph::add_tensor(
257     const std::vector<int64_t>& sizes,
258     const vkapi::ScalarType dtype,
259     const utils::StorageType storage_type,
260     const int64_t shared_object_idx) {
261   return add_tensor(
262       sizes,
263       dtype,
264       storage_type,
265       suggested_memory_layout(sizes),
266       shared_object_idx);
267 }
268 
add_tensor(const std::vector<int64_t> & sizes,const vkapi::ScalarType dtype,const utils::GPUMemoryLayout memory_layout,const int64_t shared_object_idx)269 ValueRef ComputeGraph::add_tensor(
270     const std::vector<int64_t>& sizes,
271     const vkapi::ScalarType dtype,
272     const utils::GPUMemoryLayout memory_layout,
273     const int64_t shared_object_idx) {
274   return add_tensor(
275       sizes, dtype, suggested_storage_type(), memory_layout, shared_object_idx);
276 }
277 
add_tensor_like(const ValueRef idx,const utils::StorageType storage_type,const utils::GPUMemoryLayout memory_layout)278 ValueRef ComputeGraph::add_tensor_like(
279     const ValueRef idx,
280     const utils::StorageType storage_type,
281     const utils::GPUMemoryLayout memory_layout) {
282   return add_tensor(sizes_of(idx), dtype_of(idx), storage_type, memory_layout);
283 }
284 
add_tensor_like(const ValueRef idx,const utils::GPUMemoryLayout memory_layout)285 ValueRef ComputeGraph::add_tensor_like(
286     const ValueRef idx,
287     const utils::GPUMemoryLayout memory_layout) {
288   return add_tensor(
289       sizes_of(idx), dtype_of(idx), storage_type_of(idx), memory_layout);
290 }
291 
add_tensor(const std::vector<int64_t> & sizes,const vkapi::ScalarType dtype,const int64_t shared_object_idx)292 ValueRef ComputeGraph::add_tensor(
293     const std::vector<int64_t>& sizes,
294     const vkapi::ScalarType dtype,
295     const int64_t shared_object_idx) {
296   return add_tensor(
297       sizes, dtype, suggested_memory_layout(sizes), shared_object_idx);
298 }
299 
add_tensor(const vkapi::VulkanImage & image)300 ValueRef ComputeGraph::add_tensor(const vkapi::VulkanImage& image) {
301   ValueRef idx(static_cast<int>(values_.size()));
302   check_no_active_value_ptrs();
303   values_.emplace_back(api::vTensor(context(), image));
304   return idx;
305 }
306 
add_tensor_view(const ValueRef vref)307 ValueRef ComputeGraph::add_tensor_view(const ValueRef vref) {
308   const vTensorPtr t = get_tensor(vref);
309   ValueRef idx(static_cast<int>(values_.size()));
310   values_.emplace_back(api::vTensor(*t));
311   for (SharedObject& sobj : shared_objects_) {
312     if (sobj.has_user(vref)) {
313       sobj.add_user(this, idx);
314     }
315   }
316   return idx;
317 }
318 
add_tensor_view(const ValueRef vref,const std::vector<int64_t> & sizes,const std::vector<int64_t> & strides,const size_t offset_numel)319 ValueRef ComputeGraph::add_tensor_view(
320     const ValueRef vref,
321     const std::vector<int64_t>& sizes,
322     const std::vector<int64_t>& strides,
323     const size_t offset_numel) {
324   const vTensorPtr t = get_tensor(vref);
325   ValueRef idx(static_cast<int>(values_.size()));
326   values_.emplace_back(api::vTensor(*t, sizes, strides, offset_numel));
327   for (SharedObject& sobj : shared_objects_) {
328     if (sobj.has_user(vref)) {
329       sobj.add_user(this, idx);
330     }
331   }
332   return idx;
333 }
334 
add_tensorref(const std::vector<int64_t> & sizes,const vkapi::ScalarType dtype,const void * const data)335 ValueRef ComputeGraph::add_tensorref(
336     const std::vector<int64_t>& sizes,
337     const vkapi::ScalarType dtype,
338     const void* const data) {
339   ValueRef idx(static_cast<int>(values_.size()));
340   check_no_active_value_ptrs();
341   values_.emplace_back(TensorRef(sizes, dtype, data));
342   return idx;
343 }
344 
add_staging(const vkapi::ScalarType dtype,const size_t numel)345 ValueRef ComputeGraph::add_staging(
346     const vkapi::ScalarType dtype,
347     const size_t numel) {
348   ValueRef idx(static_cast<int>(values_.size()));
349   check_no_active_value_ptrs();
350   values_.emplace_back(api::StagingBuffer(context(), dtype, numel));
351   return idx;
352 }
353 
add_none()354 ValueRef ComputeGraph::add_none() {
355   ValueRef idx(static_cast<int>(values_.size()));
356   check_no_active_value_ptrs();
357   values_.emplace_back();
358   return idx;
359 }
360 
add_value_list(std::vector<ValueRef> && value)361 ValueRef ComputeGraph::add_value_list(std::vector<ValueRef>&& value) {
362   ValueRef idx(static_cast<int>(values_.size()));
363   check_no_active_value_ptrs();
364   values_.emplace_back(std::move(value));
365   return idx;
366 }
367 
add_string(std::string && str)368 ValueRef ComputeGraph::add_string(std::string&& str) {
369   ValueRef idx(static_cast<int>(values_.size()));
370   check_no_active_value_ptrs();
371   values_.emplace_back(std::move(str));
372   return idx;
373 }
374 
add_symint(const int32_t val)375 ValueRef ComputeGraph::add_symint(const int32_t val) {
376   ValueRef idx(static_cast<int>(values_.size()));
377   check_no_active_value_ptrs();
378   values_.emplace_back(SymInt(context(), val));
379   return idx;
380 }
381 
set_input_tensor(const ValueRef idx,const bool use_staging)382 ValueRef ComputeGraph::set_input_tensor(
383     const ValueRef idx,
384     const bool use_staging) {
385   if (use_staging) {
386     vkapi::ScalarType dtype = get_tensor(idx)->dtype();
387     // For texture storage, the buffer size needs to account for the zero
388     // padding applied by unused texel elements.
389     size_t buf_numel = get_tensor(idx)->staging_buffer_numel();
390     ValueRef staging_idx = add_staging(dtype, buf_numel);
391     add_staging_to_tensor_node(*this, staging_idx, idx);
392     inputs_.push_back({idx, staging_idx});
393     return staging_idx;
394   }
395   inputs_.push_back({idx, kDummyValueRef});
396   return idx;
397 }
398 
set_output_tensor(const ValueRef idx,const bool use_staging)399 ValueRef ComputeGraph::set_output_tensor(
400     const ValueRef idx,
401     const bool use_staging) {
402   if (use_staging) {
403     vkapi::ScalarType dtype = get_tensor(idx)->dtype();
404     // For texture storage, the buffer size needs to account for the zero
405     // padding applied by unused texel elements.
406     size_t buf_numel = get_tensor(idx)->staging_buffer_numel();
407     ValueRef staging_idx = add_staging(dtype, buf_numel);
408     // We only run this when the tensor is non-empty.  When the underlying
409     // tensor is empty (e.g. padded_numel == 0), we do not allocate a VkImage to
410     // tensor, we will not be able to bind the node for execution.
411     if (buf_numel > 0) {
412       add_tensor_to_staging_node(*this, idx, staging_idx);
413     }
414     outputs_.push_back({idx, staging_idx});
415     return staging_idx;
416   }
417   outputs_.push_back({idx, kDummyValueRef});
418   return idx;
419 }
420 
get_or_create_int_param_buffer(const ValueRef idx)421 vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer(
422     const ValueRef idx) {
423   if (values_.at(idx).isInt()) {
424     const int32_t val = extract_scalar<int32_t>(idx);
425     create_params_buffer(val);
426   } else if (values_.at(idx).isSymInt()) {
427     SymIntPtr symint = get_symint(idx);
428     return vkapi::BufferBindInfo(symint->gpu_buffer.buffer());
429   }
430   VK_THROW("Cannot create a int param buffer for the given value");
431 }
432 
set_symint(const ValueRef idx,const int32_t val)433 void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) {
434   get_symint(idx)->set(val);
435 }
436 
read_symint(const ValueRef idx)437 int32_t ComputeGraph::read_symint(const ValueRef idx) {
438   return get_symint(idx)->get();
439 }
440 
get_shared_object(const int64_t idx)441 SharedObject& ComputeGraph::get_shared_object(const int64_t idx) {
442   if (idx >= shared_objects_.size()) {
443     shared_objects_.resize(static_cast<size_t>(idx + 1));
444   }
445   return shared_objects_.at(idx);
446 }
447 
update_descriptor_counts(const vkapi::ShaderInfo & shader_info,bool execute)448 void ComputeGraph::update_descriptor_counts(
449     const vkapi::ShaderInfo& shader_info,
450     bool execute) {
451   vkapi::DescriptorPoolConfig* config =
452       execute ? &execute_descriptor_counts_ : &prepack_descriptor_counts_;
453 
454   config->descriptor_pool_max_sets += 1;
455   for (const VkDescriptorType arg_type : shader_info.kernel_layout) {
456     switch (arg_type) {
457       case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
458         config->descriptor_uniform_buffer_count += 1;
459         break;
460       case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER:
461         config->descriptor_storage_buffer_count += 1;
462         break;
463       case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:
464         config->descriptor_combined_sampler_count += 1;
465         break;
466       case VK_DESCRIPTOR_TYPE_STORAGE_IMAGE:
467         config->descriptor_storage_image_count += 1;
468         break;
469       default:
470         VK_THROW("Unsupported descriptor type!");
471     }
472   }
473 }
474 
create_global_wg_size(const ValueRef idx)475 utils::uvec3 ComputeGraph::create_global_wg_size(const ValueRef idx) {
476   if (is_buffer_storage(idx)) {
477     return {uint32_t(numel_of(idx)), 1u, 1u};
478   }
479   return logical_limits_of(idx);
480 }
481 
create_local_wg_size(const utils::uvec3 global_wg_size)482 utils::uvec3 ComputeGraph::create_local_wg_size(
483     const utils::uvec3 global_wg_size) {
484   if (config_.enable_local_wg_size_override) {
485     return config_.local_wg_size_override;
486   }
487 
488   // array containing axis index and global workgroup size
489   std::pair<uint32_t, uint32_t> global_wg_size_desc[] = {
490       {0u, global_wg_size[0]},
491       {1u, global_wg_size[1]},
492       {2u, global_wg_size[2]}};
493 
494   // sort the global workgroup size in descending order
495   if (global_wg_size_desc[0].second < global_wg_size_desc[1].second) {
496     std::swap(global_wg_size_desc[0], global_wg_size_desc[1]);
497   }
498   if (global_wg_size_desc[1].second < global_wg_size_desc[2].second) {
499     std::swap(global_wg_size_desc[1], global_wg_size_desc[2]);
500   }
501   if (global_wg_size_desc[0].second < global_wg_size_desc[1].second) {
502     std::swap(global_wg_size_desc[0], global_wg_size_desc[1]);
503   }
504 
505   utils::uvec3 local_group_size = {
506       8,
507       std::max(1u, std::min(4u, global_wg_size_desc[1].second)),
508       std::max(1u, std::min(2u, global_wg_size_desc[2].second))};
509 
510   if (global_wg_size_desc[2u].second == 1) {
511     if (global_wg_size_desc[1u].second == 1) {
512       local_group_size[0u] = 64;
513       local_group_size[1u] = 1;
514     } else if (global_wg_size_desc[1u].second % 4 == 0) {
515       local_group_size[0u] = 16;
516       local_group_size[1u] = 4;
517     } else {
518       local_group_size[0u] = 32;
519       local_group_size[1u] = 2;
520     }
521   }
522 
523   return {
524       local_group_size[global_wg_size_desc[0].first],
525       local_group_size[global_wg_size_desc[1].first],
526       local_group_size[global_wg_size_desc[2].first]};
527 }
528 
create_local_wg_size(const ValueRef idx)529 utils::uvec3 ComputeGraph::create_local_wg_size(const ValueRef idx) {
530   return create_local_wg_size(create_global_wg_size(idx));
531 }
532 
copy_into_staging(const ValueRef idx,const void * data,const size_t numel)533 void ComputeGraph::copy_into_staging(
534     const ValueRef idx,
535     const void* data,
536     const size_t numel) {
537   StagingPtr staging = get_staging(idx);
538   size_t nbytes = numel * vkapi::element_size(staging->dtype());
539   staging->copy_from(data, nbytes);
540 }
541 
copy_from_staging(const ValueRef idx,void * data,const size_t numel)542 void ComputeGraph::copy_from_staging(
543     const ValueRef idx,
544     void* data,
545     const size_t numel) {
546   StagingPtr staging = get_staging(idx);
547   size_t nbytes = numel * vkapi::element_size(staging->dtype());
548   staging->copy_to(data, nbytes);
549 }
550 
prepare()551 void ComputeGraph::prepare() {
552 #define MERGE_FIELD(field)                    \
553   static_cast<uint32_t>(std::ceil(            \
554       std::max(                               \
555           execute_descriptor_counts_.field,   \
556           prepack_descriptor_counts_.field) * \
557       config_.descriptor_pool_safety_factor))
558 
559   uint32_t max_sets = MERGE_FIELD(descriptor_pool_max_sets);
560   vkapi::DescriptorPoolConfig config{
561       max_sets,
562       std::max(MERGE_FIELD(descriptor_uniform_buffer_count), max_sets),
563       std::max(MERGE_FIELD(descriptor_storage_buffer_count), max_sets),
564       std::max(MERGE_FIELD(descriptor_combined_sampler_count), max_sets),
565       std::max(MERGE_FIELD(descriptor_storage_image_count), max_sets),
566       1u,
567   };
568 
569   if (!context_->descriptor_pool()) {
570     context_->descriptor_pool().init(config);
571   }
572 #undef MERGE_FIELD
573 
574   if (config_.enable_querypool) {
575     context_->initialize_querypool();
576   }
577 }
578 
encode_prepack()579 void ComputeGraph::encode_prepack() {
580   for (std::unique_ptr<PrepackNode>& node : prepack_nodes_) {
581     node->encode(this);
582   }
583 }
584 
prepack() const585 void ComputeGraph::prepack() const {
586   // Submit and execute the command buffer
587   vkapi::VulkanFence fence = context_->fences().get_fence();
588   context_->submit_cmd_to_gpu(fence.get_submit_handle(), /*final_use = */ true);
589   fence.wait();
590 
591   context_->flush();
592 }
593 
encode_execute()594 void ComputeGraph::encode_execute() {
595   context_->flush();
596   context_->set_cmd(/*reusable = */ true);
597 
598   context_->cmd_reset_querypool();
599 
600   for (SharedObject& shared_object : shared_objects_) {
601     shared_object.allocate(this);
602     shared_object.bind_users(this);
603   }
604 
605   for (std::unique_ptr<ExecuteNode>& node : execute_nodes_) {
606     node->encode(this);
607   }
608 }
609 
execute() const610 void ComputeGraph::execute() const {
611   vkapi::VulkanFence fence = context_->fences().get_fence();
612   context_->submit_cmd_to_gpu(fence.get_submit_handle());
613   fence.wait();
614 }
615 
resize_input(const int64_t idx,const std::vector<int64_t> & new_sizes)616 void ComputeGraph::resize_input(
617     const int64_t idx,
618     const std::vector<int64_t>& new_sizes) {
619   IOValueRef io_val = inputs_.at(idx);
620   get_tensor(io_val.value)->virtual_resize(new_sizes);
621 }
622 
propagate_resize()623 void ComputeGraph::propagate_resize() {
624   for (std::unique_ptr<ExecuteNode>& node : execute_nodes_) {
625     node->trigger_resize(this);
626   }
627 }
628 
629 } // namespace vkcompute
630