1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/backends/vulkan/runtime/graph/ops/utils/BindingUtils.h>
10
11 namespace vkcompute {
12
bind_tensor_to_descriptor_set(api::vTensor & tensor,vkapi::PipelineBarrier & pipeline_barrier,const vkapi::MemoryAccessFlags accessType,vkapi::DescriptorSet & descriptor_set,const uint32_t idx)13 void bind_tensor_to_descriptor_set(
14 api::vTensor& tensor,
15 vkapi::PipelineBarrier& pipeline_barrier,
16 const vkapi::MemoryAccessFlags accessType,
17 vkapi::DescriptorSet& descriptor_set,
18 const uint32_t idx) {
19 if (tensor.buffer()) {
20 vkapi::VulkanBuffer& buffer = tensor.buffer(
21 pipeline_barrier, vkapi::PipelineStage::COMPUTE, accessType);
22 descriptor_set.bind(idx, buffer);
23 } else {
24 vkapi::VulkanImage& image = tensor.image(
25 pipeline_barrier, vkapi::PipelineStage::COMPUTE, accessType);
26 descriptor_set.bind(idx, image);
27 }
28 }
29
bind_values_to_descriptor_set(ComputeGraph * graph,const std::vector<ArgGroup> & args,vkapi::PipelineBarrier & pipeline_barrier,vkapi::DescriptorSet & descriptor_set,const uint32_t base_idx)30 uint32_t bind_values_to_descriptor_set(
31 ComputeGraph* graph,
32 const std::vector<ArgGroup>& args,
33 vkapi::PipelineBarrier& pipeline_barrier,
34 vkapi::DescriptorSet& descriptor_set,
35 const uint32_t base_idx) {
36 uint32_t idx = base_idx;
37 for (auto& arg : args) {
38 for (auto& ref : arg.refs) {
39 if (graph->val_is_tensor(ref)) {
40 bind_tensor_to_descriptor_set(
41 *(graph->get_tensor(ref)),
42 pipeline_barrier,
43 arg.access,
44 descriptor_set,
45 idx++);
46 } else if (graph->val_is_staging(ref)) {
47 bind_staging_to_descriptor_set(
48 *(graph->get_staging(ref)), descriptor_set, idx++);
49 } else {
50 VK_THROW("Unsupported type: ", graph->get_val_type(ref));
51 }
52 }
53 }
54 return idx;
55 }
56
bind_params_to_descriptor_set(const vkapi::ParamsBindList & params,vkapi::DescriptorSet & descriptor_set,const uint32_t base_idx)57 uint32_t bind_params_to_descriptor_set(
58 const vkapi::ParamsBindList& params,
59 vkapi::DescriptorSet& descriptor_set,
60 const uint32_t base_idx) {
61 uint32_t idx = base_idx;
62 for (auto& param : params.bind_infos) {
63 descriptor_set.bind(idx++, param);
64 }
65 return idx;
66 }
67
bind_staging_to_descriptor_set(api::StagingBuffer & staging,vkapi::DescriptorSet & descriptor_set,const uint32_t idx)68 void bind_staging_to_descriptor_set(
69 api::StagingBuffer& staging,
70 vkapi::DescriptorSet& descriptor_set,
71 const uint32_t idx) {
72 descriptor_set.bind(idx, staging.buffer());
73 }
74
75 } // namespace vkcompute
76