xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <executorch/backends/vulkan/runtime/api/api.h>
12 
13 #include <executorch/backends/vulkan/runtime/graph/containers/Value.h>
14 
15 namespace vkcompute {
16 
17 class ComputeGraph;
18 
19 /*
20  * Represents a group of shader arguments (images and/or buffers), with a common
21  * access permission.
22  */
23 struct ArgGroup {
ArgGroupArgGroup24   ArgGroup(const ValueRef ref, const vkapi::MemoryAccessFlags access)
25       : refs{ref}, access(access) {}
26 
ArgGroupArgGroup27   ArgGroup(
28       const std::vector<ValueRef>& refs,
29       const vkapi::MemoryAccessFlags access)
30       : refs(refs), access(access) {}
31 
32   const std::vector<ValueRef> refs;
33   const vkapi::MemoryAccessFlags access;
34 };
35 
36 /*
37  * Represents a single execution op in a ML model. In graph mode, ops will be
38  * implemented in a derived class that implements encode, which will implement
39  * encoding of the shader corresponding to the op into the command buffer of a
40  * ComputeGraph.
41  */
42 class ExecuteNode {
43   friend class ComputeGraph;
44 
45  public:
46   using ResizeFunction = const std::function<void(
47       ComputeGraph*,
48       const std::vector<ArgGroup>&,
49       const std::vector<ValueRef>&)>;
50 
51   /*
52    * This overload of the DispatchNode constructor is used to register ops which
53    * update a tensor view. No shader is dispatched, but the node still needs to
54    * update the view's sizes and strides after a resize.
55    */
56   explicit ExecuteNode(
57       const ResizeFunction& resize_fn = nullptr,
58       const std::vector<ValueRef>& resize_args = {},
59       const std::vector<ArgGroup>& args = {},
60       const std::string& name = "Graph Node");
61 
62   virtual ~ExecuteNode() = default;
63 
encode(ComputeGraph * graph)64   virtual void encode(ComputeGraph* graph) {
65     (void)graph;
66   }
67 
trigger_resize(ComputeGraph * graph)68   inline void trigger_resize(ComputeGraph* graph) {
69     if (resize_fn_ != nullptr) {
70       resize_fn_(graph, args_, resize_args_);
71     }
72   }
73 
set_node_id(uint32_t node_id)74   inline void set_node_id(uint32_t node_id) {
75     node_id_ = node_id;
76   }
77 
name()78   inline const std::string& name() const {
79     return name_;
80   }
81 
82  protected:
83   uint32_t node_id_;
84   const ResizeFunction resize_fn_;
85   const std::vector<ValueRef> resize_args_;
86   const std::vector<ArgGroup> args_;
87   const std::string name_;
88 };
89 
90 } // namespace vkcompute
91