1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10
11 #include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
12
13 #include <executorch/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h>
14
15 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
16 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
17
18 #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
19
20 namespace vkcompute {
21
resize_conv2d_node(ComputeGraph * graph,const std::vector<ArgGroup> & args,const std::vector<ValueRef> & extra_args)22 void resize_conv2d_node(
23 ComputeGraph* graph,
24 const std::vector<ArgGroup>& args,
25 const std::vector<ValueRef>& extra_args) {
26 vTensorPtr out = graph->get_tensor(args[0].refs[0]);
27 vTensorPtr self = graph->get_tensor(args[1].refs[0]);
28
29 size_t ndim = self->sizes().size();
30 std::vector<int64_t> new_out_sizes(ndim);
31 const bool transposed = graph->get_bool(extra_args[4]);
32
33 // Batch, Channel
34 if (ndim == 4) {
35 new_out_sizes.at(ndim - 4) = self->sizes().at(ndim - 4);
36 }
37
38 TensorRefPtr weight_ref = graph->get_tref(extra_args[0]);
39 const auto& weight_sizes = weight_ref->sizes;
40 new_out_sizes.at(ndim - 3) =
41 transposed ? weight_sizes.at(ndim - 3) : weight_sizes.at(ndim - 4);
42
43 // Height, Width
44 const auto& new_out_sizes_hw = calc_out_sizes_hw(
45 *graph,
46 self->sizes(),
47 extra_args[0],
48 /*kernel_size_only = */ false,
49 {extra_args[1], extra_args[2], extra_args[3], extra_args[5]},
50 transposed);
51 new_out_sizes.at(ndim - 2) = new_out_sizes_hw.at(0);
52 new_out_sizes.at(ndim - 1) = new_out_sizes_hw.at(1);
53
54 out->virtual_resize(new_out_sizes);
55 }
56
resize_conv1d_node(ComputeGraph * graph,const std::vector<ArgGroup> & args,const std::vector<ValueRef> & extra_args)57 void resize_conv1d_node(
58 ComputeGraph* graph,
59 const std::vector<ArgGroup>& args,
60 const std::vector<ValueRef>& extra_args) {
61 vTensorPtr out = graph->get_tensor(args[0].refs[0]);
62 vTensorPtr self = graph->get_tensor(args[1].refs[0]);
63 TensorRefPtr weight_ref = graph->get_tref(extra_args[0]);
64
65 int64_t stride_size = graph->get_int_list(extra_args[1])->at(0);
66 int64_t padding_size = graph->get_int_list(extra_args[2])->at(0);
67 int64_t dilation_size = graph->get_int_list(extra_args[3])->at(0);
68
69 const std::vector<int64_t>& weight_sizes = weight_ref->sizes;
70
71 const std::vector<int64_t>& in_sizes = self->sizes();
72 size_t ndim = in_sizes.size();
73 std::vector<int64_t> new_out_sizes(ndim);
74
75 int64_t kernel_size = weight_sizes.at(2);
76 int64_t in_length = in_sizes.at(2);
77
78 new_out_sizes.at(0) = in_sizes.at(0);
79 new_out_sizes.at(1) = weight_sizes.at(0);
80 new_out_sizes.at(2) = calc_out_size(
81 in_length, kernel_size, stride_size, padding_size, dilation_size, false);
82
83 out->virtual_resize(new_out_sizes);
84 }
85
prepack_biases(ComputeGraph & graph,const ValueRef vref,const ValueRef weight,const bool transposed,const utils::StorageType storage_type,const utils::GPUMemoryLayout memory_layout)86 ValueRef prepack_biases(
87 ComputeGraph& graph,
88 const ValueRef vref,
89 const ValueRef weight,
90 const bool transposed,
91 const utils::StorageType storage_type,
92 const utils::GPUMemoryLayout memory_layout) {
93 auto sizes = graph.sizes_of(weight);
94 const int64_t out_channels = transposed ? sizes.at(1) : sizes.at(0);
95
96 ValueRef v = graph.add_tensor(
97 {out_channels}, graph.dtype_of(weight), storage_type, memory_layout);
98 vTensorPtr t = graph.get_tensor(v);
99
100 vkapi::ShaderInfo shader = get_nchw_to_tensor_shader(*t);
101
102 graph.prepack_nodes().emplace_back(new PrepackNode(
103 graph,
104 shader,
105 graph.create_global_wg_size(v),
106 graph.create_local_wg_size(v),
107 vref,
108 v,
109 {t->sizes_ubo()},
110 // Specialization constants
111 {t->hashed_layout()}));
112
113 return v;
114 }
115
116 enum class Conv2dMethod : uint8_t {
117 Depthwise,
118 Pointwise,
119 SlidingWindow,
120 Transposed,
121 };
122
get_conv2d_shader(ComputeGraph & graph,const api::vTensor & t_out,const bool prepack_weights,const Conv2dMethod method,const ValueRef weight,const bool clamp_out=false)123 vkapi::ShaderInfo get_conv2d_shader(
124 ComputeGraph& graph,
125 const api::vTensor& t_out,
126 const bool prepack_weights,
127 const Conv2dMethod method,
128 const ValueRef weight,
129 const bool clamp_out = false) {
130 std::string kernel_name;
131 kernel_name.reserve(kShaderNameReserve);
132 switch (method) {
133 case Conv2dMethod::Depthwise:
134 kernel_name = "conv2d_dw";
135 if (!prepack_weights) {
136 const auto& weight_sizes = graph.get_tref(weight)->sizes;
137 if (weight_sizes.at(2) == 3 && weight_sizes.at(3) == 3) {
138 kernel_name += "_output_tile_3x3";
139 }
140 if (weight_sizes.at(2) == 5 && weight_sizes.at(3) == 5) {
141 kernel_name += "_output_tile_5x5";
142 }
143 }
144 break;
145 case Conv2dMethod::Pointwise:
146 if (prepack_weights) {
147 kernel_name = "conv2d";
148 } else {
149 kernel_name = "conv2d_pw";
150 }
151 break;
152 case Conv2dMethod::SlidingWindow:
153 kernel_name = "conv2d";
154 break;
155 case Conv2dMethod::Transposed:
156 kernel_name = "conv_transpose2d";
157 break;
158 }
159 if (prepack_weights) {
160 kernel_name += "_prepack_weights";
161 } else if (clamp_out) {
162 kernel_name += "_clamp";
163 }
164 add_dtype_suffix(kernel_name, t_out);
165
166 return VK_KERNEL_FROM_STR(kernel_name);
167 }
168
get_final_sizes(const std::vector<int64_t> & original_sizes,const Conv2dMethod method)169 std::vector<int64_t> get_final_sizes(
170 const std::vector<int64_t>& original_sizes,
171 const Conv2dMethod method) {
172 int64_t batch_padded = utils::align_up_4(utils::val_at(-4, original_sizes));
173 int64_t channels_padded =
174 utils::align_up_4(utils::val_at(-3, original_sizes));
175 int64_t height = utils::val_at(-2, original_sizes);
176 int64_t width = utils::val_at(-1, original_sizes);
177
178 switch (method) {
179 case Conv2dMethod::Depthwise:
180 return std::vector<int64_t>{4, batch_padded / 4, height * width};
181 case Conv2dMethod::Pointwise:
182 case Conv2dMethod::SlidingWindow:
183 return std::vector<int64_t>{
184 4, batch_padded * height / 4, channels_padded * width};
185 case Conv2dMethod::Transposed:
186 return std::vector<int64_t>{
187 4, channels_padded * height / 4, batch_padded * width};
188 }
189 }
190
prepack_weights(ComputeGraph & graph,const ValueRef vref,const Conv2dMethod method)191 ValueRef prepack_weights(
192 ComputeGraph& graph,
193 const ValueRef vref,
194 const Conv2dMethod method) {
195 const auto original_sizes = graph.sizes_of(vref);
196 const auto final_sizes = get_final_sizes(original_sizes, method);
197
198 ValueRef v = graph.add_tensor(
199 final_sizes,
200 graph.dtype_of(vref),
201 utils::kTexture2D,
202 utils::kChannelsPacked);
203 vTensorPtr t = graph.get_tensor(v);
204
205 vkapi::ShaderInfo shader =
206 get_conv2d_shader(graph, *t, /*prepack_weights = */ true, method, vref);
207
208 graph.prepack_nodes().emplace_back(new PrepackNode(
209 graph,
210 shader,
211 graph.create_global_wg_size(v),
212 graph.create_local_wg_size(v),
213 vref,
214 v,
215 {t->sizes_ubo(),
216 graph.create_params_buffer(
217 utils::make_ivec4(original_sizes, /*reverse = */ true))},
218 // Specialization constants
219 {SV(t->packed_dim())}));
220
221 return v;
222 }
223
check_conv_args(const api::vTensor & in,const api::vTensor & out)224 void check_conv_args(const api::vTensor& in, const api::vTensor& out) {
225 VK_CHECK_COND(check_packed_dim_is(in, WHCN::kChannelsDim));
226 VK_CHECK_COND(check_packed_dim_is(out, WHCN::kChannelsDim));
227 }
228
229 struct Conv2dParams final {
230 utils::ivec2 overlay_region;
231 int in_group_size;
232 };
233
234 struct OutputParams final {
235 float out_min;
236 float out_max;
237 };
238
create_conv2d_params(ComputeGraph & graph,const ValueRef weight,const Kernel2dParams & p,const bool transposed)239 Conv2dParams create_conv2d_params(
240 ComputeGraph& graph,
241 const ValueRef weight,
242 const Kernel2dParams& p,
243 const bool transposed) {
244 const auto& overlay_region = utils::make_ivec2({
245 p.kernel_size[0] + (p.kernel_size[0] - 1) * (p.dilation[0] - 1),
246 p.kernel_size[1] + (p.kernel_size[1] - 1) * (p.dilation[1] - 1),
247 });
248 const auto weight_sizes = graph.sizes_of(weight);
249 const int32_t in_group_size = utils::safe_downcast<int32_t>(
250 utils::align_up_4(transposed ? weight_sizes.at(0) : weight_sizes.at(1)));
251 return {overlay_region, in_group_size};
252 }
253
check_conv2d_params(const Kernel2dParams & p,const bool transposed)254 void check_conv2d_params(const Kernel2dParams& p, const bool transposed) {
255 if (transposed) {
256 if (p.dilation[0] > 1 || p.dilation[1] > 1) {
257 VK_THROW(
258 "aten.convolution.default: transposed = true, dilation > 1 is not supported yet!");
259 }
260 }
261 if ((p.padding[0] > 0 && p.kernel_size[0] > 1 && p.dilation[0] > 1) ||
262 (p.padding[1] > 0 && p.kernel_size[1] > 1 && p.dilation[1] > 1)) {
263 VK_THROW(
264 "aten.convolution.default: padding > 0 while dilation, kernel_size > 1 is not supported yet!");
265 }
266 }
267
get_conv2d_method(ComputeGraph & graph,const ValueRef weight,const int64_t groups,const bool transposed)268 Conv2dMethod get_conv2d_method(
269 ComputeGraph& graph,
270 const ValueRef weight,
271 const int64_t groups,
272 const bool transposed) {
273 const auto weight_sizes = graph.sizes_of(weight);
274 if (!transposed && weight_sizes.at(0) == groups && weight_sizes.at(1) == 1) {
275 return Conv2dMethod::Depthwise;
276 }
277 if (groups > 1) {
278 VK_THROW("aten.convolution.default: groups > 1 is not supported yet!");
279 }
280 if (transposed) {
281 return Conv2dMethod::Transposed;
282 }
283 if (weight_sizes.at(2) == 1 && weight_sizes.at(3) == 1) {
284 return Conv2dMethod::Pointwise;
285 }
286 return Conv2dMethod::SlidingWindow;
287 }
288
create_conv2d_global_wg_size(ComputeGraph & graph,const Conv2dMethod method,const ValueRef out)289 utils::uvec3 create_conv2d_global_wg_size(
290 ComputeGraph& graph,
291 const Conv2dMethod method,
292 const ValueRef out) {
293 if (method == Conv2dMethod::Pointwise) {
294 const utils::uvec3 image_extents = graph.logical_limits_of(out);
295 return {
296 utils::div_up(image_extents[0u], 2u),
297 utils::div_up(image_extents[1u], 2u),
298 image_extents[2u]};
299 } else {
300 return graph.create_global_wg_size(out);
301 }
302 }
303
add_conv2d_node(ComputeGraph & graph,const ValueRef in,const ValueRef weight_data,const ValueRef bias,const ValueRef stride,const ValueRef padding,const ValueRef dilation,const ValueRef transposed,const ValueRef output_padding,const ValueRef groups,const ValueRef out_min,const ValueRef out_max,const ValueRef out,const bool clamp_out)304 void add_conv2d_node(
305 ComputeGraph& graph,
306 const ValueRef in,
307 const ValueRef weight_data,
308 const ValueRef bias,
309 const ValueRef stride,
310 const ValueRef padding,
311 const ValueRef dilation,
312 const ValueRef transposed,
313 const ValueRef output_padding,
314 const ValueRef groups,
315 const ValueRef out_min,
316 const ValueRef out_max,
317 const ValueRef out,
318 const bool clamp_out) {
319 const bool transposed_val = graph.get_bool(transposed);
320
321 float out_min_val = 0.0f;
322 float out_max_val = 0.0f;
323 if (out_min != kDummyValueRef) {
324 out_min_val = graph.extract_scalar<float>(out_min);
325 }
326 if (out_max != kDummyValueRef) {
327 out_max_val = graph.extract_scalar<float>(out_max);
328 }
329
330 const int64_t groups_val = graph.get_int(groups);
331
332 const Conv2dMethod method =
333 get_conv2d_method(graph, weight_data, groups_val, transposed_val);
334
335 ValueRef arg_weight = prepack_weights(graph, weight_data, method);
336 ValueRef arg_bias = prepack_biases(
337 graph,
338 bias,
339 weight_data,
340 transposed_val,
341 /* storage_type = */ utils::kTexture2D,
342 /* memory_layout = */ utils::kWidthPacked);
343
344 vTensorPtr t_in = graph.get_tensor(in);
345 vTensorPtr t_out = graph.get_tensor(out);
346 if (t_in->sizes().at(0) > 1) {
347 VK_THROW("conv2d: input batch size > 1 is not supported yet!");
348 }
349 check_conv_args(*t_in, *t_out);
350
351 Kernel2dParams kernel_params = create_kernel2d_params(
352 graph,
353 weight_data,
354 /*kernel_size_only = */ false,
355 stride,
356 padding,
357 dilation);
358 Conv2dParams extra_params =
359 create_conv2d_params(graph, weight_data, kernel_params, transposed_val);
360
361 OutputParams out_params = {out_min_val, out_max_val};
362
363 check_conv2d_params(kernel_params, transposed_val);
364
365 vkapi::ShaderInfo shader = get_conv2d_shader(
366 graph,
367 *t_out,
368 /*prepack_weights = */ false,
369 method,
370 weight_data,
371 clamp_out);
372
373 graph.execute_nodes().emplace_back(new DispatchNode(
374 graph,
375 shader,
376 create_conv2d_global_wg_size(graph, method, out),
377 graph.create_local_wg_size(out),
378 // Inputs and Outputs
379 {{out, vkapi::MemoryAccessType::WRITE},
380 {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
381 // Shader params buffers
382 {
383 t_out->logical_limits_ubo(),
384 t_in->sizes_ubo(),
385 graph.create_params_buffer(kernel_params),
386 graph.create_params_buffer(extra_params),
387 graph.create_params_buffer(out_params),
388 },
389 // Specialization Constants
390 {},
391 // Resizing Logic
392 resize_conv2d_node,
393 {weight_data, stride, padding, dilation, transposed, output_padding}));
394 }
395
add_conv1d_node(ComputeGraph & graph,const ValueRef in,const ValueRef weight,const ValueRef bias,const ValueRef stride,const ValueRef padding,const ValueRef dilation,const ValueRef groups,const ValueRef out_min,const ValueRef out_max,const ValueRef out,const bool clamp_out)396 void add_conv1d_node(
397 ComputeGraph& graph,
398 const ValueRef in,
399 const ValueRef weight,
400 const ValueRef bias,
401 const ValueRef stride,
402 const ValueRef padding,
403 const ValueRef dilation,
404 const ValueRef groups,
405 const ValueRef out_min,
406 const ValueRef out_max,
407 const ValueRef out,
408 const bool clamp_out) {
409 ValueRef arg_weight = prepack_standard(
410 graph, weight, graph.storage_type_of(out), utils::kWidthPacked);
411 ValueRef arg_bias = prepack_biases(
412 graph,
413 bias,
414 weight,
415 /*transposed = */ false,
416 /*storage_type = */ utils::kTexture3D,
417 /*memory_layout = */ utils::kChannelsPacked);
418
419 float out_min_val = 0.0f;
420 float out_max_val = 0.0f;
421 if (out_min != kDummyValueRef) {
422 out_min_val = graph.extract_scalar<float>(out_min);
423 }
424 if (out_max != kDummyValueRef) {
425 out_max_val = graph.extract_scalar<float>(out_max);
426 }
427
428 vTensorPtr t_in = graph.get_tensor(in);
429 vTensorPtr t_weight = graph.get_tensor(arg_weight);
430 vTensorPtr t_bias = graph.get_tensor(arg_bias);
431 vTensorPtr t_out = graph.get_tensor(out);
432 const int64_t groups_val = graph.get_int(groups);
433
434 std::vector<int64_t> in_sizes = t_in->sizes();
435 std::vector<int64_t> weight_sizes = t_weight->sizes();
436 std::vector<int64_t> out_sizes = t_out->sizes();
437
438 check_conv_args(*t_in, *t_out);
439
440 int32_t in_channels = in_sizes.at(1);
441 int32_t out_channels = weight_sizes.at(0);
442 int32_t kernel_size = weight_sizes.at(2);
443 int32_t stride_size = graph.get_int_list(stride)->at(0);
444 int32_t padding_size = graph.get_int_list(padding)->at(0);
445 int32_t dilation_size = graph.get_int_list(dilation)->at(0);
446 int32_t in_group_size = static_cast<int64_t>(in_channels / groups_val);
447 int32_t out_group_size = static_cast<int64_t>(out_channels / groups_val);
448
449 utils::uvec3 global_size = {1, static_cast<uint32_t>(out_channels), 1};
450 utils::uvec3 local_size = {1, 64, 1};
451
452 Kernel1dParams kernel_params = {
453 kernel_size,
454 stride_size,
455 padding_size,
456 dilation_size,
457 in_group_size,
458 out_group_size};
459
460 OutputParams out_params = {out_min_val, out_max_val};
461
462 std::string kernel_name("conv1d");
463 if (clamp_out) {
464 kernel_name += "_clamp";
465 }
466 kernel_name.reserve(kShaderNameReserve);
467
468 add_dtype_suffix(kernel_name, *t_out);
469
470 graph.execute_nodes().emplace_back(new DispatchNode(
471 graph,
472 VK_KERNEL_FROM_STR(kernel_name),
473 global_size,
474 local_size,
475 // Inputs and Outputs
476 {{out, vkapi::MemoryAccessType::WRITE},
477 {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
478 // Shader params buffers
479 {
480 t_out->logical_limits_ubo(),
481 t_in->sizes_ubo(),
482 graph.create_params_buffer(kernel_params),
483 graph.create_params_buffer(out_params),
484 },
485 // Specialization Constants
486 {t_out->hashed_layout(),
487 t_in->hashed_layout(),
488 t_weight->hashed_layout(),
489 t_bias->hashed_layout()},
490 // Resizing Logic
491 resize_conv1d_node,
492 {weight, stride, padding, dilation}));
493 }
494
conv(ComputeGraph & graph,const std::vector<ValueRef> & args)495 void conv(ComputeGraph& graph, const std::vector<ValueRef>& args) {
496 int64_t in_ndim = graph.get_tensor(args[0])->sizes().size();
497 if (in_ndim == 4) {
498 if (args.size() == 10) {
499 // ordinary conv2d
500 return add_conv2d_node(
501 graph,
502 args[0],
503 args[1],
504 args[2],
505 args[3],
506 args[4],
507 args[5],
508 args[6],
509 args[7],
510 args[8],
511 /*out_min = */ kDummyValueRef,
512 /*out_max = */ kDummyValueRef,
513 args[9],
514 false);
515 } else {
516 // conv2d with clamp
517 return add_conv2d_node(
518 graph,
519 args[0],
520 args[1],
521 args[2],
522 args[3],
523 args[4],
524 args[5],
525 args[6],
526 args[7],
527 args[8],
528 args[9],
529 args[10],
530 args[11],
531 true);
532 }
533 } else {
534 if (args.size() == 10) {
535 // ordinary conv1d
536 return add_conv1d_node(
537 graph,
538 args[0],
539 args[1],
540 args[2],
541 args[3],
542 args[4],
543 args[5],
544 args[8],
545 /*out_min = */ kDummyValueRef,
546 /*out_max = */ kDummyValueRef,
547 args[9],
548 false);
549 } else {
550 // conv1d with clamp
551 return add_conv1d_node(
552 graph,
553 args[0],
554 args[1],
555 args[2],
556 args[3],
557 args[4],
558 args[5],
559 args[8],
560 args[9],
561 args[10],
562 args[11],
563 true);
564 }
565 }
566 }
567
568 REGISTER_OPERATORS {
569 VK_REGISTER_OP(aten.convolution.default, conv);
570 VK_REGISTER_OP(conv_with_clamp.default, conv);
571 VK_REGISTER_OP(et_vk.conv_with_clamp.default, conv);
572 }
573
574 } // namespace vkcompute
575