xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/graph/ops/utils/BindingUtils.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/backends/vulkan/runtime/graph/ops/utils/BindingUtils.h>
10 
11 namespace vkcompute {
12 
bind_tensor_to_descriptor_set(api::vTensor & tensor,vkapi::PipelineBarrier & pipeline_barrier,const vkapi::MemoryAccessFlags accessType,vkapi::DescriptorSet & descriptor_set,const uint32_t idx)13 void bind_tensor_to_descriptor_set(
14     api::vTensor& tensor,
15     vkapi::PipelineBarrier& pipeline_barrier,
16     const vkapi::MemoryAccessFlags accessType,
17     vkapi::DescriptorSet& descriptor_set,
18     const uint32_t idx) {
19   if (tensor.buffer()) {
20     vkapi::VulkanBuffer& buffer = tensor.buffer(
21         pipeline_barrier, vkapi::PipelineStage::COMPUTE, accessType);
22     descriptor_set.bind(idx, buffer);
23   } else {
24     vkapi::VulkanImage& image = tensor.image(
25         pipeline_barrier, vkapi::PipelineStage::COMPUTE, accessType);
26     descriptor_set.bind(idx, image);
27   }
28 }
29 
bind_values_to_descriptor_set(ComputeGraph * graph,const std::vector<ArgGroup> & args,vkapi::PipelineBarrier & pipeline_barrier,vkapi::DescriptorSet & descriptor_set,const uint32_t base_idx)30 uint32_t bind_values_to_descriptor_set(
31     ComputeGraph* graph,
32     const std::vector<ArgGroup>& args,
33     vkapi::PipelineBarrier& pipeline_barrier,
34     vkapi::DescriptorSet& descriptor_set,
35     const uint32_t base_idx) {
36   uint32_t idx = base_idx;
37   for (auto& arg : args) {
38     for (auto& ref : arg.refs) {
39       if (graph->val_is_tensor(ref)) {
40         bind_tensor_to_descriptor_set(
41             *(graph->get_tensor(ref)),
42             pipeline_barrier,
43             arg.access,
44             descriptor_set,
45             idx++);
46       } else if (graph->val_is_staging(ref)) {
47         bind_staging_to_descriptor_set(
48             *(graph->get_staging(ref)), descriptor_set, idx++);
49       } else {
50         VK_THROW("Unsupported type: ", graph->get_val_type(ref));
51       }
52     }
53   }
54   return idx;
55 }
56 
bind_params_to_descriptor_set(const vkapi::ParamsBindList & params,vkapi::DescriptorSet & descriptor_set,const uint32_t base_idx)57 uint32_t bind_params_to_descriptor_set(
58     const vkapi::ParamsBindList& params,
59     vkapi::DescriptorSet& descriptor_set,
60     const uint32_t base_idx) {
61   uint32_t idx = base_idx;
62   for (auto& param : params.bind_infos) {
63     descriptor_set.bind(idx++, param);
64   }
65   return idx;
66 }
67 
bind_staging_to_descriptor_set(api::StagingBuffer & staging,vkapi::DescriptorSet & descriptor_set,const uint32_t idx)68 void bind_staging_to_descriptor_set(
69     api::StagingBuffer& staging,
70     vkapi::DescriptorSet& descriptor_set,
71     const uint32_t idx) {
72   descriptor_set.bind(idx, staging.buffer());
73 }
74 
75 } // namespace vkcompute
76