xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10 
11 #include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
12 
13 #include <executorch/backends/vulkan/runtime/graph/ops/utils/StagingUtils.h>
14 
15 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
16 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
17 
18 #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
19 
20 namespace vkcompute {
21 
resize_conv2d_node(ComputeGraph * graph,const std::vector<ArgGroup> & args,const std::vector<ValueRef> & extra_args)22 void resize_conv2d_node(
23     ComputeGraph* graph,
24     const std::vector<ArgGroup>& args,
25     const std::vector<ValueRef>& extra_args) {
26   vTensorPtr out = graph->get_tensor(args[0].refs[0]);
27   vTensorPtr self = graph->get_tensor(args[1].refs[0]);
28 
29   size_t ndim = self->sizes().size();
30   std::vector<int64_t> new_out_sizes(ndim);
31   const bool transposed = graph->get_bool(extra_args[4]);
32 
33   // Batch, Channel
34   if (ndim == 4) {
35     new_out_sizes.at(ndim - 4) = self->sizes().at(ndim - 4);
36   }
37 
38   TensorRefPtr weight_ref = graph->get_tref(extra_args[0]);
39   const auto& weight_sizes = weight_ref->sizes;
40   new_out_sizes.at(ndim - 3) =
41       transposed ? weight_sizes.at(ndim - 3) : weight_sizes.at(ndim - 4);
42 
43   // Height, Width
44   const auto& new_out_sizes_hw = calc_out_sizes_hw(
45       *graph,
46       self->sizes(),
47       extra_args[0],
48       /*kernel_size_only = */ false,
49       {extra_args[1], extra_args[2], extra_args[3], extra_args[5]},
50       transposed);
51   new_out_sizes.at(ndim - 2) = new_out_sizes_hw.at(0);
52   new_out_sizes.at(ndim - 1) = new_out_sizes_hw.at(1);
53 
54   out->virtual_resize(new_out_sizes);
55 }
56 
resize_conv1d_node(ComputeGraph * graph,const std::vector<ArgGroup> & args,const std::vector<ValueRef> & extra_args)57 void resize_conv1d_node(
58     ComputeGraph* graph,
59     const std::vector<ArgGroup>& args,
60     const std::vector<ValueRef>& extra_args) {
61   vTensorPtr out = graph->get_tensor(args[0].refs[0]);
62   vTensorPtr self = graph->get_tensor(args[1].refs[0]);
63   TensorRefPtr weight_ref = graph->get_tref(extra_args[0]);
64 
65   int64_t stride_size = graph->get_int_list(extra_args[1])->at(0);
66   int64_t padding_size = graph->get_int_list(extra_args[2])->at(0);
67   int64_t dilation_size = graph->get_int_list(extra_args[3])->at(0);
68 
69   const std::vector<int64_t>& weight_sizes = weight_ref->sizes;
70 
71   const std::vector<int64_t>& in_sizes = self->sizes();
72   size_t ndim = in_sizes.size();
73   std::vector<int64_t> new_out_sizes(ndim);
74 
75   int64_t kernel_size = weight_sizes.at(2);
76   int64_t in_length = in_sizes.at(2);
77 
78   new_out_sizes.at(0) = in_sizes.at(0);
79   new_out_sizes.at(1) = weight_sizes.at(0);
80   new_out_sizes.at(2) = calc_out_size(
81       in_length, kernel_size, stride_size, padding_size, dilation_size, false);
82 
83   out->virtual_resize(new_out_sizes);
84 }
85 
prepack_biases(ComputeGraph & graph,const ValueRef vref,const ValueRef weight,const bool transposed,const utils::StorageType storage_type,const utils::GPUMemoryLayout memory_layout)86 ValueRef prepack_biases(
87     ComputeGraph& graph,
88     const ValueRef vref,
89     const ValueRef weight,
90     const bool transposed,
91     const utils::StorageType storage_type,
92     const utils::GPUMemoryLayout memory_layout) {
93   auto sizes = graph.sizes_of(weight);
94   const int64_t out_channels = transposed ? sizes.at(1) : sizes.at(0);
95 
96   ValueRef v = graph.add_tensor(
97       {out_channels}, graph.dtype_of(weight), storage_type, memory_layout);
98   vTensorPtr t = graph.get_tensor(v);
99 
100   vkapi::ShaderInfo shader = get_nchw_to_tensor_shader(*t);
101 
102   graph.prepack_nodes().emplace_back(new PrepackNode(
103       graph,
104       shader,
105       graph.create_global_wg_size(v),
106       graph.create_local_wg_size(v),
107       vref,
108       v,
109       {t->sizes_ubo()},
110       // Specialization constants
111       {t->hashed_layout()}));
112 
113   return v;
114 }
115 
116 enum class Conv2dMethod : uint8_t {
117   Depthwise,
118   Pointwise,
119   SlidingWindow,
120   Transposed,
121 };
122 
get_conv2d_shader(ComputeGraph & graph,const api::vTensor & t_out,const bool prepack_weights,const Conv2dMethod method,const ValueRef weight,const bool clamp_out=false)123 vkapi::ShaderInfo get_conv2d_shader(
124     ComputeGraph& graph,
125     const api::vTensor& t_out,
126     const bool prepack_weights,
127     const Conv2dMethod method,
128     const ValueRef weight,
129     const bool clamp_out = false) {
130   std::string kernel_name;
131   kernel_name.reserve(kShaderNameReserve);
132   switch (method) {
133     case Conv2dMethod::Depthwise:
134       kernel_name = "conv2d_dw";
135       if (!prepack_weights) {
136         const auto& weight_sizes = graph.get_tref(weight)->sizes;
137         if (weight_sizes.at(2) == 3 && weight_sizes.at(3) == 3) {
138           kernel_name += "_output_tile_3x3";
139         }
140         if (weight_sizes.at(2) == 5 && weight_sizes.at(3) == 5) {
141           kernel_name += "_output_tile_5x5";
142         }
143       }
144       break;
145     case Conv2dMethod::Pointwise:
146       if (prepack_weights) {
147         kernel_name = "conv2d";
148       } else {
149         kernel_name = "conv2d_pw";
150       }
151       break;
152     case Conv2dMethod::SlidingWindow:
153       kernel_name = "conv2d";
154       break;
155     case Conv2dMethod::Transposed:
156       kernel_name = "conv_transpose2d";
157       break;
158   }
159   if (prepack_weights) {
160     kernel_name += "_prepack_weights";
161   } else if (clamp_out) {
162     kernel_name += "_clamp";
163   }
164   add_dtype_suffix(kernel_name, t_out);
165 
166   return VK_KERNEL_FROM_STR(kernel_name);
167 }
168 
get_final_sizes(const std::vector<int64_t> & original_sizes,const Conv2dMethod method)169 std::vector<int64_t> get_final_sizes(
170     const std::vector<int64_t>& original_sizes,
171     const Conv2dMethod method) {
172   int64_t batch_padded = utils::align_up_4(utils::val_at(-4, original_sizes));
173   int64_t channels_padded =
174       utils::align_up_4(utils::val_at(-3, original_sizes));
175   int64_t height = utils::val_at(-2, original_sizes);
176   int64_t width = utils::val_at(-1, original_sizes);
177 
178   switch (method) {
179     case Conv2dMethod::Depthwise:
180       return std::vector<int64_t>{4, batch_padded / 4, height * width};
181     case Conv2dMethod::Pointwise:
182     case Conv2dMethod::SlidingWindow:
183       return std::vector<int64_t>{
184           4, batch_padded * height / 4, channels_padded * width};
185     case Conv2dMethod::Transposed:
186       return std::vector<int64_t>{
187           4, channels_padded * height / 4, batch_padded * width};
188   }
189 }
190 
prepack_weights(ComputeGraph & graph,const ValueRef vref,const Conv2dMethod method)191 ValueRef prepack_weights(
192     ComputeGraph& graph,
193     const ValueRef vref,
194     const Conv2dMethod method) {
195   const auto original_sizes = graph.sizes_of(vref);
196   const auto final_sizes = get_final_sizes(original_sizes, method);
197 
198   ValueRef v = graph.add_tensor(
199       final_sizes,
200       graph.dtype_of(vref),
201       utils::kTexture2D,
202       utils::kChannelsPacked);
203   vTensorPtr t = graph.get_tensor(v);
204 
205   vkapi::ShaderInfo shader =
206       get_conv2d_shader(graph, *t, /*prepack_weights = */ true, method, vref);
207 
208   graph.prepack_nodes().emplace_back(new PrepackNode(
209       graph,
210       shader,
211       graph.create_global_wg_size(v),
212       graph.create_local_wg_size(v),
213       vref,
214       v,
215       {t->sizes_ubo(),
216        graph.create_params_buffer(
217            utils::make_ivec4(original_sizes, /*reverse = */ true))},
218       // Specialization constants
219       {SV(t->packed_dim())}));
220 
221   return v;
222 }
223 
check_conv_args(const api::vTensor & in,const api::vTensor & out)224 void check_conv_args(const api::vTensor& in, const api::vTensor& out) {
225   VK_CHECK_COND(check_packed_dim_is(in, WHCN::kChannelsDim));
226   VK_CHECK_COND(check_packed_dim_is(out, WHCN::kChannelsDim));
227 }
228 
229 struct Conv2dParams final {
230   utils::ivec2 overlay_region;
231   int in_group_size;
232 };
233 
234 struct OutputParams final {
235   float out_min;
236   float out_max;
237 };
238 
create_conv2d_params(ComputeGraph & graph,const ValueRef weight,const Kernel2dParams & p,const bool transposed)239 Conv2dParams create_conv2d_params(
240     ComputeGraph& graph,
241     const ValueRef weight,
242     const Kernel2dParams& p,
243     const bool transposed) {
244   const auto& overlay_region = utils::make_ivec2({
245       p.kernel_size[0] + (p.kernel_size[0] - 1) * (p.dilation[0] - 1),
246       p.kernel_size[1] + (p.kernel_size[1] - 1) * (p.dilation[1] - 1),
247   });
248   const auto weight_sizes = graph.sizes_of(weight);
249   const int32_t in_group_size = utils::safe_downcast<int32_t>(
250       utils::align_up_4(transposed ? weight_sizes.at(0) : weight_sizes.at(1)));
251   return {overlay_region, in_group_size};
252 }
253 
check_conv2d_params(const Kernel2dParams & p,const bool transposed)254 void check_conv2d_params(const Kernel2dParams& p, const bool transposed) {
255   if (transposed) {
256     if (p.dilation[0] > 1 || p.dilation[1] > 1) {
257       VK_THROW(
258           "aten.convolution.default: transposed = true, dilation > 1 is not supported yet!");
259     }
260   }
261   if ((p.padding[0] > 0 && p.kernel_size[0] > 1 && p.dilation[0] > 1) ||
262       (p.padding[1] > 0 && p.kernel_size[1] > 1 && p.dilation[1] > 1)) {
263     VK_THROW(
264         "aten.convolution.default: padding > 0 while dilation, kernel_size > 1 is not supported yet!");
265   }
266 }
267 
get_conv2d_method(ComputeGraph & graph,const ValueRef weight,const int64_t groups,const bool transposed)268 Conv2dMethod get_conv2d_method(
269     ComputeGraph& graph,
270     const ValueRef weight,
271     const int64_t groups,
272     const bool transposed) {
273   const auto weight_sizes = graph.sizes_of(weight);
274   if (!transposed && weight_sizes.at(0) == groups && weight_sizes.at(1) == 1) {
275     return Conv2dMethod::Depthwise;
276   }
277   if (groups > 1) {
278     VK_THROW("aten.convolution.default: groups > 1 is not supported yet!");
279   }
280   if (transposed) {
281     return Conv2dMethod::Transposed;
282   }
283   if (weight_sizes.at(2) == 1 && weight_sizes.at(3) == 1) {
284     return Conv2dMethod::Pointwise;
285   }
286   return Conv2dMethod::SlidingWindow;
287 }
288 
create_conv2d_global_wg_size(ComputeGraph & graph,const Conv2dMethod method,const ValueRef out)289 utils::uvec3 create_conv2d_global_wg_size(
290     ComputeGraph& graph,
291     const Conv2dMethod method,
292     const ValueRef out) {
293   if (method == Conv2dMethod::Pointwise) {
294     const utils::uvec3 image_extents = graph.logical_limits_of(out);
295     return {
296         utils::div_up(image_extents[0u], 2u),
297         utils::div_up(image_extents[1u], 2u),
298         image_extents[2u]};
299   } else {
300     return graph.create_global_wg_size(out);
301   }
302 }
303 
add_conv2d_node(ComputeGraph & graph,const ValueRef in,const ValueRef weight_data,const ValueRef bias,const ValueRef stride,const ValueRef padding,const ValueRef dilation,const ValueRef transposed,const ValueRef output_padding,const ValueRef groups,const ValueRef out_min,const ValueRef out_max,const ValueRef out,const bool clamp_out)304 void add_conv2d_node(
305     ComputeGraph& graph,
306     const ValueRef in,
307     const ValueRef weight_data,
308     const ValueRef bias,
309     const ValueRef stride,
310     const ValueRef padding,
311     const ValueRef dilation,
312     const ValueRef transposed,
313     const ValueRef output_padding,
314     const ValueRef groups,
315     const ValueRef out_min,
316     const ValueRef out_max,
317     const ValueRef out,
318     const bool clamp_out) {
319   const bool transposed_val = graph.get_bool(transposed);
320 
321   float out_min_val = 0.0f;
322   float out_max_val = 0.0f;
323   if (out_min != kDummyValueRef) {
324     out_min_val = graph.extract_scalar<float>(out_min);
325   }
326   if (out_max != kDummyValueRef) {
327     out_max_val = graph.extract_scalar<float>(out_max);
328   }
329 
330   const int64_t groups_val = graph.get_int(groups);
331 
332   const Conv2dMethod method =
333       get_conv2d_method(graph, weight_data, groups_val, transposed_val);
334 
335   ValueRef arg_weight = prepack_weights(graph, weight_data, method);
336   ValueRef arg_bias = prepack_biases(
337       graph,
338       bias,
339       weight_data,
340       transposed_val,
341       /* storage_type = */ utils::kTexture2D,
342       /* memory_layout = */ utils::kWidthPacked);
343 
344   vTensorPtr t_in = graph.get_tensor(in);
345   vTensorPtr t_out = graph.get_tensor(out);
346   if (t_in->sizes().at(0) > 1) {
347     VK_THROW("conv2d: input batch size > 1 is not supported yet!");
348   }
349   check_conv_args(*t_in, *t_out);
350 
351   Kernel2dParams kernel_params = create_kernel2d_params(
352       graph,
353       weight_data,
354       /*kernel_size_only = */ false,
355       stride,
356       padding,
357       dilation);
358   Conv2dParams extra_params =
359       create_conv2d_params(graph, weight_data, kernel_params, transposed_val);
360 
361   OutputParams out_params = {out_min_val, out_max_val};
362 
363   check_conv2d_params(kernel_params, transposed_val);
364 
365   vkapi::ShaderInfo shader = get_conv2d_shader(
366       graph,
367       *t_out,
368       /*prepack_weights = */ false,
369       method,
370       weight_data,
371       clamp_out);
372 
373   graph.execute_nodes().emplace_back(new DispatchNode(
374       graph,
375       shader,
376       create_conv2d_global_wg_size(graph, method, out),
377       graph.create_local_wg_size(out),
378       // Inputs and Outputs
379       {{out, vkapi::MemoryAccessType::WRITE},
380        {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
381       // Shader params buffers
382       {
383           t_out->logical_limits_ubo(),
384           t_in->sizes_ubo(),
385           graph.create_params_buffer(kernel_params),
386           graph.create_params_buffer(extra_params),
387           graph.create_params_buffer(out_params),
388       },
389       // Specialization Constants
390       {},
391       // Resizing Logic
392       resize_conv2d_node,
393       {weight_data, stride, padding, dilation, transposed, output_padding}));
394 }
395 
add_conv1d_node(ComputeGraph & graph,const ValueRef in,const ValueRef weight,const ValueRef bias,const ValueRef stride,const ValueRef padding,const ValueRef dilation,const ValueRef groups,const ValueRef out_min,const ValueRef out_max,const ValueRef out,const bool clamp_out)396 void add_conv1d_node(
397     ComputeGraph& graph,
398     const ValueRef in,
399     const ValueRef weight,
400     const ValueRef bias,
401     const ValueRef stride,
402     const ValueRef padding,
403     const ValueRef dilation,
404     const ValueRef groups,
405     const ValueRef out_min,
406     const ValueRef out_max,
407     const ValueRef out,
408     const bool clamp_out) {
409   ValueRef arg_weight = prepack_standard(
410       graph, weight, graph.storage_type_of(out), utils::kWidthPacked);
411   ValueRef arg_bias = prepack_biases(
412       graph,
413       bias,
414       weight,
415       /*transposed = */ false,
416       /*storage_type = */ utils::kTexture3D,
417       /*memory_layout = */ utils::kChannelsPacked);
418 
419   float out_min_val = 0.0f;
420   float out_max_val = 0.0f;
421   if (out_min != kDummyValueRef) {
422     out_min_val = graph.extract_scalar<float>(out_min);
423   }
424   if (out_max != kDummyValueRef) {
425     out_max_val = graph.extract_scalar<float>(out_max);
426   }
427 
428   vTensorPtr t_in = graph.get_tensor(in);
429   vTensorPtr t_weight = graph.get_tensor(arg_weight);
430   vTensorPtr t_bias = graph.get_tensor(arg_bias);
431   vTensorPtr t_out = graph.get_tensor(out);
432   const int64_t groups_val = graph.get_int(groups);
433 
434   std::vector<int64_t> in_sizes = t_in->sizes();
435   std::vector<int64_t> weight_sizes = t_weight->sizes();
436   std::vector<int64_t> out_sizes = t_out->sizes();
437 
438   check_conv_args(*t_in, *t_out);
439 
440   int32_t in_channels = in_sizes.at(1);
441   int32_t out_channels = weight_sizes.at(0);
442   int32_t kernel_size = weight_sizes.at(2);
443   int32_t stride_size = graph.get_int_list(stride)->at(0);
444   int32_t padding_size = graph.get_int_list(padding)->at(0);
445   int32_t dilation_size = graph.get_int_list(dilation)->at(0);
446   int32_t in_group_size = static_cast<int64_t>(in_channels / groups_val);
447   int32_t out_group_size = static_cast<int64_t>(out_channels / groups_val);
448 
449   utils::uvec3 global_size = {1, static_cast<uint32_t>(out_channels), 1};
450   utils::uvec3 local_size = {1, 64, 1};
451 
452   Kernel1dParams kernel_params = {
453       kernel_size,
454       stride_size,
455       padding_size,
456       dilation_size,
457       in_group_size,
458       out_group_size};
459 
460   OutputParams out_params = {out_min_val, out_max_val};
461 
462   std::string kernel_name("conv1d");
463   if (clamp_out) {
464     kernel_name += "_clamp";
465   }
466   kernel_name.reserve(kShaderNameReserve);
467 
468   add_dtype_suffix(kernel_name, *t_out);
469 
470   graph.execute_nodes().emplace_back(new DispatchNode(
471       graph,
472       VK_KERNEL_FROM_STR(kernel_name),
473       global_size,
474       local_size,
475       // Inputs and Outputs
476       {{out, vkapi::MemoryAccessType::WRITE},
477        {{in, arg_weight, arg_bias}, vkapi::MemoryAccessType::READ}},
478       // Shader params buffers
479       {
480           t_out->logical_limits_ubo(),
481           t_in->sizes_ubo(),
482           graph.create_params_buffer(kernel_params),
483           graph.create_params_buffer(out_params),
484       },
485       // Specialization Constants
486       {t_out->hashed_layout(),
487        t_in->hashed_layout(),
488        t_weight->hashed_layout(),
489        t_bias->hashed_layout()},
490       // Resizing Logic
491       resize_conv1d_node,
492       {weight, stride, padding, dilation}));
493 }
494 
conv(ComputeGraph & graph,const std::vector<ValueRef> & args)495 void conv(ComputeGraph& graph, const std::vector<ValueRef>& args) {
496   int64_t in_ndim = graph.get_tensor(args[0])->sizes().size();
497   if (in_ndim == 4) {
498     if (args.size() == 10) {
499       // ordinary conv2d
500       return add_conv2d_node(
501           graph,
502           args[0],
503           args[1],
504           args[2],
505           args[3],
506           args[4],
507           args[5],
508           args[6],
509           args[7],
510           args[8],
511           /*out_min = */ kDummyValueRef,
512           /*out_max = */ kDummyValueRef,
513           args[9],
514           false);
515     } else {
516       // conv2d with clamp
517       return add_conv2d_node(
518           graph,
519           args[0],
520           args[1],
521           args[2],
522           args[3],
523           args[4],
524           args[5],
525           args[6],
526           args[7],
527           args[8],
528           args[9],
529           args[10],
530           args[11],
531           true);
532     }
533   } else {
534     if (args.size() == 10) {
535       // ordinary conv1d
536       return add_conv1d_node(
537           graph,
538           args[0],
539           args[1],
540           args[2],
541           args[3],
542           args[4],
543           args[5],
544           args[8],
545           /*out_min = */ kDummyValueRef,
546           /*out_max = */ kDummyValueRef,
547           args[9],
548           false);
549     } else {
550       // conv1d with clamp
551       return add_conv1d_node(
552           graph,
553           args[0],
554           args[1],
555           args[2],
556           args[3],
557           args[4],
558           args[5],
559           args[8],
560           args[9],
561           args[10],
562           args[11],
563           true);
564     }
565   }
566 }
567 
568 REGISTER_OPERATORS {
569   VK_REGISTER_OP(aten.convolution.default, conv);
570   VK_REGISTER_OP(conv_with_clamp.default, conv);
571   VK_REGISTER_OP(et_vk.conv_with_clamp.default, conv);
572 }
573 
574 } // namespace vkcompute
575