1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 #include <executorch/backends/vulkan/runtime/api/api.h> 12 13 #include <executorch/backends/vulkan/runtime/graph/containers/Value.h> 14 15 #include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h> 16 17 namespace vkcompute { 18 19 class ComputeGraph; 20 21 /* 22 * Represents a single shader execution op in a ML model. 23 */ 24 class DispatchNode final : public ExecuteNode { 25 friend class ComputeGraph; 26 27 public: 28 explicit DispatchNode( 29 ComputeGraph& graph, 30 const vkapi::ShaderInfo& shader, 31 const utils::uvec3& global_workgroup_size, 32 const utils::uvec3& local_workgroup_size, 33 const std::vector<ArgGroup>& args, 34 const vkapi::ParamsBindList& params, 35 const vkapi::SpecVarList& spec_vars = {}, 36 const ResizeFunction& resize_fn = nullptr, 37 const std::vector<ValueRef>& resize_args = {}); 38 39 ~DispatchNode() override = default; 40 41 void encode(ComputeGraph* graph) override; 42 43 protected: 44 const vkapi::ShaderInfo shader_; 45 const utils::uvec3 global_workgroup_size_; 46 const utils::uvec3 local_workgroup_size_; 47 const vkapi::ParamsBindList params_; 48 const vkapi::SpecVarList spec_vars_; 49 50 public: 51 operator bool() const { 52 return shader_; 53 } 54 }; 55 56 } // namespace vkcompute 57