1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10
11 #include <executorch/backends/vulkan/runtime/graph/ops/impl/MatMul.h>
12 #include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
13
14 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/ScalarUtils.h>
15 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
16
17 #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
18
19 namespace vkcompute {
20
check_matmul_args(const ComputeGraph & graph,const ValueRef mat1,const ValueRef mat2_data,const ValueRef out)21 void check_matmul_args(
22 const ComputeGraph& graph,
23 const ValueRef mat1,
24 const ValueRef mat2_data,
25 const ValueRef out) {
26 std::vector<int64_t> mat1_sizes = graph.sizes_of(mat1);
27 std::vector<int64_t> mat2_sizes = graph.sizes_of(mat2_data);
28
29 VK_CHECK_COND(mat1_sizes.size() == 2 || mat1_sizes.size() == 3);
30 VK_CHECK_COND(mat1_sizes.size() == mat2_sizes.size());
31
32 VK_CHECK_COND(graph.packed_dim_of(mat1) == graph.packed_dim_of(out));
33
34 VK_CHECK_COND(utils::val_at(-1, mat1_sizes) == utils::val_at(-2, mat2_sizes));
35 }
36
resize_matmul_node(ComputeGraph * graph,const std::vector<ArgGroup> & args,const std::vector<ValueRef> & extra_args)37 void resize_matmul_node(
38 ComputeGraph* graph,
39 const std::vector<ArgGroup>& args,
40 const std::vector<ValueRef>& extra_args) {
41 vTensorPtr out = graph->get_tensor(args[0].refs[0]);
42 vTensorPtr mat1 = graph->get_tensor(args[1].refs[0]);
43 vTensorPtr mat2 = graph->get_tensor(args[1].refs[1]);
44
45 bool mat2_is_transposed = graph->get_bool(extra_args[0]);
46
47 const int out_cols = utils::val_at(-2, mat1->sizes());
48 const int out_rows = mat2_is_transposed ? utils::val_at(-2, mat2->sizes())
49 : utils::val_at(-1, mat2->sizes());
50
51 const int64_t out_dim = out->dim();
52 std::vector<int64_t> new_out_sizes(mat1->sizes());
53 new_out_sizes.at(out_dim - 1) = out_rows;
54 new_out_sizes.at(out_dim - 2) = out_cols;
55
56 out->virtual_resize(new_out_sizes);
57 }
58
add_matmul_naive_buffer_node(ComputeGraph & graph,const ValueRef mat1,const ValueRef mat2_data,const ValueRef out,const ValueRef mat2_is_transposed)59 void add_matmul_naive_buffer_node(
60 ComputeGraph& graph,
61 const ValueRef mat1,
62 const ValueRef mat2_data,
63 const ValueRef out,
64 const ValueRef mat2_is_transposed) {
65 ValueRef mat2 = prepack_standard(
66 graph,
67 mat2_data,
68 graph.storage_type_of(out),
69 utils::kHeightPacked,
70 /*passthrough = */ true);
71
72 std::string kernel_name = "matmul_naive_buffer";
73 add_dtype_suffix(kernel_name, graph.dtype_of(out));
74
75 utils::uvec3 global_size = {
76 graph.size_at<uint32_t>(-1, out),
77 graph.size_at<uint32_t>(-2, out),
78 graph.size_at<uint32_t>(-3, out) * graph.size_at<uint32_t>(-4, out)};
79
80 int mat2_is_transposed_val = (mat2_is_transposed != kDummyValueRef &&
81 graph.get_bool(mat2_is_transposed))
82 ? 1
83 : 0;
84
85 graph.execute_nodes().emplace_back(new DispatchNode(
86 graph,
87 VK_KERNEL_FROM_STR(kernel_name),
88 global_size,
89 graph.create_local_wg_size(global_size),
90 // Inputs and Outputs
91 {{out, vkapi::MemoryAccessType::WRITE},
92 {{mat1, mat2}, vkapi::MemoryAccessType::READ}},
93 // Shader params buffers
94 {
95 graph.sizes_ubo(out),
96 graph.strides_ubo(out),
97 graph.sizes_ubo(mat1),
98 graph.strides_ubo(mat1),
99 graph.sizes_ubo(mat2),
100 graph.strides_ubo(mat2),
101 graph.numel_ubo(out),
102 },
103 // Specialization Constants
104 {mat2_is_transposed_val},
105 // Resizing Logic
106 resize_matmul_node,
107 {mat2_is_transposed}));
108 }
109
add_matmul_naive_texture3d_node(ComputeGraph & graph,const ValueRef mat1,const ValueRef mat2_data,const ValueRef out,const ValueRef mat2_is_transposed)110 void add_matmul_naive_texture3d_node(
111 ComputeGraph& graph,
112 const ValueRef mat1,
113 const ValueRef mat2_data,
114 const ValueRef out,
115 const ValueRef mat2_is_transposed) {
116 ValueRef mat2 = prepack_standard(
117 graph,
118 mat2_data,
119 graph.storage_type_of(out),
120 utils::kHeightPacked,
121 /*passthrough = */ true);
122
123 std::string kernel_name = graph.get_bool(mat2_is_transposed)
124 ? "matmul_transposed_naive"
125 : "matmul_naive";
126 kernel_name.reserve(kShaderNameReserve);
127 add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
128 add_dtype_suffix(kernel_name, graph.dtype_of(out));
129
130 utils::uvec3 global_wg_size = graph.logical_limits_of(out);
131 graph.execute_nodes().emplace_back(new DispatchNode(
132 graph,
133 VK_KERNEL_FROM_STR(kernel_name),
134 global_wg_size,
135 graph.create_local_wg_size(global_wg_size),
136 // Inputs and Outputs
137 {{out, vkapi::MemoryAccessType::WRITE},
138 {{mat1, mat2}, vkapi::MemoryAccessType::READ}},
139 // Shader params buffers
140 {
141 graph.sizes_ubo(out),
142 graph.logical_limits_ubo(out),
143 graph.sizes_ubo(mat1),
144 graph.sizes_ubo(mat2),
145 },
146 // Specialization Constants
147 {graph.hashed_layout_of(out),
148 graph.hashed_layout_of(mat1),
149 graph.hashed_layout_of(mat2)},
150 // Resizing Logic
151 resize_matmul_node,
152 {mat2_is_transposed}));
153 }
154
add_matmul_optimized_node(ComputeGraph & graph,const ValueRef mat1,const ValueRef mat2_data,const ValueRef out,const ValueRef mat2_is_transposed)155 void add_matmul_optimized_node(
156 ComputeGraph& graph,
157 const ValueRef mat1,
158 const ValueRef mat2_data,
159 const ValueRef out,
160 const ValueRef mat2_is_transposed) {
161 ValueRef mat2 = prepack_standard(
162 graph,
163 mat2_data,
164 graph.storage_type_of(out),
165 utils::kHeightPacked,
166 /*passthrough = */ true);
167
168 // Ensure mat1 is width packed
169 ValueRef mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked);
170 auto viewFn = VK_GET_OP_FN("aten.view_copy.default");
171 viewFn(graph, {mat1, graph.add_none(), mat1_W_packed});
172
173 const bool mat2_is_transposed_val = graph.get_bool(mat2_is_transposed);
174
175 // Ensure mat2 to height packed
176 ValueRef mat2_packed = mat2;
177 const utils::GPUMemoryLayout mat2_layout =
178 mat2_is_transposed_val ? utils::kWidthPacked : utils::kHeightPacked;
179 if (graph.estimate_memory_layout_of(mat2) != mat2_layout) {
180 mat2_packed = graph.add_tensor_like(mat2, mat2_layout);
181 viewFn(graph, {mat2, graph.add_none(), mat2_packed});
182 }
183
184 std::string kernel_name = mat2_is_transposed_val
185 ? "matmul_transposed_optimized"
186 : "matmul_optimized";
187
188 std::vector<int64_t> mat1_sizes = graph.sizes_of(mat1_W_packed);
189 int mat1_dims = mat1_sizes.size();
190 if (mat1_dims == 3) {
191 kernel_name = "batch_" + kernel_name;
192 }
193 if (mat1_sizes.at(mat1_dims - 2) < 8) {
194 kernel_name += "_tile_row_2";
195 } else {
196 kernel_name += "_tile_row_4";
197 }
198
199 add_dtype_suffix(kernel_name, graph.dtype_of(out));
200
201 // Each thread computes a W=(2/4) x H=4 x C=(1/4) output tile. Therefore, the
202 // total number of threads is W/(2 or 4) x H/4 x C/1. Since the out tensor is
203 // channels packed, C does not need to be divided by 4. The "identity" of each
204 // thread is the (x, y, z) coordinate of the output tile it is computing, and
205 // this identity can be used to compute the tensor index of the top left
206 // element in the tile, which will be [W=x*(2 or 4), H=y*4, C=z*(1 or 4), N=0]
207 utils::uvec3 global_size = graph.logical_limits_of(out);
208 if (mat1_sizes.at(mat1_dims - 2) < 8) {
209 // Use `logical_extents` instead of `image_extents` because the workgroup
210 // axes need to correspond to tensor dimensions.
211 global_size = utils::divup_vec(global_size, {4, 2, 1});
212 } else {
213 global_size = utils::divup_vec(global_size, {4, 4, 1});
214 }
215
216 utils::uvec3 local_size = adaptive_work_group_size(global_size);
217
218 graph.execute_nodes().emplace_back(new DispatchNode(
219 graph,
220 VK_KERNEL_FROM_STR(kernel_name),
221 global_size,
222 local_size,
223 // Inputs and Outputs
224 {{out, vkapi::MemoryAccessType::WRITE},
225 {{mat1_W_packed, mat2_packed}, vkapi::MemoryAccessType::READ}},
226 // Shader params buffers
227 {
228 graph.sizes_ubo(out),
229 graph.sizes_ubo(mat1_W_packed),
230 graph.sizes_ubo(mat2_packed),
231 },
232 // Specialization Constants
233 {graph.hashed_layout_of(out),
234 graph.hashed_layout_of(mat1_W_packed),
235 graph.hashed_layout_of(mat2_packed)},
236 // Resizing Logic
237 resize_matmul_node,
238 {mat2_is_transposed}));
239 }
240
add_matmul_node(ComputeGraph & graph,const ValueRef mat1,const ValueRef mat2_data,const ValueRef out,const ValueRef mat2_is_transposed)241 void add_matmul_node(
242 ComputeGraph& graph,
243 const ValueRef mat1,
244 const ValueRef mat2_data,
245 const ValueRef out,
246 const ValueRef mat2_is_transposed) {
247 if (graph.is_buffer_storage(out)) {
248 add_matmul_naive_buffer_node(
249 graph, mat1, mat2_data, out, mat2_is_transposed);
250 } else if (graph.packed_dim_of(mat1) == WHCN::kChannelsDim) {
251 add_matmul_optimized_node(graph, mat1, mat2_data, out, mat2_is_transposed);
252 } else if (graph.packed_dim_of(mat1) == WHCN::kWidthDim) {
253 add_matmul_naive_texture3d_node(
254 graph, mat1, mat2_data, out, mat2_is_transposed);
255 } else {
256 VK_THROW("Input texture should be channel packed or width packed.");
257 }
258 }
259
matmul(ComputeGraph & graph,const std::vector<ValueRef> & args)260 void matmul(ComputeGraph& graph, const std::vector<ValueRef>& args) {
261 check_matmul_args(graph, args[0], args[1], args[2]);
262 const ValueRef mat2_is_transposed = graph.add_scalar(false);
263 return add_matmul_node(graph, args[0], args[1], args[2], mat2_is_transposed);
264 }
265
266 REGISTER_OPERATORS {
267 VK_REGISTER_OP(aten.mm.default, matmul);
268 VK_REGISTER_OP(aten.bmm.default, matmul);
269 }
270
271 } // namespace vkcompute
272