1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 // @lint-ignore-every CLANGTIDY
10 // facebook-security-vulnerable-integer-sign-conversion
11
12 #include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
13
14 #include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
15
16 #include <executorch/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h>
17
18 namespace vkcompute {
19
20 //
21 // VTensorPtr
22 //
23
24 #define VALUE_PTR_CLASS_IMPL(classname, ctype, type_name) \
25 classname::classname(ComputeGraph* const graph, const ValueRef idx) \
26 : graph_(graph), ptr_(&(graph_->values_.at(idx).to##type_name())) { \
27 graph_->values_in_use_++; \
28 } \
29 ctype* classname::operator->() const { \
30 return ptr_; \
31 } \
32 ctype& classname::operator*() const { \
33 return *ptr_; \
34 } \
35 classname::~classname() { \
36 graph_->values_in_use_--; \
37 }
38
VALUE_PTR_CLASS_IMPL(vTensorPtr,api::vTensor,Tensor)39 VALUE_PTR_CLASS_IMPL(vTensorPtr, api::vTensor, Tensor)
40 VALUE_PTR_CLASS_IMPL(TensorRefPtr, TensorRef, TensorRef)
41 VALUE_PTR_CLASS_IMPL(StagingPtr, api::StagingBuffer, Staging)
42 VALUE_PTR_CLASS_IMPL(IntListPtr, std::vector<int64_t>, IntList)
43 VALUE_PTR_CLASS_IMPL(DoubleListPtr, std::vector<double>, DoubleList)
44 VALUE_PTR_CLASS_IMPL(BoolListPtr, std::vector<bool>, BoolList)
45 VALUE_PTR_CLASS_IMPL(ValueListPtr, std::vector<ValueRef>, ValueList)
46 VALUE_PTR_CLASS_IMPL(SymIntPtr, SymInt, SymInt)
47
48 #undef VALUE_PTR_CLASS_IMPL
49
50 //
51 // TmpTensor
52 //
53
54 TmpTensor::TmpTensor(
55 ComputeGraph* const graph_ptr,
56 const std::vector<int64_t>& sizes,
57 const vkapi::ScalarType dtype,
58 const utils::StorageType storage_type,
59 const utils::GPUMemoryLayout memory_layout)
60 : graph_p(graph_ptr),
61 sobj_idx(get_sobj_idx()),
62 vref(graph_p->add_tensor(
63 sizes,
64 dtype,
65 storage_type,
66 memory_layout,
67 sobj_idx)) {}
68
TmpTensor(ComputeGraph * const graph_ptr,const std::vector<int64_t> & sizes,const vkapi::ScalarType dtype,const utils::StorageType storage_type)69 TmpTensor::TmpTensor(
70 ComputeGraph* const graph_ptr,
71 const std::vector<int64_t>& sizes,
72 const vkapi::ScalarType dtype,
73 const utils::StorageType storage_type)
74 : graph_p(graph_ptr),
75 sobj_idx(get_sobj_idx()),
76 vref(graph_p->add_tensor(sizes, dtype, storage_type, sobj_idx)) {}
77
TmpTensor(ComputeGraph * const graph_ptr,const std::vector<int64_t> & sizes,const vkapi::ScalarType dtype,const utils::GPUMemoryLayout memory_layout)78 TmpTensor::TmpTensor(
79 ComputeGraph* const graph_ptr,
80 const std::vector<int64_t>& sizes,
81 const vkapi::ScalarType dtype,
82 const utils::GPUMemoryLayout memory_layout)
83 : graph_p(graph_ptr),
84 sobj_idx(get_sobj_idx()),
85 vref(graph_p->add_tensor(sizes, dtype, memory_layout, sobj_idx)) {}
86
TmpTensor(ComputeGraph * const graph_ptr,const std::vector<int64_t> & sizes,const vkapi::ScalarType dtype)87 TmpTensor::TmpTensor(
88 ComputeGraph* const graph_ptr,
89 const std::vector<int64_t>& sizes,
90 const vkapi::ScalarType dtype)
91 : graph_p(graph_ptr),
92 sobj_idx(get_sobj_idx()),
93 vref(graph_p->add_tensor(sizes, dtype, sobj_idx)) {}
94
~TmpTensor()95 TmpTensor::~TmpTensor() {
96 // Lifetime of this temporary tensor is expired; return the shared object to
97 // the pool, as long as the sobj index is valid
98 if (sobj_idx >= 0) {
99 graph_p->tmp_shared_object_idxs_.emplace(sobj_idx);
100 }
101 }
102
get_sobj_idx()103 int64_t TmpTensor::get_sobj_idx() {
104 int64_t sobj_idx;
105 // If no available temporary shared objects, request a new one to be created
106 if (graph_p->tmp_shared_object_idxs_.empty()) {
107 sobj_idx = graph_p->shared_objects_.size();
108 } else {
109 // Get the first available shared object idx
110 sobj_idx = graph_p->tmp_shared_object_idxs_.top();
111 graph_p->tmp_shared_object_idxs_.pop();
112 }
113 return sobj_idx;
114 }
115
116 //
117 // ComputeGraph
118 //
119
ComputeGraph(GraphConfig config)120 ComputeGraph::ComputeGraph(GraphConfig config)
121 : config_{config},
122 prepack_descriptor_counts_{},
123 execute_descriptor_counts_{},
124 context_{new api::Context(
125 vkapi::runtime()->default_adapter_i(),
126 config_.context_config)},
127 shared_objects_{},
128 values_{},
129 param_ubos_{},
130 prepack_nodes_{},
131 execute_nodes_{},
132 inputs_{},
133 outputs_{} {
134 // Ensure that descriptor counts are initialized to 0
135 prepack_descriptor_counts_.descriptor_pool_max_sets = 0;
136 prepack_descriptor_counts_.descriptor_uniform_buffer_count = 0;
137 prepack_descriptor_counts_.descriptor_storage_buffer_count = 0;
138 prepack_descriptor_counts_.descriptor_combined_sampler_count = 0;
139 prepack_descriptor_counts_.descriptor_storage_image_count = 0;
140
141 execute_descriptor_counts_.descriptor_pool_max_sets = 0;
142 execute_descriptor_counts_.descriptor_uniform_buffer_count = 0;
143 execute_descriptor_counts_.descriptor_storage_buffer_count = 0;
144 execute_descriptor_counts_.descriptor_combined_sampler_count = 0;
145 execute_descriptor_counts_.descriptor_storage_image_count = 0;
146
147 context_->set_cmd(/*reusable = */ true);
148 }
149
~ComputeGraph()150 ComputeGraph::~ComputeGraph() {
151 values_.clear();
152
153 prepack_nodes_.clear();
154 execute_nodes_.clear();
155
156 context_->flush();
157 }
158
suggested_storage_type()159 utils::StorageType ComputeGraph::suggested_storage_type() {
160 if (config_.enable_storage_type_override) {
161 return config_.storage_type_override;
162 }
163 return utils::kTexture3D;
164 }
165
suggested_memory_layout(const std::vector<int64_t> & sizes)166 utils::GPUMemoryLayout ComputeGraph::suggested_memory_layout(
167 const std::vector<int64_t>& sizes) {
168 if (config_.enable_memory_layout_override) {
169 return config_.memory_layout_override;
170 }
171 if (sizes.size() < 3) {
172 return utils::kWidthPacked;
173 }
174 // For 3 dimensional tensors that only have a channels dimension of 1, still
175 // prefer width packed.
176 if (utils::val_at(-3, sizes) == 1) {
177 return utils::kWidthPacked;
178 }
179 return utils::kChannelsPacked;
180 }
181
check_no_active_value_ptrs()182 void ComputeGraph::check_no_active_value_ptrs() {
183 VK_CHECK_COND(
184 values_in_use_ == 0,
185 "Make sure that there are no pointers stored from the return values of "
186 "`ComputeGraph::get_*()` functions in scope before adding Values to the "
187 "graph. Modifying the graph's values may cause existing pointers to be "
188 "invalidated.");
189 }
190
sizes_of(const ValueRef idx) const191 std::vector<int64_t> ComputeGraph::sizes_of(const ValueRef idx) const {
192 const Value& val = values_.at(idx);
193 if (val.isTensor()) {
194 return val.toConstTensor().sizes();
195 } else if (val.isTensorRef()) {
196 return val.toConstTensorRef().sizes;
197 }
198 VK_THROW("Could not get sizes of value with type ", val.type());
199 }
200
dim_of(const ValueRef idx) const201 int64_t ComputeGraph::dim_of(const ValueRef idx) const {
202 const Value& val = values_.at(idx);
203 if (val.isTensor()) {
204 return val.toConstTensor().dim();
205 } else if (val.isTensorRef()) {
206 return val.toConstTensorRef().sizes.size();
207 }
208 VK_THROW("Could not get dim of value with type ", val.type());
209 }
210
dim_order_of(const ValueRef idx) const211 std::vector<int64_t> ComputeGraph::dim_order_of(const ValueRef idx) const {
212 const Value& val = values_.at(idx);
213 if (val.isTensor()) {
214 return val.toConstTensor().dim_order();
215 }
216 VK_THROW("Could not get dim order of value with type ", val.type());
217 }
218
strides_of(const ValueRef idx) const219 std::vector<int64_t> ComputeGraph::strides_of(const ValueRef idx) const {
220 const Value& val = values_.at(idx);
221 if (val.isTensor()) {
222 return val.toConstTensor().strides();
223 }
224 VK_THROW("Could not get strides of value with type ", val.type());
225 }
226
dtype_of(const ValueRef idx) const227 vkapi::ScalarType ComputeGraph::dtype_of(const ValueRef idx) const {
228 const Value& val = values_.at(idx);
229 if (val.isTensor()) {
230 return val.toConstTensor().dtype();
231 } else if (val.isTensorRef()) {
232 return val.toConstTensorRef().dtype;
233 }
234 VK_THROW("Could not get dtype of value with type ", val.type());
235 }
236
add_tensor(const std::vector<int64_t> & sizes,const vkapi::ScalarType dtype,const utils::StorageType storage_type,const utils::GPUMemoryLayout memory_layout,const int64_t shared_object_idx)237 ValueRef ComputeGraph::add_tensor(
238 const std::vector<int64_t>& sizes,
239 const vkapi::ScalarType dtype,
240 const utils::StorageType storage_type,
241 const utils::GPUMemoryLayout memory_layout,
242 const int64_t shared_object_idx) {
243 bool allocate_memory = shared_object_idx < 0;
244
245 ValueRef idx(static_cast<int>(values_.size()));
246 check_no_active_value_ptrs();
247 values_.emplace_back(api::vTensor(
248 context(), sizes, dtype, storage_type, memory_layout, allocate_memory));
249
250 if (!allocate_memory) {
251 get_shared_object(shared_object_idx).add_user(this, idx);
252 }
253 return idx;
254 }
255
add_tensor(const std::vector<int64_t> & sizes,const vkapi::ScalarType dtype,const utils::StorageType storage_type,const int64_t shared_object_idx)256 ValueRef ComputeGraph::add_tensor(
257 const std::vector<int64_t>& sizes,
258 const vkapi::ScalarType dtype,
259 const utils::StorageType storage_type,
260 const int64_t shared_object_idx) {
261 return add_tensor(
262 sizes,
263 dtype,
264 storage_type,
265 suggested_memory_layout(sizes),
266 shared_object_idx);
267 }
268
add_tensor(const std::vector<int64_t> & sizes,const vkapi::ScalarType dtype,const utils::GPUMemoryLayout memory_layout,const int64_t shared_object_idx)269 ValueRef ComputeGraph::add_tensor(
270 const std::vector<int64_t>& sizes,
271 const vkapi::ScalarType dtype,
272 const utils::GPUMemoryLayout memory_layout,
273 const int64_t shared_object_idx) {
274 return add_tensor(
275 sizes, dtype, suggested_storage_type(), memory_layout, shared_object_idx);
276 }
277
add_tensor_like(const ValueRef idx,const utils::StorageType storage_type,const utils::GPUMemoryLayout memory_layout)278 ValueRef ComputeGraph::add_tensor_like(
279 const ValueRef idx,
280 const utils::StorageType storage_type,
281 const utils::GPUMemoryLayout memory_layout) {
282 return add_tensor(sizes_of(idx), dtype_of(idx), storage_type, memory_layout);
283 }
284
add_tensor_like(const ValueRef idx,const utils::GPUMemoryLayout memory_layout)285 ValueRef ComputeGraph::add_tensor_like(
286 const ValueRef idx,
287 const utils::GPUMemoryLayout memory_layout) {
288 return add_tensor(
289 sizes_of(idx), dtype_of(idx), storage_type_of(idx), memory_layout);
290 }
291
add_tensor(const std::vector<int64_t> & sizes,const vkapi::ScalarType dtype,const int64_t shared_object_idx)292 ValueRef ComputeGraph::add_tensor(
293 const std::vector<int64_t>& sizes,
294 const vkapi::ScalarType dtype,
295 const int64_t shared_object_idx) {
296 return add_tensor(
297 sizes, dtype, suggested_memory_layout(sizes), shared_object_idx);
298 }
299
add_tensor(const vkapi::VulkanImage & image)300 ValueRef ComputeGraph::add_tensor(const vkapi::VulkanImage& image) {
301 ValueRef idx(static_cast<int>(values_.size()));
302 check_no_active_value_ptrs();
303 values_.emplace_back(api::vTensor(context(), image));
304 return idx;
305 }
306
add_tensor_view(const ValueRef vref)307 ValueRef ComputeGraph::add_tensor_view(const ValueRef vref) {
308 const vTensorPtr t = get_tensor(vref);
309 ValueRef idx(static_cast<int>(values_.size()));
310 values_.emplace_back(api::vTensor(*t));
311 for (SharedObject& sobj : shared_objects_) {
312 if (sobj.has_user(vref)) {
313 sobj.add_user(this, idx);
314 }
315 }
316 return idx;
317 }
318
add_tensor_view(const ValueRef vref,const std::vector<int64_t> & sizes,const std::vector<int64_t> & strides,const size_t offset_numel)319 ValueRef ComputeGraph::add_tensor_view(
320 const ValueRef vref,
321 const std::vector<int64_t>& sizes,
322 const std::vector<int64_t>& strides,
323 const size_t offset_numel) {
324 const vTensorPtr t = get_tensor(vref);
325 ValueRef idx(static_cast<int>(values_.size()));
326 values_.emplace_back(api::vTensor(*t, sizes, strides, offset_numel));
327 for (SharedObject& sobj : shared_objects_) {
328 if (sobj.has_user(vref)) {
329 sobj.add_user(this, idx);
330 }
331 }
332 return idx;
333 }
334
add_tensorref(const std::vector<int64_t> & sizes,const vkapi::ScalarType dtype,const void * const data)335 ValueRef ComputeGraph::add_tensorref(
336 const std::vector<int64_t>& sizes,
337 const vkapi::ScalarType dtype,
338 const void* const data) {
339 ValueRef idx(static_cast<int>(values_.size()));
340 check_no_active_value_ptrs();
341 values_.emplace_back(TensorRef(sizes, dtype, data));
342 return idx;
343 }
344
add_staging(const vkapi::ScalarType dtype,const size_t numel)345 ValueRef ComputeGraph::add_staging(
346 const vkapi::ScalarType dtype,
347 const size_t numel) {
348 ValueRef idx(static_cast<int>(values_.size()));
349 check_no_active_value_ptrs();
350 values_.emplace_back(api::StagingBuffer(context(), dtype, numel));
351 return idx;
352 }
353
add_none()354 ValueRef ComputeGraph::add_none() {
355 ValueRef idx(static_cast<int>(values_.size()));
356 check_no_active_value_ptrs();
357 values_.emplace_back();
358 return idx;
359 }
360
add_value_list(std::vector<ValueRef> && value)361 ValueRef ComputeGraph::add_value_list(std::vector<ValueRef>&& value) {
362 ValueRef idx(static_cast<int>(values_.size()));
363 check_no_active_value_ptrs();
364 values_.emplace_back(std::move(value));
365 return idx;
366 }
367
add_string(std::string && str)368 ValueRef ComputeGraph::add_string(std::string&& str) {
369 ValueRef idx(static_cast<int>(values_.size()));
370 check_no_active_value_ptrs();
371 values_.emplace_back(std::move(str));
372 return idx;
373 }
374
add_symint(const int32_t val)375 ValueRef ComputeGraph::add_symint(const int32_t val) {
376 ValueRef idx(static_cast<int>(values_.size()));
377 check_no_active_value_ptrs();
378 values_.emplace_back(SymInt(context(), val));
379 return idx;
380 }
381
set_input_tensor(const ValueRef idx,const bool use_staging)382 ValueRef ComputeGraph::set_input_tensor(
383 const ValueRef idx,
384 const bool use_staging) {
385 if (use_staging) {
386 vkapi::ScalarType dtype = get_tensor(idx)->dtype();
387 // For texture storage, the buffer size needs to account for the zero
388 // padding applied by unused texel elements.
389 size_t buf_numel = get_tensor(idx)->staging_buffer_numel();
390 ValueRef staging_idx = add_staging(dtype, buf_numel);
391 add_staging_to_tensor_node(*this, staging_idx, idx);
392 inputs_.push_back({idx, staging_idx});
393 return staging_idx;
394 }
395 inputs_.push_back({idx, kDummyValueRef});
396 return idx;
397 }
398
set_output_tensor(const ValueRef idx,const bool use_staging)399 ValueRef ComputeGraph::set_output_tensor(
400 const ValueRef idx,
401 const bool use_staging) {
402 if (use_staging) {
403 vkapi::ScalarType dtype = get_tensor(idx)->dtype();
404 // For texture storage, the buffer size needs to account for the zero
405 // padding applied by unused texel elements.
406 size_t buf_numel = get_tensor(idx)->staging_buffer_numel();
407 ValueRef staging_idx = add_staging(dtype, buf_numel);
408 // We only run this when the tensor is non-empty. When the underlying
409 // tensor is empty (e.g. padded_numel == 0), we do not allocate a VkImage to
410 // tensor, we will not be able to bind the node for execution.
411 if (buf_numel > 0) {
412 add_tensor_to_staging_node(*this, idx, staging_idx);
413 }
414 outputs_.push_back({idx, staging_idx});
415 return staging_idx;
416 }
417 outputs_.push_back({idx, kDummyValueRef});
418 return idx;
419 }
420
get_or_create_int_param_buffer(const ValueRef idx)421 vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer(
422 const ValueRef idx) {
423 if (values_.at(idx).isInt()) {
424 const int32_t val = extract_scalar<int32_t>(idx);
425 create_params_buffer(val);
426 } else if (values_.at(idx).isSymInt()) {
427 SymIntPtr symint = get_symint(idx);
428 return vkapi::BufferBindInfo(symint->gpu_buffer.buffer());
429 }
430 VK_THROW("Cannot create a int param buffer for the given value");
431 }
432
set_symint(const ValueRef idx,const int32_t val)433 void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) {
434 get_symint(idx)->set(val);
435 }
436
read_symint(const ValueRef idx)437 int32_t ComputeGraph::read_symint(const ValueRef idx) {
438 return get_symint(idx)->get();
439 }
440
get_shared_object(const int64_t idx)441 SharedObject& ComputeGraph::get_shared_object(const int64_t idx) {
442 if (idx >= shared_objects_.size()) {
443 shared_objects_.resize(static_cast<size_t>(idx + 1));
444 }
445 return shared_objects_.at(idx);
446 }
447
update_descriptor_counts(const vkapi::ShaderInfo & shader_info,bool execute)448 void ComputeGraph::update_descriptor_counts(
449 const vkapi::ShaderInfo& shader_info,
450 bool execute) {
451 vkapi::DescriptorPoolConfig* config =
452 execute ? &execute_descriptor_counts_ : &prepack_descriptor_counts_;
453
454 config->descriptor_pool_max_sets += 1;
455 for (const VkDescriptorType arg_type : shader_info.kernel_layout) {
456 switch (arg_type) {
457 case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
458 config->descriptor_uniform_buffer_count += 1;
459 break;
460 case VK_DESCRIPTOR_TYPE_STORAGE_BUFFER:
461 config->descriptor_storage_buffer_count += 1;
462 break;
463 case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:
464 config->descriptor_combined_sampler_count += 1;
465 break;
466 case VK_DESCRIPTOR_TYPE_STORAGE_IMAGE:
467 config->descriptor_storage_image_count += 1;
468 break;
469 default:
470 VK_THROW("Unsupported descriptor type!");
471 }
472 }
473 }
474
create_global_wg_size(const ValueRef idx)475 utils::uvec3 ComputeGraph::create_global_wg_size(const ValueRef idx) {
476 if (is_buffer_storage(idx)) {
477 return {uint32_t(numel_of(idx)), 1u, 1u};
478 }
479 return logical_limits_of(idx);
480 }
481
create_local_wg_size(const utils::uvec3 global_wg_size)482 utils::uvec3 ComputeGraph::create_local_wg_size(
483 const utils::uvec3 global_wg_size) {
484 if (config_.enable_local_wg_size_override) {
485 return config_.local_wg_size_override;
486 }
487
488 // array containing axis index and global workgroup size
489 std::pair<uint32_t, uint32_t> global_wg_size_desc[] = {
490 {0u, global_wg_size[0]},
491 {1u, global_wg_size[1]},
492 {2u, global_wg_size[2]}};
493
494 // sort the global workgroup size in descending order
495 if (global_wg_size_desc[0].second < global_wg_size_desc[1].second) {
496 std::swap(global_wg_size_desc[0], global_wg_size_desc[1]);
497 }
498 if (global_wg_size_desc[1].second < global_wg_size_desc[2].second) {
499 std::swap(global_wg_size_desc[1], global_wg_size_desc[2]);
500 }
501 if (global_wg_size_desc[0].second < global_wg_size_desc[1].second) {
502 std::swap(global_wg_size_desc[0], global_wg_size_desc[1]);
503 }
504
505 utils::uvec3 local_group_size = {
506 8,
507 std::max(1u, std::min(4u, global_wg_size_desc[1].second)),
508 std::max(1u, std::min(2u, global_wg_size_desc[2].second))};
509
510 if (global_wg_size_desc[2u].second == 1) {
511 if (global_wg_size_desc[1u].second == 1) {
512 local_group_size[0u] = 64;
513 local_group_size[1u] = 1;
514 } else if (global_wg_size_desc[1u].second % 4 == 0) {
515 local_group_size[0u] = 16;
516 local_group_size[1u] = 4;
517 } else {
518 local_group_size[0u] = 32;
519 local_group_size[1u] = 2;
520 }
521 }
522
523 return {
524 local_group_size[global_wg_size_desc[0].first],
525 local_group_size[global_wg_size_desc[1].first],
526 local_group_size[global_wg_size_desc[2].first]};
527 }
528
create_local_wg_size(const ValueRef idx)529 utils::uvec3 ComputeGraph::create_local_wg_size(const ValueRef idx) {
530 return create_local_wg_size(create_global_wg_size(idx));
531 }
532
copy_into_staging(const ValueRef idx,const void * data,const size_t numel)533 void ComputeGraph::copy_into_staging(
534 const ValueRef idx,
535 const void* data,
536 const size_t numel) {
537 StagingPtr staging = get_staging(idx);
538 size_t nbytes = numel * vkapi::element_size(staging->dtype());
539 staging->copy_from(data, nbytes);
540 }
541
copy_from_staging(const ValueRef idx,void * data,const size_t numel)542 void ComputeGraph::copy_from_staging(
543 const ValueRef idx,
544 void* data,
545 const size_t numel) {
546 StagingPtr staging = get_staging(idx);
547 size_t nbytes = numel * vkapi::element_size(staging->dtype());
548 staging->copy_to(data, nbytes);
549 }
550
prepare()551 void ComputeGraph::prepare() {
552 #define MERGE_FIELD(field) \
553 static_cast<uint32_t>(std::ceil( \
554 std::max( \
555 execute_descriptor_counts_.field, \
556 prepack_descriptor_counts_.field) * \
557 config_.descriptor_pool_safety_factor))
558
559 uint32_t max_sets = MERGE_FIELD(descriptor_pool_max_sets);
560 vkapi::DescriptorPoolConfig config{
561 max_sets,
562 std::max(MERGE_FIELD(descriptor_uniform_buffer_count), max_sets),
563 std::max(MERGE_FIELD(descriptor_storage_buffer_count), max_sets),
564 std::max(MERGE_FIELD(descriptor_combined_sampler_count), max_sets),
565 std::max(MERGE_FIELD(descriptor_storage_image_count), max_sets),
566 1u,
567 };
568
569 if (!context_->descriptor_pool()) {
570 context_->descriptor_pool().init(config);
571 }
572 #undef MERGE_FIELD
573
574 if (config_.enable_querypool) {
575 context_->initialize_querypool();
576 }
577 }
578
encode_prepack()579 void ComputeGraph::encode_prepack() {
580 for (std::unique_ptr<PrepackNode>& node : prepack_nodes_) {
581 node->encode(this);
582 }
583 }
584
prepack() const585 void ComputeGraph::prepack() const {
586 // Submit and execute the command buffer
587 vkapi::VulkanFence fence = context_->fences().get_fence();
588 context_->submit_cmd_to_gpu(fence.get_submit_handle(), /*final_use = */ true);
589 fence.wait();
590
591 context_->flush();
592 }
593
encode_execute()594 void ComputeGraph::encode_execute() {
595 context_->flush();
596 context_->set_cmd(/*reusable = */ true);
597
598 context_->cmd_reset_querypool();
599
600 for (SharedObject& shared_object : shared_objects_) {
601 shared_object.allocate(this);
602 shared_object.bind_users(this);
603 }
604
605 for (std::unique_ptr<ExecuteNode>& node : execute_nodes_) {
606 node->encode(this);
607 }
608 }
609
execute() const610 void ComputeGraph::execute() const {
611 vkapi::VulkanFence fence = context_->fences().get_fence();
612 context_->submit_cmd_to_gpu(fence.get_submit_handle());
613 fence.wait();
614 }
615
resize_input(const int64_t idx,const std::vector<int64_t> & new_sizes)616 void ComputeGraph::resize_input(
617 const int64_t idx,
618 const std::vector<int64_t>& new_sizes) {
619 IOValueRef io_val = inputs_.at(idx);
620 get_tensor(io_val.value)->virtual_resize(new_sizes);
621 }
622
propagate_resize()623 void ComputeGraph::propagate_resize() {
624 for (std::unique_ptr<ExecuteNode>& node : execute_nodes_) {
625 node->trigger_resize(this);
626 }
627 }
628
629 } // namespace vkcompute
630