xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/graph/ops/PrepackNode.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <executorch/backends/vulkan/runtime/api/api.h>
12 
13 #include <executorch/backends/vulkan/runtime/graph/containers/Value.h>
14 
15 namespace vkcompute {
16 
17 class ComputeGraph;
18 
19 /*
20  * Represents a single prepacking op in a ML model. In graph mode, ops will be
21  * implemented in a derived class that implements encode, which will implement
22  * encoding of shaders transferring necessary data (such as weights and biases)
23  * to the GPU.
24  */
25 class PrepackNode final {
26   friend class ComputeGraph;
27 
28  public:
29   PrepackNode(
30       ComputeGraph& graph,
31       const vkapi::ShaderInfo& shader,
32       const utils::uvec3& global_workgroup_size,
33       const utils::uvec3& local_workgroup_size,
34       const ValueRef tref,
35       const ValueRef packed,
36       const vkapi::ParamsBindList& params,
37       const vkapi::SpecVarList& spec_vars = {});
38 
39   ~PrepackNode() = default;
40 
41   void encode(ComputeGraph* graph);
42 
set_node_id(uint32_t node_id)43   inline void set_node_id(uint32_t node_id) {
44     node_id_ = node_id;
45   }
46 
47  protected:
48   uint32_t node_id_;
49   const vkapi::ShaderInfo shader_;
50   vkapi::ShaderInfo noop_shader_;
51   const utils::uvec3 global_workgroup_size_;
52   const utils::uvec3 local_workgroup_size_;
53   const ValueRef tref_;
54   const ValueRef packed_;
55   const vkapi::ParamsBindList params_;
56   const vkapi::SpecVarList spec_vars_;
57 
58  private:
59   api::StagingBuffer create_staging_buffer(ComputeGraph* graph);
60 };
61 
62 } // namespace vkcompute
63