xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/graph/ops/DispatchNode.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <executorch/backends/vulkan/runtime/api/api.h>
12 
13 #include <executorch/backends/vulkan/runtime/graph/containers/Value.h>
14 
15 #include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
16 
17 namespace vkcompute {
18 
19 class ComputeGraph;
20 
21 /*
22  * Represents a single shader execution op in a ML model.
23  */
24 class DispatchNode final : public ExecuteNode {
25   friend class ComputeGraph;
26 
27  public:
28   explicit DispatchNode(
29       ComputeGraph& graph,
30       const vkapi::ShaderInfo& shader,
31       const utils::uvec3& global_workgroup_size,
32       const utils::uvec3& local_workgroup_size,
33       const std::vector<ArgGroup>& args,
34       const vkapi::ParamsBindList& params,
35       const vkapi::SpecVarList& spec_vars = {},
36       const ResizeFunction& resize_fn = nullptr,
37       const std::vector<ValueRef>& resize_args = {});
38 
39   ~DispatchNode() override = default;
40 
41   void encode(ComputeGraph* graph) override;
42 
43  protected:
44   const vkapi::ShaderInfo shader_;
45   const utils::uvec3 global_workgroup_size_;
46   const utils::uvec3 local_workgroup_size_;
47   const vkapi::ParamsBindList params_;
48   const vkapi::SpecVarList spec_vars_;
49 
50  public:
51   operator bool() const {
52     return shader_;
53   }
54 };
55 
56 } // namespace vkcompute
57