1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 #include <executorch/backends/vulkan/runtime/api/api.h> 12 13 #include <executorch/backends/vulkan/runtime/graph/containers/Value.h> 14 15 namespace vkcompute { 16 17 class ComputeGraph; 18 19 /* 20 * Represents a group of shader arguments (images and/or buffers), with a common 21 * access permission. 22 */ 23 struct ArgGroup { ArgGroupArgGroup24 ArgGroup(const ValueRef ref, const vkapi::MemoryAccessFlags access) 25 : refs{ref}, access(access) {} 26 ArgGroupArgGroup27 ArgGroup( 28 const std::vector<ValueRef>& refs, 29 const vkapi::MemoryAccessFlags access) 30 : refs(refs), access(access) {} 31 32 const std::vector<ValueRef> refs; 33 const vkapi::MemoryAccessFlags access; 34 }; 35 36 /* 37 * Represents a single execution op in a ML model. In graph mode, ops will be 38 * implemented in a derived class that implements encode, which will implement 39 * encoding of the shader corresponding to the op into the command buffer of a 40 * ComputeGraph. 41 */ 42 class ExecuteNode { 43 friend class ComputeGraph; 44 45 public: 46 using ResizeFunction = const std::function<void( 47 ComputeGraph*, 48 const std::vector<ArgGroup>&, 49 const std::vector<ValueRef>&)>; 50 51 /* 52 * This overload of the DispatchNode constructor is used to register ops which 53 * update a tensor view. No shader is dispatched, but the node still needs to 54 * update the view's sizes and strides after a resize. 55 */ 56 explicit ExecuteNode( 57 const ResizeFunction& resize_fn = nullptr, 58 const std::vector<ValueRef>& resize_args = {}, 59 const std::vector<ArgGroup>& args = {}, 60 const std::string& name = "Graph Node"); 61 62 virtual ~ExecuteNode() = default; 63 encode(ComputeGraph * graph)64 virtual void encode(ComputeGraph* graph) { 65 (void)graph; 66 } 67 trigger_resize(ComputeGraph * graph)68 inline void trigger_resize(ComputeGraph* graph) { 69 if (resize_fn_ != nullptr) { 70 resize_fn_(graph, args_, resize_args_); 71 } 72 } 73 set_node_id(uint32_t node_id)74 inline void set_node_id(uint32_t node_id) { 75 node_id_ = node_id; 76 } 77 name()78 inline const std::string& name() const { 79 return name_; 80 } 81 82 protected: 83 uint32_t node_id_; 84 const ResizeFunction resize_fn_; 85 const std::vector<ValueRef> resize_args_; 86 const std::vector<ArgGroup> args_; 87 const std::string name_; 88 }; 89 90 } // namespace vkcompute 91