xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/graph/ops/DispatchNode.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/backends/vulkan/runtime/graph/ops/DispatchNode.h>
10 
11 #include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
12 
13 #include <executorch/backends/vulkan/runtime/graph/ops/utils/BindingUtils.h>
14 
15 namespace vkcompute {
16 
DispatchNode(ComputeGraph & graph,const vkapi::ShaderInfo & shader,const utils::uvec3 & global_workgroup_size,const utils::uvec3 & local_workgroup_size,const std::vector<ArgGroup> & args,const vkapi::ParamsBindList & params,const vkapi::SpecVarList & spec_vars,const ResizeFunction & resize_fn,const std::vector<ValueRef> & resize_args)17 DispatchNode::DispatchNode(
18     ComputeGraph& graph,
19     const vkapi::ShaderInfo& shader,
20     const utils::uvec3& global_workgroup_size,
21     const utils::uvec3& local_workgroup_size,
22     const std::vector<ArgGroup>& args,
23     const vkapi::ParamsBindList& params,
24     const vkapi::SpecVarList& spec_vars,
25     const ResizeFunction& resize_fn,
26     const std::vector<ValueRef>& resize_args)
27     : ExecuteNode(resize_fn, resize_args, args, shader.kernel_name),
28       shader_(shader),
29       global_workgroup_size_(global_workgroup_size),
30       local_workgroup_size_(local_workgroup_size),
31       params_(params),
32       spec_vars_(spec_vars) {
33   graph.update_descriptor_counts(shader, /*execute = */ true);
34 }
35 
encode(ComputeGraph * graph)36 void DispatchNode::encode(ComputeGraph* graph) {
37   if (!shader_) {
38     return;
39   }
40   api::Context* const context = graph->context();
41   vkapi::PipelineBarrier pipeline_barrier{};
42 
43   std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock();
44 
45   context->report_shader_dispatch_start(
46       shader_.kernel_name,
47       global_workgroup_size_,
48       local_workgroup_size_,
49       node_id_);
50 
51   vkapi::DescriptorSet descriptor_set =
52       context->get_descriptor_set(shader_, local_workgroup_size_, spec_vars_);
53 
54   uint32_t idx = 0;
55   idx = bind_values_to_descriptor_set(
56       graph, args_, pipeline_barrier, descriptor_set, idx);
57 
58   bind_params_to_descriptor_set(params_, descriptor_set, idx);
59 
60   context->register_shader_dispatch(
61       descriptor_set, pipeline_barrier, shader_, global_workgroup_size_);
62 
63   context->report_shader_dispatch_end();
64 }
65 
66 } // namespace vkcompute
67