1
2 #include <ATen/Context.h>
3
4 #include <ATen/native/ConvUtils.h>
5 #include <ATen/native/utils/ParamUtils.h>
6 #include <ATen/native/vulkan/api/Utils.h>
7 #include <ATen/native/vulkan/impl/Packing.h>
8 #include <ATen/native/vulkan/ops/Common.h>
9 #include <ATen/native/vulkan/ops/Convolution.h>
10 #include <ATen/native/vulkan/ops/Copy.h>
11 #include <ATen/native/vulkan/ops/Utils.h>
12 #include <c10/util/irange.h>
13
14 #ifndef AT_PER_OPERATOR_HEADERS
15 #include <ATen/Functions.h>
16 #else
17 #include <ATen/ops/dequantize.h>
18 #include <ATen/ops/pad.h>
19 #include <ATen/ops/permute.h>
20 #include <ATen/ops/quantize_per_tensor.h>
21 #include <ATen/ops/zeros.h>
22 #endif
23
24 namespace at {
25 namespace native {
26 namespace vulkan {
27 namespace ops {
28
29 namespace conv2d {
30
31 //
32 // Convolution type classification
33 //
34
is_depthwise(const IntArrayRef weight_size,const int64_t groups)35 inline bool is_depthwise(const IntArrayRef weight_size, const int64_t groups) {
36 uint32_t groups_uint = api::utils::safe_downcast<uint32_t>(groups);
37 if (get_dim<DimConv2DKernel::OutChannels>(weight_size) != groups_uint) {
38 return false;
39 }
40 if (get_dim<DimConv2DKernel::InChannels>(weight_size) != 1) {
41 return false;
42 }
43 return true;
44 }
45
is_pointwise(const IntArrayRef weight_size)46 inline bool is_pointwise(const IntArrayRef weight_size) {
47 if (get_dim<DimConv2DKernel::Width>(weight_size) != 1) {
48 return false;
49 }
50 if (get_dim<DimConv2DKernel::Height>(weight_size) != 1) {
51 return false;
52 }
53 return true;
54 }
55
determine_method(const IntArrayRef weight_size,const IntArrayRef stride,const IntArrayRef padding,const IntArrayRef dilation,const int64_t groups,const bool transposed,const bool quantized)56 static Conv2dMethod determine_method(
57 const IntArrayRef weight_size,
58 const IntArrayRef stride,
59 const IntArrayRef padding,
60 const IntArrayRef dilation,
61 const int64_t groups,
62 const bool transposed,
63 const bool quantized) {
64 if (transposed) {
65 return Conv2dSlidingWindow;
66 }
67 if (is_depthwise(weight_size, groups)) {
68 return Conv2dDepthwise;
69 }
70 if (is_pointwise(weight_size)) {
71 return Conv2dPointwise;
72 }
73 return Conv2dSlidingWindow;
74 }
75
76 //
77 // Rearrangement functions for pre-packing
78 //
79
80 /*
81 * Rearranges a convolution weight tensor to a layout that can be used by
82 * convolution compute shaders. The goal of this packing is to arrange the data
83 * such that data access in the compute shader is as linear as possible. The
84 * reasoning behind the packing pattern will be described in the shader kernel
85 * code.
86 *
87 * To understand the transformations performed by this function, consider an
88 * example input of size {11, 1, 3, 3}. The following transformations will
89 * applied to this weight tensor:
90 *
91 * 1. First, apply padding to the N dims so that it is a multiple of 4.
92 * In this case, 1 batch is added, producing a tensor of size {12,1,3,3}.
93 *
94 * 2. Next, flatten the last two dims of the tensor. This is done by reshaping
95 * the tensor to size {12,1,9}.
96 *
97 * 3. Finally, we want to "fold" the batch dim into the channel dim. We start by
98 * splitting the tensor along the N dim so that each split has 4 batches. This
99 * is done by reshaping the tensor to size {3,4,1,9}.
100 *
101 * 4. Normally, we would be done, but we want to stack each back vertically.
102 * This is done by permuting the N and C dims and reshaping the tensor to size
103 * {4,3,9}.
104 */
rearrange_weights_dw(const Tensor & weight_in)105 at::Tensor rearrange_weights_dw(const Tensor& weight_in) {
106 at::Tensor weight = weight_in.clone();
107
108 uint32_t N = ops::get_dim<DimConv2DKernel::OutChannels>(weight);
109 uint32_t C = ops::get_dim<DimConv2DKernel::InChannels>(weight);
110 uint32_t H = ops::get_dim<DimConv2DKernel::Height>(weight);
111 uint32_t W = ops::get_dim<DimConv2DKernel::Width>(weight);
112
113 uint32_t N_aligned = api::utils::align_up(N, 4u);
114
115 // Add padding to the N dimension so that it's a multiple of 4
116 uint32_t N_padding_needed = N_aligned - N;
117 weight =
118 at::pad(weight, {0, 0, 0, 0, 0, 0, 0, N_padding_needed}, "constant", 0);
119
120 // Flatten so the H and W dim are on one row
121 weight = weight.reshape({N_aligned, C, H * W});
122
123 // Split batch dim to make groups of 4
124 uint32_t N4 = N_aligned / 4u;
125 weight = weight.reshape({N4, 4, C, H * W});
126
127 // Permute the groups of 4 so they are arranged along the channel dim, then
128 // reshape to stack the resulting batches vertically
129 weight = weight.permute({1, 0, 2, 3}).reshape({4, N4 * C, H * W});
130
131 return weight.contiguous();
132 }
133
134 /*
135 * Rearranges a convolution weight tensor to a layout that can be used by
136 * convolution compute shaders. The goal of this packing is to arrange the data
137 * such that data access in the compute shader is as linear as possible. The
138 * reasoning behind the packing pattern will be described in the shader kernel
139 * code.
140 *
141 * To understand the transformations performed by this function, consider an
142 * example input of size {10, 7, 3, 3}. The following transformations will
143 * applied to this weight tensor:
144 *
145 * 1. First, apply padding to the N and C dims so that both are a multiple of 4.
146 * In this case, 2 batches and 1 channel of padding are added, producing a
147 * tensor of size {12,8,3,3}.
148 *
149 * 2. Next, split the tensor along the C dim so that each split has 4 channels.
150 * This is done by reshaping the channel to have the size {12,2,(4,3,3)}. ()
151 * brackets denote the size of the split.
152 *
153 * 3. For each split, we want to "fold" the C dim into the W dim. So suppose the
154 * first rows at H=0 of the split has values
155 *
156 * 0,1,2 | 10,11,12 | 20,21,22 | 30,31,32
157 *
158 * where | denotes a channel boundary, then the goal is to combine those rows
159 * into one row with the values
160 *
161 * 0, 10, 20, 30, 1, 11, 21, 31, 2, 12, 22, 32
162 *
163 * This is done in code by permuting and reshaping the tensor, producing a
164 * tensor of size {12,2,(3,12)}.
165 *
166 * 4. Next, we want to stack the splits belonging to the same batch horizontally
167 * which is done by swapping the C and H dims of the intermediate tensor and
168 * reshaping to produce a tensor of size {12,3,24}.
169 *
170 * 5. Now we will repeat a similar process of "folding" the N dim into the C
171 * dim. We start by splitting along the N dim so that each split has 4 batches.
172 * To do this the tensor is reshaped to {3,4,3,24}.
173 *
174 * 6. Normally, we would be done but we also want to stack each batch on each
175 * other vertically. Therefore final step is another permute swapping the N and
176 * C dims and reshaping to the output shape of {4, 9, 24}.
177 *
178 * For transposed convolutions, there are some slight differences to reflect the
179 * data access pattern in the shader. The first major difference is that the
180 * weight tensor is flipped along the H and W dims. The second major difference
181 * is that steps 3 and 4 are slightly different so that the splits are
182 * interleaved.
183 */
rearrange_weights_2d(const Tensor & weight_in,bool tconv)184 at::Tensor rearrange_weights_2d(const Tensor& weight_in, bool tconv) {
185 at::Tensor weight = weight_in.clone();
186
187 // Flip values along the H and W axes for transposed convolutions
188 if (tconv) {
189 weight = weight.flip(3).flip(2);
190 }
191
192 uint32_t N = get_dim<DimConv2DKernel::OutChannels>(weight);
193 uint32_t C = get_dim<DimConv2DKernel::InChannels>(weight);
194 uint32_t H = get_dim<DimConv2DKernel::Height>(weight);
195 uint32_t W = get_dim<DimConv2DKernel::Width>(weight);
196
197 uint32_t N_aligned = api::utils::align_up(N, 4u);
198 uint32_t C_aligned = api::utils::align_up(C, 4u);
199
200 // Add padding to the N and C dimensions so that it's a multiple of 4
201 uint32_t C_padding_needed = C_aligned - C;
202 uint32_t N_padding_needed = N_aligned - N;
203 weight = at::pad(
204 weight,
205 {0, 0, 0, 0, 0, C_padding_needed, 0, N_padding_needed},
206 "constant",
207 0);
208
209 // Split the C dim into groups of 4
210 uint32_t C4 = C_aligned / 4u;
211 weight = weight.reshape({N_aligned, C4, 4, H, W});
212
213 if (!tconv) {
214 // Collapse each group of 4 channels onto the width axis
215 weight = weight.permute({0, 1, 3, 4, 2}).reshape({N_aligned, C4, H, 4 * W});
216 // Next collapse each group of four onto the width axis
217 weight =
218 weight.permute({0, 2, 1, 3}).reshape({N_aligned, H, C_aligned * W});
219 } else {
220 // For tconv, do the same thing as above but we want to interleave batches
221 // of 4 from each of the channels
222 weight = weight.permute({0, 3, 4, 1, 2}).reshape({N_aligned, H, W, 4 * C4});
223 // Next reshape to combine the last two dims into a single row
224 weight = weight.reshape({N_aligned, H, C_aligned * W});
225 }
226
227 // Split the N dim into groups of 4
228 uint32_t N4 = N_aligned / 4u;
229 weight = weight.reshape({N4, 4, H, C_aligned * W});
230
231 // Collapse the outermost dim so that each group of 4 is stacked vertically
232 weight = weight.permute({1, 0, 2, 3}).reshape({4, N4 * H, C_aligned * W});
233
234 return weight.contiguous();
235 }
236
237 /*
238 * Rearranges a convolution weight tensor to a layout that can be used by
239 * convolution compute shaders. The goal of this packing is to arrange the data
240 * such that data access in the compute shader is as linear as possible. The
241 * reasoning behind the packing pattern will be described in the shader kernel
242 * code.
243 *
244 * The rearrangement structure is quite straightforward. Essentially we are
245 * taking each texel and arranging them along the x axis.
246 */
rearrange_bias(const std::optional<Tensor> & bias_in,const at::Tensor & weight_in,bool tconv)247 at::Tensor rearrange_bias(
248 const std::optional<Tensor>& bias_in,
249 const at::Tensor& weight_in,
250 bool tconv) {
251 // If optional is empty, just return zeros
252 if (!bias_in) {
253 uint32_t L = tconv ? get_dim<DimTConv2DKernel::OutChannels>(weight_in)
254 : get_dim<DimConv2DKernel::OutChannels>(weight_in);
255 const uint32_t L4 = api::utils::div_up(L, 4u);
256
257 at::Tensor bias = at::zeros({4, 1, L4}, weight_in.options());
258 return bias;
259 }
260
261 at::Tensor bias = bias_in->clone();
262
263 // Bias should just be a 1D tensor
264 uint32_t L = get_dim<Dim1D::Length>(bias);
265
266 uint32_t L_aligned = api::utils::align_up(L, 4u);
267
268 // Add padding so that the length is a multiple of 4
269 uint32_t padding_needed = L_aligned - L;
270 bias = at::pad(bias, {0, padding_needed}, "constant", 0);
271
272 // Reshape + permute to group every 4 consecutive elements along the same
273 // channel
274 uint32_t L4 = L_aligned / 4u;
275 bias = bias.reshape({L4, 4}).permute({1, 0});
276 bias = bias.reshape({4, 1, L4});
277
278 return bias.contiguous();
279 }
280
281 //
282 // Shader and Workgroup size determination
283 //
284
get_shader(const IntArrayRef kernel_size,const IntArrayRef stride,const IntArrayRef padding,const IntArrayRef dilation,const Conv2dMethod method,const bool transposed,const bool quantized)285 static api::ShaderInfo get_shader(
286 const IntArrayRef kernel_size,
287 const IntArrayRef stride,
288 const IntArrayRef padding,
289 const IntArrayRef dilation,
290 const Conv2dMethod method,
291 const bool transposed,
292 const bool quantized) {
293 api::ShaderInfo shader;
294
295 if (quantized) {
296 if (transposed) {
297 shader = VK_KERNEL(quantized_conv_transpose2d);
298 return shader;
299 }
300
301 switch (method) {
302 case Conv2dSlidingWindow:
303 shader = VK_KERNEL(quantized_conv2d);
304 break;
305 case Conv2dDepthwise:
306 shader = VK_KERNEL(quantized_conv2d_dw);
307 break;
308 case Conv2dPointwise:
309 shader = VK_KERNEL(quantized_conv2d_pw_2x2);
310 break;
311 // todo fail for quantized transposed conv
312 }
313 return shader;
314 }
315
316 if (transposed) {
317 shader = VK_KERNEL(conv_transpose2d);
318 return shader;
319 }
320
321 switch (method) {
322 case Conv2dSlidingWindow:
323 shader = VK_KERNEL(conv2d);
324 break;
325 case Conv2dDepthwise:
326 shader = VK_KERNEL(conv2d_dw);
327 if (kernel_size.size() == 4 && kernel_size[2] == 3 &&
328 kernel_size[3] == 3) {
329 // 1x1 refers to the output tile size
330 shader = VK_KERNEL(conv2d_dw_output_tile_3x3);
331 }
332 if (kernel_size.size() == 4 && kernel_size[2] == 5 &&
333 kernel_size[3] == 5) {
334 // 1x1 refers to the output tile size
335 shader = VK_KERNEL(conv2d_dw_output_tile_5x5);
336 }
337 break;
338 case Conv2dPointwise:
339 shader = VK_KERNEL(conv2d_pw_output_tile_2x2);
340 break;
341 }
342 return shader;
343 }
344
345 //
346 // Op Recording
347 //
348
349 struct Params final {
350 api::utils::ivec3 out_extents;
351 int32_t fill0;
352 api::utils::ivec3 in_extents;
353 int32_t fill1;
354 api::utils::ivec4 overlay_region;
355 api::utils::ivec2 kernel_size;
356 api::utils::ivec2 stride;
357 api::utils::ivec2 padding;
358 api::utils::ivec2 dilate;
359 api::utils::vec2 clamp;
360 };
361
record_op(api::Context * const context,api::ShaderInfo & compute_shader,vTensor & v_output,const vTensor & v_input,const vTensor & v_weight,const vTensor & v_bias,const IntArrayRef overlay_region,const IntArrayRef stride,const IntArrayRef padding,const IntArrayRef dilation,const float output_min,const float output_max,const IntArrayRef kernel_size,const Conv2dMethod method,const bool transposed)362 static void record_op(
363 api::Context* const context,
364 api::ShaderInfo& compute_shader,
365 vTensor& v_output,
366 const vTensor& v_input,
367 const vTensor& v_weight,
368 const vTensor& v_bias,
369 const IntArrayRef overlay_region,
370 const IntArrayRef stride,
371 const IntArrayRef padding,
372 const IntArrayRef dilation,
373 const float output_min,
374 const float output_max,
375 const IntArrayRef kernel_size,
376 const Conv2dMethod method,
377 const bool transposed) {
378 api::PipelineBarrier pipeline_barrier{};
379
380 api::utils::uvec3 global_size = v_output.extents();
381 api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
382
383 Params block{
384 api::utils::make_ivec3(v_output.extents()),
385 0u,
386 api::utils::make_ivec3(v_input.extents()),
387 0u,
388 utils::make_ivec4(overlay_region, /*reverse=*/true),
389 utils::make_ivec2({kernel_size[3], kernel_size[2]}),
390 utils::make_ivec2(stride, /*reverse=*/true),
391 utils::make_ivec2(padding, /*reverse=*/true),
392 utils::make_ivec2(dilation, /*reverse=*/true),
393 {output_min, output_max},
394 };
395 api::UniformParamsBuffer params(context, block);
396
397 context->submit_compute_job(
398 // shader descriptor
399 compute_shader,
400 // pipeline barrier
401 pipeline_barrier,
402 // global work group size
403 global_size,
404 // local work group size
405 local_size,
406 // fence handle
407 VK_NULL_HANDLE,
408 // shader arguments
409 v_output.image(
410 pipeline_barrier,
411 api::PipelineStage::COMPUTE,
412 api::MemoryAccessType::WRITE),
413 v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE),
414 v_weight.image(pipeline_barrier, api::PipelineStage::COMPUTE),
415 v_bias.image(pipeline_barrier, api::PipelineStage::COMPUTE),
416 // params buffer
417 params.buffer());
418 }
419
420 struct QParams final {
421 api::utils::vec4 scales;
422 api::utils::ivec4 zero_points;
423 api::utils::ivec3 out_extents;
424 int32_t fill0;
425 api::utils::ivec3 in_extents;
426 int32_t fill1;
427 api::utils::ivec4 overlay_region;
428 api::utils::ivec2 kernel_size;
429 api::utils::ivec2 stride;
430 api::utils::ivec2 padding;
431 api::utils::ivec2 dilate;
432 api::utils::vec2 clamp;
433 };
434
record_quantized_op(api::Context * const context,api::ShaderInfo & compute_shader,vTensor & v_output,const vTensor & v_input,const vTensor & v_weight,const vTensor & v_bias,const IntArrayRef overlay_region,const IntArrayRef stride,const IntArrayRef padding,const IntArrayRef dilation,const float output_min,const float output_max,const IntArrayRef kernel_size,const Conv2dMethod method,const bool transposed)435 static void record_quantized_op(
436 api::Context* const context,
437 api::ShaderInfo& compute_shader,
438 vTensor& v_output,
439 const vTensor& v_input,
440 const vTensor& v_weight,
441 const vTensor& v_bias,
442 const IntArrayRef overlay_region,
443 const IntArrayRef stride,
444 const IntArrayRef padding,
445 const IntArrayRef dilation,
446 const float output_min,
447 const float output_max,
448 const IntArrayRef kernel_size,
449 const Conv2dMethod method,
450 const bool transposed) {
451 api::PipelineBarrier pipeline_barrier{};
452
453 api::utils::uvec3 global_size = v_output.extents();
454 api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
455
456 QParams block{
457 {
458 v_output.get_scale_float(),
459 v_input.get_scale_float(),
460 v_weight.get_scale_float(),
461 v_bias.get_scale_float(),
462 },
463 {
464 v_output.get_zero_point_int32(),
465 v_input.get_zero_point_int32(),
466 v_weight.get_zero_point_int32(),
467 v_bias.get_zero_point_int32(),
468 },
469 api::utils::make_ivec3(v_output.extents()),
470 0u,
471 api::utils::make_ivec3(v_input.extents()),
472 0u,
473 utils::make_ivec4(overlay_region, /*reverse=*/true),
474 utils::make_ivec2({kernel_size[3], kernel_size[2]}),
475 utils::make_ivec2(stride, /*reverse=*/true),
476 utils::make_ivec2(padding, /*reverse=*/true),
477 utils::make_ivec2(dilation, /*reverse=*/true),
478 {output_min, output_max},
479 };
480 api::UniformParamsBuffer params(context, block);
481
482 context->submit_compute_job(
483 // shader descriptor
484 compute_shader,
485 // pipeline barrier
486 pipeline_barrier,
487 // global work group size
488 global_size,
489 // local work group size
490 local_size,
491 // fence handle
492 VK_NULL_HANDLE,
493 // shader arguments
494 v_output.image(
495 pipeline_barrier,
496 api::PipelineStage::COMPUTE,
497 api::MemoryAccessType::WRITE),
498 v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE),
499 v_weight.image(pipeline_barrier, api::PipelineStage::COMPUTE),
500 v_bias.image(pipeline_barrier, api::PipelineStage::COMPUTE),
501 // params buffer
502 params.buffer());
503 }
504
505 } // namespace conv2d
506
507 namespace {
508
509 using namespace api::utils;
510
pack_weights(const Tensor & weight_inp,const bool transposed,const bool quantized,const Conv2dMethod conv_method)511 vTensor pack_weights(
512 const Tensor& weight_inp,
513 const bool transposed,
514 const bool quantized,
515 const Conv2dMethod conv_method) {
516 if (weight_inp.is_vulkan()) {
517 return convert(weight_inp);
518 }
519
520 const Tensor weight_arg = quantized ? at::dequantize(weight_inp) : weight_inp;
521
522 const Tensor weight = transposed
523 ? at::permute(weight_arg, {1, 0, 2, 3}).contiguous()
524 : weight_arg.contiguous();
525
526 at::Tensor weight_rearranged;
527 if (conv_method == Conv2dDepthwise) {
528 weight_rearranged = conv2d::rearrange_weights_dw(weight);
529 } else {
530 weight_rearranged = conv2d::rearrange_weights_2d(weight, transposed);
531 }
532
533 vTensor v_weight{
534 api::context(),
535 weight_rearranged.sizes().vec(),
536 convert_dtype(weight_rearranged.scalar_type()),
537 api::StorageType::TEXTURE_2D,
538 };
539
540 pack_cpu_to_vulkan(weight_rearranged, v_weight);
541
542 return v_weight;
543 }
544
pack_biases(const std::optional<Tensor> & bias,const Tensor & weight,const bool transposed,const bool quantized)545 vTensor pack_biases(
546 const std::optional<Tensor>& bias,
547 const Tensor& weight,
548 const bool transposed,
549 const bool quantized) {
550 at::Tensor bias_arg = conv2d::rearrange_bias(bias, weight, transposed);
551 at::Tensor bias_rearranged =
552 (quantized &&
553 (bias_arg.scalar_type() == kQUInt8 || bias_arg.scalar_type() == kQInt8 ||
554 bias_arg.scalar_type() == kQInt32))
555 ? at::dequantize(bias_arg)
556 : bias_arg;
557
558 vTensor v_bias{
559 api::context(),
560 bias_rearranged.sizes().vec(),
561 convert_dtype(bias_rearranged.scalar_type()),
562 api::StorageType::TEXTURE_2D,
563 };
564
565 pack_cpu_to_vulkan(bias_rearranged, v_bias);
566
567 return v_bias;
568 }
569
570 /*
571 * Computes the size of the overlay region when computing a convolution output.
572 */
compute_overlay_region(const Tensor & weight,const IntArrayRef dilation,const bool transposed)573 std::array<int64_t, 4> compute_overlay_region(
574 const Tensor& weight,
575 const IntArrayRef dilation,
576 const bool transposed) {
577 const IntArrayRef filter = weight.sizes();
578
579 const auto overlay_length = [](const int64_t k, const int64_t d) {
580 return k + (k - 1) * (d - 1);
581 };
582
583 return {
584 align_up(
585 transposed ? filter[Layout::TransposedFilter::output]
586 : filter[Layout::Filter::output],
587 INT64_C(4)),
588 align_up(
589 transposed ? filter[Layout::TransposedFilter::input]
590 : filter[Layout::Filter::input],
591 INT64_C(4)),
592 overlay_length(
593 filter[Layout::Filter::height], dilation[Layout::Parameter::height]),
594 overlay_length(
595 filter[Layout::Filter::width], dilation[Layout::Parameter::width]),
596 };
597 }
598
pack_params(const std::vector<int64_t> & vector)599 std::array<int64_t, 2> pack_params(const std::vector<int64_t>& vector) {
600 TORCH_INTERNAL_ASSERT(2u == vector.size(), "Invalid usage!");
601
602 return {
603 vector[0],
604 vector[1],
605 };
606 }
607
weight_valid(const Tensor & weight,const bool quantized)608 bool weight_valid(const Tensor& weight, const bool quantized) {
609 if (4 != weight.ndimension()) {
610 return false;
611 }
612 if (get_dim<DimConv2DKernel::Height>(weight) == 0) {
613 return false;
614 }
615 if (get_dim<DimConv2DKernel::Width>(weight) == 0) {
616 return false;
617 }
618 if (!weight.device().is_cpu() &&
619 weight.device().type() != c10::DeviceType::Vulkan) {
620 return false;
621 }
622 if (quantized &&
623 (weight.scalar_type() != c10::kQUInt8 &&
624 weight.scalar_type() != c10::kQInt8)) {
625 return false;
626 }
627
628 return true;
629 }
630
bias_valid(const std::optional<Tensor> & bias,const Tensor & weight,const bool transposed,const bool quantized)631 bool bias_valid(
632 const std::optional<Tensor>& bias,
633 const Tensor& weight,
634 const bool transposed,
635 const bool quantized) {
636 if (!bias) {
637 return true;
638 }
639
640 if (bias->ndimension() != 1) {
641 return false;
642 }
643 if (!bias->device().is_cpu() &&
644 bias->device().type() != c10::DeviceType::Vulkan) {
645 return false;
646 }
647 uint32_t L = get_dim<Dim1D::Length>(*bias);
648 uint32_t OC = transposed ? get_dim<DimTConv2DKernel::OutChannels>(weight)
649 : get_dim<DimConv2DKernel::OutChannels>(weight);
650 if (L != OC) {
651 return false;
652 }
653
654 return true;
655 }
656
available(const Tensor & weight,const std::optional<Tensor> & bias,const IntArrayRef stride,const IntArrayRef padding,const IntArrayRef dilation,const bool transposed,const bool quantized,const IntArrayRef,const int64_t groups,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)657 bool available(
658 const Tensor& weight,
659 const std::optional<Tensor>& bias,
660 const IntArrayRef stride,
661 const IntArrayRef padding,
662 const IntArrayRef dilation,
663 const bool transposed,
664 const bool quantized,
665 const IntArrayRef /* output_padding */,
666 const int64_t groups,
667 const std::optional<Scalar>& output_min,
668 const std::optional<Scalar>& output_max) {
669 if (!weight_valid(weight, quantized)) {
670 return false;
671 }
672 if (!bias_valid(bias, weight, transposed, quantized)) {
673 return false;
674 }
675 if (get_dim<Dim4D::Height>(stride) == 0 ||
676 get_dim<Dim4D::Width>(stride) == 0) {
677 return false;
678 }
679 if (transposed) {
680 if (get_dim<Dim4D::Height>(dilation) != 1 ||
681 get_dim<Dim4D::Width>(dilation) != 1) {
682 return false;
683 }
684 } else {
685 if (get_dim<Dim4D::Height>(dilation) == 0 ||
686 get_dim<Dim4D::Width>(dilation) == 0) {
687 return false;
688 }
689 }
690 if (groups <= 0) {
691 return false;
692 }
693 if (transposed) {
694 if ((get_dim<DimTConv2DKernel::OutChannels>(weight) % groups) != 0) {
695 return false;
696 }
697 } else {
698 if ((get_dim<DimConv2DKernel::OutChannels>(weight) % groups) != 0) {
699 return false;
700 }
701 }
702 if (get_dim<DimConv2DKernel::InChannels>(weight) == 0 ||
703 get_dim<DimConv2DKernel::OutChannels>(weight) == 0) {
704 return false;
705 }
706 if (output_min && !output_min->isFloatingPoint()) {
707 return false;
708 }
709 if (output_max && !output_max->isFloatingPoint()) {
710 return false;
711 }
712 return true;
713 }
714
usable(const Tensor & input,const bool quantized)715 bool usable(const Tensor& input, const bool quantized) {
716 if (input.ndimension() != 4) {
717 return false;
718 }
719 if (input.device().type() != c10::DeviceType::Vulkan) {
720 return false;
721 }
722 if (!quantized && input.scalar_type() != at::kFloat) {
723 return false;
724 }
725 if (quantized && input.scalar_type() != c10::kQUInt8) {
726 return false;
727 }
728 if (get_dim<Dim4D::Batch>(input) == 0) {
729 return false;
730 }
731 if (get_dim<Dim4D::Channel>(input) == 0) {
732 return false;
733 }
734 if (get_dim<Dim4D::Height>(input) == 0) {
735 return false;
736 }
737 if (get_dim<Dim4D::Width>(input) == 0) {
738 return false;
739 }
740 if (input.requires_grad()) {
741 return false;
742 }
743
744 return true;
745 }
746
get_conv_transpose_output_size(IntArrayRef input_size,IntArrayRef weight_size,IntArrayRef padding,IntArrayRef output_padding,IntArrayRef stride,IntArrayRef dilation=IntArrayRef ())747 static inline std::vector<int64_t> get_conv_transpose_output_size(
748 IntArrayRef input_size,
749 IntArrayRef weight_size,
750 IntArrayRef padding,
751 IntArrayRef output_padding,
752 IntArrayRef stride,
753 IntArrayRef dilation = IntArrayRef()) {
754 auto dim = input_size.size();
755 std::vector<int64_t> output_size(dim);
756 output_size[0] = input_size[input_batch_size_dim];
757 output_size[1] = weight_size[weight_input_channels_dim];
758 for (const auto d : c10::irange(2, dim)) {
759 output_size[d] = stride[d - 2] * (input_size[d] - 1) + weight_size[d] -
760 2 * padding[d - 2] + output_padding[d - 2];
761 }
762 return output_size;
763 }
764
convolution(const Tensor & input,const Tensor & weight,const std::optional<Tensor> & bias,const IntArrayRef stride,const IntArrayRef padding,const IntArrayRef dilation,const bool transposed,const IntArrayRef output_padding,const int64_t groups)765 Tensor convolution(
766 const Tensor& input,
767 const Tensor& weight,
768 const std::optional<Tensor>& bias,
769 const IntArrayRef stride,
770 const IntArrayRef padding,
771 const IntArrayRef dilation,
772 const bool transposed,
773 const IntArrayRef output_padding,
774 const int64_t groups) {
775 Conv2dPackedContext conv_context = Conv2dPackedContext(
776 weight,
777 bias,
778 stride,
779 padding,
780 dilation,
781 transposed,
782 false,
783 output_padding,
784 groups);
785
786 return run_conv2d_context(
787 input, c10::make_intrusive<Conv2dPackedContext>(conv_context));
788 }
789
790 } // namespace
791
792 namespace conv1d {
793
pack_weights_using_width_packing(const Tensor & weight_arg)794 static vTensor pack_weights_using_width_packing(const Tensor& weight_arg) {
795 Tensor weight = weight_arg;
796
797 if (weight.is_cpu()) {
798 weight = weight.vulkan();
799 }
800
801 TORCH_CHECK(weight.is_vulkan(), "Weight must be on Vulkan device!");
802
803 vTensor v_weight = convert(weight);
804 if (v_weight.gpu_memory_layout() ==
805 api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED) {
806 v_weight = packing::convert_image_channels_packed_to_width_packed(v_weight);
807 }
808
809 TORCH_CHECK(
810 v_weight.gpu_memory_layout() == api::GPUMemoryLayout::TENSOR_WIDTH_PACKED,
811 "After packing, the v_weight must be in TENSOR_WIDTH_PACKED format");
812
813 return v_weight;
814 }
815
816 /*
817 * This is a full implementation. For algorithm details, refer to the shader
818 * kernel code.
819 */
run_conv1d_context_impl(const Tensor & input_arg,const Tensor & weight_arg,const std::optional<Tensor> & bias_arg_opt,IntArrayRef stride,IntArrayRef padding,IntArrayRef dilation,int64_t groups)820 static Tensor run_conv1d_context_impl(
821 const Tensor& input_arg,
822 const Tensor& weight_arg,
823 const std::optional<Tensor>& bias_arg_opt,
824 IntArrayRef stride,
825 IntArrayRef padding,
826 IntArrayRef dilation,
827 int64_t groups) {
828 api::Context* const context = api::context();
829 const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan();
830 const Tensor weight =
831 weight_arg.is_vulkan() ? weight_arg : weight_arg.vulkan();
832
833 const IntArrayRef& input_sizes = input.sizes();
834 const IntArrayRef& weight_sizes = weight.sizes();
835
836 int32_t in_channels = static_cast<int32_t>(input_sizes[1]);
837 int32_t out_channels = static_cast<int32_t>(weight_sizes[0]);
838 int32_t kernel_size = static_cast<int32_t>(weight_sizes[2]);
839
840 Tensor bias;
841 if (bias_arg_opt) {
842 if (bias_arg_opt->is_vulkan()) {
843 bias = bias_arg_opt.value();
844 } else {
845 bias = bias_arg_opt.value().vulkan();
846 }
847 } else {
848 bias = at::zeros({out_channels}).vulkan();
849 }
850
851 TORCH_CHECK(input.dim() == 3, "input must be a 3-dim tensor");
852 TORCH_CHECK(weight.dim() == 3, "weight must be a 3-dim tensor");
853 TORCH_CHECK(
854 in_channels % groups == 0, "in_channels must be divisible by groups");
855 TORCH_CHECK(
856 out_channels % groups == 0, "out_channels must be divisible by groups");
857
858 const vTensor& v_input = convert(input);
859 const vTensor& v_weight = convert(weight);
860 const vTensor& v_bias = convert(bias);
861
862 vTensor v_output{
863 context,
864 conv_output_size(input_sizes, weight_sizes, padding, stride, dilation),
865 v_input.dtype(),
866 };
867
868 const struct Block final {
869 int32_t in_length;
870 int32_t kernel_size;
871 int32_t stride;
872 int32_t padding;
873 int32_t dilation;
874 int32_t in_group_size;
875 int32_t out_group_size;
876 int32_t batch_size;
877 } block{
878 static_cast<int32_t>(input_sizes[2]),
879 kernel_size,
880 static_cast<int32_t>(stride[0]),
881 static_cast<int32_t>(padding[0]),
882 static_cast<int32_t>(dilation[0]),
883 static_cast<int32_t>(in_channels / groups),
884 static_cast<int32_t>(out_channels / groups),
885 static_cast<int32_t>(input_sizes[0]),
886 };
887
888 api::UniformParamsBuffer params(context, block);
889 api::PipelineBarrier pipeline_barrier{};
890
891 context->submit_compute_job(
892 // shader descriptor
893 VK_KERNEL(conv1d),
894 // pipeline barrier
895 pipeline_barrier,
896 // global work group size
897 {1, static_cast<uint32_t>(out_channels), 1},
898 // local work group size
899 {1, 1, 1},
900 // fence handle
901 VK_NULL_HANDLE,
902 // shader arguments
903 v_output.image(
904 pipeline_barrier,
905 api::PipelineStage::COMPUTE,
906 api::MemoryAccessType::WRITE),
907 v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE),
908 v_weight.image(pipeline_barrier, api::PipelineStage::COMPUTE),
909 v_bias.image(pipeline_barrier, api::PipelineStage::COMPUTE),
910 // params buffer
911 params.buffer());
912
913 return convert(v_output);
914 }
915
916 } // namespace conv1d
917
Conv2dPackedContext(const Tensor & weight,const std::optional<Tensor> & bias,const IntArrayRef stride_arg,const IntArrayRef padding_arg,const IntArrayRef dilation_arg,const bool transposed,const bool quantized,const IntArrayRef output_padding_arg,const int64_t groups,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)918 Conv2dPackedContext::Conv2dPackedContext(
919 const Tensor& weight,
920 const std::optional<Tensor>& bias,
921 const IntArrayRef stride_arg,
922 const IntArrayRef padding_arg,
923 const IntArrayRef dilation_arg,
924 const bool transposed,
925 const bool quantized,
926 const IntArrayRef output_padding_arg,
927 const int64_t groups,
928 const std::optional<Scalar>& output_min,
929 const std::optional<Scalar>& output_max)
930 : unpacked_{c10::AnyType::get()} {
931 const auto stride = expand_param_if_needed(stride_arg, "stride", 2);
932 const auto padding = expand_param_if_needed(padding_arg, "padding", 2);
933 const auto dilation = expand_param_if_needed(dilation_arg, "dilation", 2);
934 const auto output_padding =
935 expand_param_if_needed(output_padding_arg, "output_padding", 2);
936
937 TORCH_CHECK(
938 available(
939 weight,
940 bias,
941 stride,
942 padding,
943 dilation,
944 transposed,
945 quantized,
946 output_padding,
947 groups,
948 output_min,
949 output_max),
950 "Vulkan::convolution not available! "
951 "Reason: The provided (weight, bias, stride, padding, dilation, groups, "
952 "transposed, output_padding, output_min, output_max) parameters are either "
953 "invalid individually or their combination is not supported by Vulkan impl.");
954
955 const auto method = conv2d::determine_method(
956 weight.sizes(), stride, padding, dilation, groups, transposed, quantized);
957
958 packed_.reserve(Packed::NumArgs);
959 packed_.emplace_back(
960 convert(pack_weights(weight, transposed, quantized, method)));
961 packed_.emplace_back(
962 convert(pack_biases(bias, weight, transposed, quantized)));
963 packed_.emplace_back(compute_overlay_region(weight, dilation, transposed));
964 packed_.emplace_back(pack_params(stride));
965 packed_.emplace_back(pack_params(padding));
966 packed_.emplace_back(output_padding);
967 packed_.emplace_back(pack_params(dilation));
968 packed_.emplace_back(transposed);
969 packed_.emplace_back(quantized);
970 packed_.emplace_back(safe_downcast<int32_t>(groups));
971 packed_.emplace_back(
972 output_min ? output_min->template to<float>()
973 : -std::numeric_limits<float>::infinity());
974 packed_.emplace_back(
975 output_max ? output_max->template to<float>()
976 : +std::numeric_limits<float>::infinity());
977 packed_.emplace_back(method);
978 packed_.emplace_back(weight.sizes().vec());
979
980 compute_shader_ = conv2d::get_shader(
981 weight.sizes(), stride, padding, dilation, method, transposed, quantized);
982
983 if (!at::globalContext().releaseWeightsWhenPrepacking()) {
984 unpacked_.reserve(Unpacked::NumArgs);
985 unpacked_.emplace_back(weight);
986 unpacked_.emplace_back(bias);
987 unpacked_.emplace_back(stride_arg.vec());
988 unpacked_.emplace_back(padding_arg.vec());
989 unpacked_.emplace_back(dilation_arg.vec());
990 unpacked_.emplace_back(transposed);
991 unpacked_.emplace_back(quantized);
992 unpacked_.emplace_back(output_padding_arg.vec());
993 unpacked_.emplace_back(groups);
994 unpacked_.emplace_back(output_min);
995 unpacked_.emplace_back(output_max);
996 }
997 }
998
pack(c10::impl::GenericList unpacked)999 Conv2dPackedContext Conv2dPackedContext::pack(c10::impl::GenericList unpacked) {
1000 return Conv2dPackedContext(
1001 unpacked.get(Unpacked::Weight).toTensor(),
1002 get_optional_tensor(unpacked, Unpacked::Bias),
1003 unpacked.get(Unpacked::Stride).toIntVector(),
1004 unpacked.get(Unpacked::Padding).toIntVector(),
1005 unpacked.get(Unpacked::Dilation).toIntVector(),
1006 unpacked.get(Unpacked::isTransposed).toBool(),
1007 unpacked.get(Unpacked::isQuantized).toBool(),
1008 unpacked.get(Unpacked::OutputPadding).toIntVector(),
1009 unpacked.get(Unpacked::Groups).toInt(),
1010 get_optional_scalar(unpacked, Unpacked::OutputMin),
1011 get_optional_scalar(unpacked, Unpacked::OutputMax));
1012 }
1013
create_conv2d_context(Tensor && weight,std::optional<Tensor> && bias,std::vector<int64_t> && stride,std::vector<int64_t> && padding,std::vector<int64_t> && dilation,const int64_t groups,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)1014 c10::intrusive_ptr<Conv2dPackedContext> create_conv2d_context(
1015 Tensor&& weight,
1016 std::optional<Tensor>&& bias,
1017 std::vector<int64_t>&& stride,
1018 std::vector<int64_t>&& padding,
1019 std::vector<int64_t>&& dilation,
1020 const int64_t groups,
1021 const std::optional<Scalar>& output_min,
1022 const std::optional<Scalar>& output_max) {
1023 return c10::make_intrusive<Conv2dPackedContext>(Conv2dPackedContext(
1024 weight,
1025 bias,
1026 stride,
1027 padding,
1028 dilation,
1029 /* transposed = */ false,
1030 /* quantized = */ false,
1031 /* output_padding_arg = */ {0},
1032 groups,
1033 output_min,
1034 output_max));
1035 }
1036
create_tconv2d_context(Tensor && weight,std::optional<Tensor> && bias,std::vector<int64_t> && stride,std::vector<int64_t> && padding,std::vector<int64_t> && output_padding,std::vector<int64_t> && dilation,const int64_t groups,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)1037 c10::intrusive_ptr<Conv2dPackedContext> create_tconv2d_context(
1038 Tensor&& weight,
1039 std::optional<Tensor>&& bias,
1040 std::vector<int64_t>&& stride,
1041 std::vector<int64_t>&& padding,
1042 std::vector<int64_t>&& output_padding,
1043 std::vector<int64_t>&& dilation,
1044 const int64_t groups,
1045 const std::optional<Scalar>& output_min,
1046 const std::optional<Scalar>& output_max) {
1047 return c10::make_intrusive<Conv2dPackedContext>(Conv2dPackedContext(
1048 weight,
1049 bias,
1050 stride,
1051 padding,
1052 dilation,
1053 /* transposed = */ true,
1054 /* quantized = */ false,
1055 output_padding,
1056 groups,
1057 output_min,
1058 output_max));
1059 }
1060
create_qconv2d_context(Tensor && weight,std::optional<Tensor> && bias,std::vector<int64_t> && stride,std::vector<int64_t> && padding,std::vector<int64_t> && dilation,const int64_t groups,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)1061 c10::intrusive_ptr<Conv2dPackedContext> create_qconv2d_context(
1062 Tensor&& weight,
1063 std::optional<Tensor>&& bias,
1064 std::vector<int64_t>&& stride,
1065 std::vector<int64_t>&& padding,
1066 std::vector<int64_t>&& dilation,
1067 const int64_t groups,
1068 const std::optional<Scalar>& output_min,
1069 const std::optional<Scalar>& output_max) {
1070 return c10::make_intrusive<Conv2dPackedContext>(Conv2dPackedContext(
1071 weight,
1072 bias,
1073 stride,
1074 padding,
1075 dilation,
1076 /* transposed = */ false,
1077 /* quantized = */ true,
1078 /* output_padding_arg = */ {0},
1079 groups,
1080 output_min,
1081 output_max));
1082 }
1083
create_qtconv2d_context(Tensor && weight,std::optional<Tensor> && bias,std::vector<int64_t> && stride,std::vector<int64_t> && padding,std::vector<int64_t> && output_padding,std::vector<int64_t> && dilation,const int64_t groups,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)1084 c10::intrusive_ptr<Conv2dPackedContext> create_qtconv2d_context(
1085 Tensor&& weight,
1086 std::optional<Tensor>&& bias,
1087 std::vector<int64_t>&& stride,
1088 std::vector<int64_t>&& padding,
1089 std::vector<int64_t>&& output_padding,
1090 std::vector<int64_t>&& dilation,
1091 const int64_t groups,
1092 const std::optional<Scalar>& output_min,
1093 const std::optional<Scalar>& output_max) {
1094 return c10::make_intrusive<Conv2dPackedContext>(Conv2dPackedContext(
1095 weight,
1096 bias,
1097 stride,
1098 padding,
1099 dilation,
1100 /* transposed = */ true,
1101 /* quantized = */ true,
1102 output_padding,
1103 groups,
1104 output_min,
1105 output_max));
1106 }
1107
run_conv2d_context_impl(const Tensor & input_arg,const c10::intrusive_ptr<Conv2dPackedContext> & conv_context,double scale,int64_t zero_point)1108 static Tensor run_conv2d_context_impl(
1109 const Tensor& input_arg,
1110 const c10::intrusive_ptr<Conv2dPackedContext>& conv_context,
1111 double scale,
1112 int64_t zero_point) {
1113 api::Context* const context = api::context();
1114 // Validate input tensor is a Vulkan tensor, then convert to vTensor
1115 TORCH_CHECK(input_arg.is_vulkan(), "Input tensor must be Vulkan!");
1116 const vTensor& v_input = convert(input_arg);
1117
1118 // Extract everything from the PackedContext
1119 const Tensor weight =
1120 conv_context->get_val(Conv2dPackedContext::Packed::Weight).toTensor();
1121 const vTensor& v_weight = convert(weight);
1122
1123 const auto quantized =
1124 conv_context->get_val(Conv2dPackedContext::Packed::isQuantized).toBool();
1125
1126 Tensor bias =
1127 conv_context->get_val(Conv2dPackedContext::Packed::Bias).toTensor();
1128
1129 const vTensor& v_bias = convert(bias);
1130
1131 const auto overlay_region =
1132 conv_context->get_val(Conv2dPackedContext::Packed::OverlayRegion)
1133 .toIntVector();
1134
1135 const auto stride =
1136 conv_context->get_val(Conv2dPackedContext::Packed::Stride).toIntVector();
1137 const auto padding =
1138 conv_context->get_val(Conv2dPackedContext::Packed::Padding).toIntVector();
1139 const auto output_padding =
1140 conv_context->get_val(Conv2dPackedContext::Packed::OutputPadding)
1141 .toIntVector();
1142 const auto dilation =
1143 conv_context->get_val(Conv2dPackedContext::Packed::Dilation)
1144 .toIntVector();
1145
1146 const auto transposed =
1147 conv_context->get_val(Conv2dPackedContext::Packed::isTransposed).toBool();
1148
1149 const float output_min = safe_downcast<float>(
1150 conv_context->get_val(Conv2dPackedContext::Packed::OutputMin).toDouble());
1151 const float output_max = safe_downcast<float>(
1152 conv_context->get_val(Conv2dPackedContext::Packed::OutputMax).toDouble());
1153
1154 const Conv2dMethod method_ = static_cast<Conv2dMethod>(
1155 conv_context->get_val(Conv2dPackedContext::Packed::ConvMethod).toInt());
1156
1157 const auto kernel_size =
1158 conv_context->get_val(Conv2dPackedContext::Packed::WeightSizes)
1159 .toIntVector();
1160
1161 TORCH_CHECK(
1162 usable(input_arg, quantized), "Input tensor not usable for convolution!");
1163
1164 std::vector<int64_t> output_size;
1165 if (transposed) {
1166 output_size = get_conv_transpose_output_size(
1167 v_input.sizes(),
1168 kernel_size,
1169 padding,
1170 output_padding,
1171 stride,
1172 dilation);
1173 } else {
1174 output_size = conv_output_size(
1175 v_input.sizes(), kernel_size, padding, stride, dilation);
1176 }
1177
1178 vTensor v_output{
1179 context,
1180 output_size,
1181 v_input.dtype(),
1182 };
1183
1184 if (quantized) {
1185 v_output.set_is_quantized();
1186 v_output.set_scale(scale);
1187 v_output.set_zero_point(zero_point);
1188 }
1189
1190 if (quantized) {
1191 conv2d::record_quantized_op(
1192 context,
1193 conv_context->compute_shader(),
1194 v_output,
1195 v_input,
1196 v_weight,
1197 v_bias,
1198 overlay_region,
1199 stride,
1200 padding,
1201 dilation,
1202 output_min,
1203 output_max,
1204 kernel_size,
1205 method_,
1206 transposed);
1207 } else {
1208 conv2d::record_op(
1209 context,
1210 conv_context->compute_shader(),
1211 v_output,
1212 v_input,
1213 v_weight,
1214 v_bias,
1215 overlay_region,
1216 stride,
1217 padding,
1218 dilation,
1219 output_min,
1220 output_max,
1221 kernel_size,
1222 method_,
1223 transposed);
1224 }
1225
1226 return convert(v_output);
1227 }
1228
run_conv2d_context(const Tensor & input_arg,const c10::intrusive_ptr<Conv2dPackedContext> & conv_context)1229 Tensor run_conv2d_context(
1230 const Tensor& input_arg,
1231 const c10::intrusive_ptr<Conv2dPackedContext>& conv_context) {
1232 return run_conv2d_context_impl(input_arg, conv_context, 1.0f, 0u);
1233 }
1234
run_tconv2d_context(const Tensor & input_arg,const c10::intrusive_ptr<Conv2dPackedContext> & conv_context)1235 Tensor run_tconv2d_context(
1236 const Tensor& input_arg,
1237 const c10::intrusive_ptr<Conv2dPackedContext>& conv_context) {
1238 return run_conv2d_context_impl(input_arg, conv_context, 1.0f, 0u);
1239 }
1240
run_qconv2d_context(const Tensor & input_arg,double scale,int64_t zero_point,const c10::intrusive_ptr<Conv2dPackedContext> & conv_context)1241 Tensor run_qconv2d_context(
1242 const Tensor& input_arg,
1243 double scale,
1244 int64_t zero_point,
1245 const c10::intrusive_ptr<Conv2dPackedContext>& conv_context) {
1246 return run_conv2d_context_impl(input_arg, conv_context, scale, zero_point);
1247 }
1248
1249 /* Backwards compatibility */
Conv2dOpContext(Conv2dPackedContext conv_context)1250 Conv2dOpContext::Conv2dOpContext(Conv2dPackedContext conv_context)
1251 : conv_context_{std::move(conv_context)} {}
1252
create(const Tensor & weight,const std::optional<Tensor> & bias,const IntArrayRef stride_arg,const IntArrayRef padding_arg,const IntArrayRef dilation_arg,const bool transposed,const IntArrayRef output_padding_arg,const int64_t groups,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)1253 Conv2dOpContext Conv2dOpContext::create(
1254 const Tensor& weight,
1255 const std::optional<Tensor>& bias,
1256 const IntArrayRef stride_arg,
1257 const IntArrayRef padding_arg,
1258 const IntArrayRef dilation_arg,
1259 const bool transposed,
1260 const IntArrayRef output_padding_arg,
1261 const int64_t groups,
1262 const std::optional<Scalar>& output_min,
1263 const std::optional<Scalar>& output_max) {
1264 return Conv2dOpContext{Conv2dPackedContext(
1265 weight,
1266 bias,
1267 stride_arg,
1268 padding_arg,
1269 dilation_arg,
1270 transposed,
1271 /* quantized = */ false,
1272 output_padding_arg,
1273 groups,
1274 output_min,
1275 output_max)};
1276 }
1277
run(const Tensor & input_arg) const1278 Tensor Conv2dOpContext::run(const Tensor& input_arg) const {
1279 return run_conv2d_context(
1280 input_arg, c10::make_intrusive<Conv2dPackedContext>(conv_context_));
1281 }
1282
unpack() const1283 Conv2dOpContext::State Conv2dOpContext::unpack() const {
1284 const c10::impl::GenericList unpacked_ = conv_context_.unpack();
1285
1286 TORCH_CHECK(unpacked_.size() > 0u, "unpacked_ does not have any elements!");
1287
1288 return Conv2dOpContext::State(
1289 unpacked_.get(Conv2dPackedContext::Unpacked::Weight).toTensor(),
1290 get_optional_tensor(unpacked_, Conv2dPackedContext::Unpacked::Bias),
1291 unpacked_.get(Conv2dPackedContext::Unpacked::Stride).toIntVector(),
1292 unpacked_.get(Conv2dPackedContext::Unpacked::Padding).toIntVector(),
1293 unpacked_.get(Conv2dPackedContext::Unpacked::Dilation).toIntVector(),
1294 unpacked_.get(Conv2dPackedContext::Unpacked::Groups).toInt(),
1295 get_optional_scalar(unpacked_, Conv2dPackedContext::Unpacked::OutputMin),
1296 get_optional_scalar(unpacked_, Conv2dPackedContext::Unpacked::OutputMax));
1297 }
1298
conv2d_clamp_prepack(Tensor && weight,std::optional<Tensor> && bias,std::vector<int64_t> && stride,std::vector<int64_t> && padding,std::vector<int64_t> && dilation,const int64_t groups,const std::optional<Scalar> & output_min,const std::optional<Scalar> & output_max)1299 c10::intrusive_ptr<Conv2dOpContext> conv2d_clamp_prepack(
1300 Tensor&& weight,
1301 std::optional<Tensor>&& bias,
1302 std::vector<int64_t>&& stride,
1303 std::vector<int64_t>&& padding,
1304 std::vector<int64_t>&& dilation,
1305 const int64_t groups,
1306 const std::optional<Scalar>& output_min,
1307 const std::optional<Scalar>& output_max) {
1308 return c10::make_intrusive<Conv2dOpContext>(Conv2dOpContext::create(
1309 std::move(weight),
1310 std::move(bias),
1311 std::move(stride),
1312 std::move(padding),
1313 std::move(dilation),
1314 /* transposed = */ false,
1315 /* output_padding = */ {0},
1316 groups,
1317 output_min,
1318 output_max));
1319 }
1320
conv2d_clamp_run(const Tensor & input,const c10::intrusive_ptr<Conv2dOpContext> & context)1321 Tensor conv2d_clamp_run(
1322 const Tensor& input,
1323 const c10::intrusive_ptr<Conv2dOpContext>& context) {
1324 return context->run(input);
1325 }
1326
Conv1dPackedContext(const Tensor & weight,const std::optional<Tensor> & bias,const IntArrayRef stride_arg,const IntArrayRef padding_arg,const IntArrayRef dilation_arg,const int64_t groups)1327 Conv1dPackedContext::Conv1dPackedContext(
1328 const Tensor& weight,
1329 const std::optional<Tensor>& bias,
1330 const IntArrayRef stride_arg,
1331 const IntArrayRef padding_arg,
1332 const IntArrayRef dilation_arg,
1333 const int64_t groups)
1334 : unpacked_{c10::AnyType::get()} {
1335 packed_.reserve(Packed::NumArgs);
1336 packed_.emplace_back(
1337 convert(conv1d::pack_weights_using_width_packing(weight.vulkan())));
1338 packed_.emplace_back(bias->vulkan());
1339 packed_.emplace_back(stride_arg);
1340 packed_.emplace_back(padding_arg);
1341 packed_.emplace_back(dilation_arg);
1342 packed_.emplace_back(safe_downcast<int32_t>(groups));
1343
1344 compute_shader_ = VK_KERNEL(conv1d);
1345
1346 if (!at::globalContext().releaseWeightsWhenPrepacking()) {
1347 unpacked_.reserve(Unpacked::NumArgs);
1348 unpacked_.emplace_back(weight);
1349 unpacked_.emplace_back(bias);
1350 unpacked_.emplace_back(stride_arg.vec());
1351 unpacked_.emplace_back(padding_arg.vec());
1352 unpacked_.emplace_back(dilation_arg.vec());
1353 unpacked_.emplace_back(safe_downcast<int32_t>(groups));
1354 }
1355 }
1356
pack(c10::impl::GenericList unpacked)1357 Conv1dPackedContext Conv1dPackedContext::pack(c10::impl::GenericList unpacked) {
1358 return Conv1dPackedContext(
1359 unpacked.get(Unpacked::Weight).toTensor(),
1360 get_optional_tensor(unpacked, Unpacked::Bias),
1361 unpacked.get(Unpacked::Stride).toIntVector(),
1362 unpacked.get(Unpacked::Padding).toIntVector(),
1363 unpacked.get(Unpacked::Dilation).toIntVector(),
1364 unpacked.get(Unpacked::Groups).toInt());
1365 }
1366
create_conv1d_context(Tensor && weight,std::optional<Tensor> && bias,std::vector<int64_t> && stride,std::vector<int64_t> && padding,std::vector<int64_t> && dilation,const int64_t groups)1367 c10::intrusive_ptr<Conv1dPackedContext> create_conv1d_context(
1368 Tensor&& weight,
1369 std::optional<Tensor>&& bias,
1370 std::vector<int64_t>&& stride,
1371 std::vector<int64_t>&& padding,
1372 std::vector<int64_t>&& dilation,
1373 const int64_t groups) {
1374 return c10::make_intrusive<Conv1dPackedContext>(
1375 Conv1dPackedContext(weight, bias, stride, padding, dilation, groups));
1376 }
1377
convolution1d(const Tensor & input,const Tensor & weight,const std::optional<Tensor> & bias,const IntArrayRef stride,const IntArrayRef padding,const IntArrayRef dilation,const int64_t groups)1378 static Tensor convolution1d(
1379 const Tensor& input,
1380 const Tensor& weight,
1381 const std::optional<Tensor>& bias,
1382 const IntArrayRef stride,
1383 const IntArrayRef padding,
1384 const IntArrayRef dilation,
1385 const int64_t groups) {
1386 Conv1dPackedContext conv1d_context =
1387 Conv1dPackedContext(weight, bias, stride, padding, dilation, groups);
1388
1389 return run_conv1d_context(
1390 input, c10::make_intrusive<Conv1dPackedContext>(conv1d_context));
1391 }
1392
run_conv1d_context(const Tensor & input,const c10::intrusive_ptr<Conv1dPackedContext> & context)1393 Tensor run_conv1d_context(
1394 const Tensor& input,
1395 const c10::intrusive_ptr<Conv1dPackedContext>& context) {
1396 const Tensor weight =
1397 context->get_val(Conv1dPackedContext::Packed::Weight).toTensor();
1398 const std::optional<Tensor>& bias_opt =
1399 context->get_val(Conv1dPackedContext::Packed::Bias).toTensor();
1400 const auto stride =
1401 context->get_val(Conv1dPackedContext::Packed::Stride).toIntVector();
1402 const auto padding =
1403 context->get_val(Conv1dPackedContext::Packed::Padding).toIntVector();
1404 const auto dilation =
1405 context->get_val(Conv1dPackedContext::Packed::Dilation).toIntVector();
1406 const auto groups =
1407 context->get_val(Conv1dPackedContext::Packed::Groups).toInt();
1408 return conv1d::run_conv1d_context_impl(
1409 input, weight, bias_opt, stride, padding, dilation, groups);
1410 }
1411
TORCH_LIBRARY_IMPL(aten,Vulkan,m)1412 TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
1413 m.impl("convolution_overrideable", convolution);
1414 m.impl(TORCH_SELECTIVE_NAME("aten::conv1d"), TORCH_FN(convolution1d));
1415 }
1416
1417 } // namespace ops
1418 } // namespace vulkan
1419 } // namespace native
1420 } // namespace at
1421