1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 #include <executorch/backends/vulkan/runtime/api/api.h> 12 13 #include <executorch/backends/vulkan/runtime/graph/containers/Value.h> 14 15 namespace vkcompute { 16 17 class ComputeGraph; 18 19 /* 20 * Represents a single prepacking op in a ML model. In graph mode, ops will be 21 * implemented in a derived class that implements encode, which will implement 22 * encoding of shaders transferring necessary data (such as weights and biases) 23 * to the GPU. 24 */ 25 class PrepackNode final { 26 friend class ComputeGraph; 27 28 public: 29 PrepackNode( 30 ComputeGraph& graph, 31 const vkapi::ShaderInfo& shader, 32 const utils::uvec3& global_workgroup_size, 33 const utils::uvec3& local_workgroup_size, 34 const ValueRef tref, 35 const ValueRef packed, 36 const vkapi::ParamsBindList& params, 37 const vkapi::SpecVarList& spec_vars = {}); 38 39 ~PrepackNode() = default; 40 41 void encode(ComputeGraph* graph); 42 set_node_id(uint32_t node_id)43 inline void set_node_id(uint32_t node_id) { 44 node_id_ = node_id; 45 } 46 47 protected: 48 uint32_t node_id_; 49 const vkapi::ShaderInfo shader_; 50 vkapi::ShaderInfo noop_shader_; 51 const utils::uvec3 global_workgroup_size_; 52 const utils::uvec3 local_workgroup_size_; 53 const ValueRef tref_; 54 const ValueRef packed_; 55 const vkapi::ParamsBindList params_; 56 const vkapi::SpecVarList spec_vars_; 57 58 private: 59 api::StagingBuffer create_staging_buffer(ComputeGraph* graph); 60 }; 61 62 } // namespace vkcompute 63