xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10 
11 #include <executorch/backends/vulkan/runtime/graph/ops/impl/MatMul.h>
12 #include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
13 
14 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/ScalarUtils.h>
15 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
16 
17 #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
18 
19 namespace vkcompute {
20 
check_matmul_args(const ComputeGraph & graph,const ValueRef mat1,const ValueRef mat2_data,const ValueRef out)21 void check_matmul_args(
22     const ComputeGraph& graph,
23     const ValueRef mat1,
24     const ValueRef mat2_data,
25     const ValueRef out) {
26   std::vector<int64_t> mat1_sizes = graph.sizes_of(mat1);
27   std::vector<int64_t> mat2_sizes = graph.sizes_of(mat2_data);
28 
29   VK_CHECK_COND(mat1_sizes.size() == 2 || mat1_sizes.size() == 3);
30   VK_CHECK_COND(mat1_sizes.size() == mat2_sizes.size());
31 
32   VK_CHECK_COND(graph.packed_dim_of(mat1) == graph.packed_dim_of(out));
33 
34   VK_CHECK_COND(utils::val_at(-1, mat1_sizes) == utils::val_at(-2, mat2_sizes));
35 }
36 
resize_matmul_node(ComputeGraph * graph,const std::vector<ArgGroup> & args,const std::vector<ValueRef> & extra_args)37 void resize_matmul_node(
38     ComputeGraph* graph,
39     const std::vector<ArgGroup>& args,
40     const std::vector<ValueRef>& extra_args) {
41   vTensorPtr out = graph->get_tensor(args[0].refs[0]);
42   vTensorPtr mat1 = graph->get_tensor(args[1].refs[0]);
43   vTensorPtr mat2 = graph->get_tensor(args[1].refs[1]);
44 
45   bool mat2_is_transposed = graph->get_bool(extra_args[0]);
46 
47   const int out_cols = utils::val_at(-2, mat1->sizes());
48   const int out_rows = mat2_is_transposed ? utils::val_at(-2, mat2->sizes())
49                                           : utils::val_at(-1, mat2->sizes());
50 
51   const int64_t out_dim = out->dim();
52   std::vector<int64_t> new_out_sizes(mat1->sizes());
53   new_out_sizes.at(out_dim - 1) = out_rows;
54   new_out_sizes.at(out_dim - 2) = out_cols;
55 
56   out->virtual_resize(new_out_sizes);
57 }
58 
add_matmul_naive_buffer_node(ComputeGraph & graph,const ValueRef mat1,const ValueRef mat2_data,const ValueRef out,const ValueRef mat2_is_transposed)59 void add_matmul_naive_buffer_node(
60     ComputeGraph& graph,
61     const ValueRef mat1,
62     const ValueRef mat2_data,
63     const ValueRef out,
64     const ValueRef mat2_is_transposed) {
65   ValueRef mat2 = prepack_standard(
66       graph,
67       mat2_data,
68       graph.storage_type_of(out),
69       utils::kHeightPacked,
70       /*passthrough = */ true);
71 
72   std::string kernel_name = "matmul_naive_buffer";
73   add_dtype_suffix(kernel_name, graph.dtype_of(out));
74 
75   utils::uvec3 global_size = {
76       graph.size_at<uint32_t>(-1, out),
77       graph.size_at<uint32_t>(-2, out),
78       graph.size_at<uint32_t>(-3, out) * graph.size_at<uint32_t>(-4, out)};
79 
80   int mat2_is_transposed_val = (mat2_is_transposed != kDummyValueRef &&
81                                 graph.get_bool(mat2_is_transposed))
82       ? 1
83       : 0;
84 
85   graph.execute_nodes().emplace_back(new DispatchNode(
86       graph,
87       VK_KERNEL_FROM_STR(kernel_name),
88       global_size,
89       graph.create_local_wg_size(global_size),
90       // Inputs and Outputs
91       {{out, vkapi::MemoryAccessType::WRITE},
92        {{mat1, mat2}, vkapi::MemoryAccessType::READ}},
93       // Shader params buffers
94       {
95           graph.sizes_ubo(out),
96           graph.strides_ubo(out),
97           graph.sizes_ubo(mat1),
98           graph.strides_ubo(mat1),
99           graph.sizes_ubo(mat2),
100           graph.strides_ubo(mat2),
101           graph.numel_ubo(out),
102       },
103       // Specialization Constants
104       {mat2_is_transposed_val},
105       // Resizing Logic
106       resize_matmul_node,
107       {mat2_is_transposed}));
108 }
109 
add_matmul_naive_texture3d_node(ComputeGraph & graph,const ValueRef mat1,const ValueRef mat2_data,const ValueRef out,const ValueRef mat2_is_transposed)110 void add_matmul_naive_texture3d_node(
111     ComputeGraph& graph,
112     const ValueRef mat1,
113     const ValueRef mat2_data,
114     const ValueRef out,
115     const ValueRef mat2_is_transposed) {
116   ValueRef mat2 = prepack_standard(
117       graph,
118       mat2_data,
119       graph.storage_type_of(out),
120       utils::kHeightPacked,
121       /*passthrough = */ true);
122 
123   std::string kernel_name = graph.get_bool(mat2_is_transposed)
124       ? "matmul_transposed_naive"
125       : "matmul_naive";
126   kernel_name.reserve(kShaderNameReserve);
127   add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
128   add_dtype_suffix(kernel_name, graph.dtype_of(out));
129 
130   utils::uvec3 global_wg_size = graph.logical_limits_of(out);
131   graph.execute_nodes().emplace_back(new DispatchNode(
132       graph,
133       VK_KERNEL_FROM_STR(kernel_name),
134       global_wg_size,
135       graph.create_local_wg_size(global_wg_size),
136       // Inputs and Outputs
137       {{out, vkapi::MemoryAccessType::WRITE},
138        {{mat1, mat2}, vkapi::MemoryAccessType::READ}},
139       // Shader params buffers
140       {
141           graph.sizes_ubo(out),
142           graph.logical_limits_ubo(out),
143           graph.sizes_ubo(mat1),
144           graph.sizes_ubo(mat2),
145       },
146       // Specialization Constants
147       {graph.hashed_layout_of(out),
148        graph.hashed_layout_of(mat1),
149        graph.hashed_layout_of(mat2)},
150       // Resizing Logic
151       resize_matmul_node,
152       {mat2_is_transposed}));
153 }
154 
add_matmul_optimized_node(ComputeGraph & graph,const ValueRef mat1,const ValueRef mat2_data,const ValueRef out,const ValueRef mat2_is_transposed)155 void add_matmul_optimized_node(
156     ComputeGraph& graph,
157     const ValueRef mat1,
158     const ValueRef mat2_data,
159     const ValueRef out,
160     const ValueRef mat2_is_transposed) {
161   ValueRef mat2 = prepack_standard(
162       graph,
163       mat2_data,
164       graph.storage_type_of(out),
165       utils::kHeightPacked,
166       /*passthrough = */ true);
167 
168   // Ensure mat1 is width packed
169   ValueRef mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked);
170   auto viewFn = VK_GET_OP_FN("aten.view_copy.default");
171   viewFn(graph, {mat1, graph.add_none(), mat1_W_packed});
172 
173   const bool mat2_is_transposed_val = graph.get_bool(mat2_is_transposed);
174 
175   // Ensure mat2 to height packed
176   ValueRef mat2_packed = mat2;
177   const utils::GPUMemoryLayout mat2_layout =
178       mat2_is_transposed_val ? utils::kWidthPacked : utils::kHeightPacked;
179   if (graph.estimate_memory_layout_of(mat2) != mat2_layout) {
180     mat2_packed = graph.add_tensor_like(mat2, mat2_layout);
181     viewFn(graph, {mat2, graph.add_none(), mat2_packed});
182   }
183 
184   std::string kernel_name = mat2_is_transposed_val
185       ? "matmul_transposed_optimized"
186       : "matmul_optimized";
187 
188   std::vector<int64_t> mat1_sizes = graph.sizes_of(mat1_W_packed);
189   int mat1_dims = mat1_sizes.size();
190   if (mat1_dims == 3) {
191     kernel_name = "batch_" + kernel_name;
192   }
193   if (mat1_sizes.at(mat1_dims - 2) < 8) {
194     kernel_name += "_tile_row_2";
195   } else {
196     kernel_name += "_tile_row_4";
197   }
198 
199   add_dtype_suffix(kernel_name, graph.dtype_of(out));
200 
201   // Each thread computes a W=(2/4) x H=4 x C=(1/4) output tile. Therefore, the
202   // total number of threads is W/(2 or 4) x H/4 x C/1. Since the out tensor is
203   // channels packed, C does not need to be divided by 4. The "identity" of each
204   // thread is the (x, y, z) coordinate of the output tile it is computing, and
205   // this identity can be used to compute the tensor index of the top left
206   // element in the tile, which will be [W=x*(2 or 4), H=y*4, C=z*(1 or 4), N=0]
207   utils::uvec3 global_size = graph.logical_limits_of(out);
208   if (mat1_sizes.at(mat1_dims - 2) < 8) {
209     // Use `logical_extents` instead of `image_extents` because the workgroup
210     // axes need to correspond to tensor dimensions.
211     global_size = utils::divup_vec(global_size, {4, 2, 1});
212   } else {
213     global_size = utils::divup_vec(global_size, {4, 4, 1});
214   }
215 
216   utils::uvec3 local_size = adaptive_work_group_size(global_size);
217 
218   graph.execute_nodes().emplace_back(new DispatchNode(
219       graph,
220       VK_KERNEL_FROM_STR(kernel_name),
221       global_size,
222       local_size,
223       // Inputs and Outputs
224       {{out, vkapi::MemoryAccessType::WRITE},
225        {{mat1_W_packed, mat2_packed}, vkapi::MemoryAccessType::READ}},
226       // Shader params buffers
227       {
228           graph.sizes_ubo(out),
229           graph.sizes_ubo(mat1_W_packed),
230           graph.sizes_ubo(mat2_packed),
231       },
232       // Specialization Constants
233       {graph.hashed_layout_of(out),
234        graph.hashed_layout_of(mat1_W_packed),
235        graph.hashed_layout_of(mat2_packed)},
236       // Resizing Logic
237       resize_matmul_node,
238       {mat2_is_transposed}));
239 }
240 
add_matmul_node(ComputeGraph & graph,const ValueRef mat1,const ValueRef mat2_data,const ValueRef out,const ValueRef mat2_is_transposed)241 void add_matmul_node(
242     ComputeGraph& graph,
243     const ValueRef mat1,
244     const ValueRef mat2_data,
245     const ValueRef out,
246     const ValueRef mat2_is_transposed) {
247   if (graph.is_buffer_storage(out)) {
248     add_matmul_naive_buffer_node(
249         graph, mat1, mat2_data, out, mat2_is_transposed);
250   } else if (graph.packed_dim_of(mat1) == WHCN::kChannelsDim) {
251     add_matmul_optimized_node(graph, mat1, mat2_data, out, mat2_is_transposed);
252   } else if (graph.packed_dim_of(mat1) == WHCN::kWidthDim) {
253     add_matmul_naive_texture3d_node(
254         graph, mat1, mat2_data, out, mat2_is_transposed);
255   } else {
256     VK_THROW("Input texture should be channel packed or width packed.");
257   }
258 }
259 
matmul(ComputeGraph & graph,const std::vector<ValueRef> & args)260 void matmul(ComputeGraph& graph, const std::vector<ValueRef>& args) {
261   check_matmul_args(graph, args[0], args[1], args[2]);
262   const ValueRef mat2_is_transposed = graph.add_scalar(false);
263   return add_matmul_node(graph, args[0], args[1], args[2], mat2_is_transposed);
264 }
265 
266 REGISTER_OPERATORS {
267   VK_REGISTER_OP(aten.mm.default, matmul);
268   VK_REGISTER_OP(aten.bmm.default, matmul);
269 }
270 
271 } // namespace vkcompute
272