xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/ops/Mm.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/native/vulkan/ops/Mm.h>
2 #include <ATen/native/vulkan/ops/Utils.h>
3 
4 #include <ATen/Context.h>
5 #include <ATen/Functions.h>
6 #include <ATen/native/vulkan/api/Tensor.h>
7 #include <ATen/native/vulkan/api/Types.h>
8 #include <ATen/native/vulkan/impl/Packing.h>
9 #include <c10/util/irange.h>
10 
11 namespace at {
12 namespace native {
13 namespace vulkan {
14 namespace ops {
15 namespace {
16 
17 using namespace api::utils;
18 using namespace at::native::vulkan::ops;
19 
pack_inputs_using_width_packing(const Tensor & input_arg)20 vTensor pack_inputs_using_width_packing(const Tensor& input_arg) {
21   TORCH_INTERNAL_ASSERT(
22       !input_arg.is_quantized(),
23       "Vulkan Linear not usable! "
24       "Reason: Input packing only supports non-quantized tensors.");
25   TORCH_INTERNAL_ASSERT(
26       input_arg.dim() == 2 || input_arg.dim() == 3,
27       "Vulkan Linear not usable! "
28       "Reason: Input packing only supports 2D or 3D tensors.");
29 
30   Tensor input = input_arg;
31   if (input.is_cpu()) {
32     input = input.vulkan();
33   }
34 
35   TORCH_CHECK(input.is_vulkan(), "Input must be on Vulkan device!");
36 
37   vTensor v_input = convert(input);
38   if (v_input.gpu_memory_layout() ==
39       api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED) {
40     v_input = packing::convert_image_channels_packed_to_width_packed(v_input);
41   }
42 
43   TORCH_CHECK(
44       v_input.gpu_memory_layout() == api::GPUMemoryLayout::TENSOR_WIDTH_PACKED,
45       "After packing, the v_input must be in TENSOR_WIDTH_PACKED format");
46 
47   return v_input;
48 }
49 
pack_weights_using_height_packing(const Tensor & weight_arg)50 vTensor pack_weights_using_height_packing(const Tensor& weight_arg) {
51   // Only non-batch, non-quantized tensors are supported
52   TORCH_INTERNAL_ASSERT(
53       !weight_arg.is_quantized(),
54       "Vulkan Linear not usable! "
55       "Reason: Weight packing only supports non-quantized tensors.");
56   TORCH_INTERNAL_ASSERT(
57       weight_arg.dim() == 2 || weight_arg.dim() == 3,
58       "Vulkan Linear not usable! "
59       "Reason: Weight packing only supports 2D or 3D tensors.");
60 
61   Tensor weight = weight_arg;
62 
63   if (weight.is_cpu()) {
64     weight = weight.vulkan();
65   }
66 
67   TORCH_CHECK(weight.is_vulkan(), "Weight must be on Vulkan device!");
68 
69   vTensor v_weight = convert(weight);
70   if (v_weight.gpu_memory_layout() ==
71       api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED) {
72     v_weight =
73         packing::convert_image_channels_packed_to_height_packed(v_weight);
74   }
75 
76   TORCH_CHECK(
77       v_weight.gpu_memory_layout() ==
78           api::GPUMemoryLayout::TENSOR_HEIGHT_PACKED,
79       "After packing, the v_weight must be in TENSOR_HEIGHT_PACKED format");
80 
81   return v_weight;
82 }
83 
pack_weights(const Tensor & weight_arg,const bool use_batch=false)84 vTensor pack_weights(const Tensor& weight_arg, const bool use_batch = false) {
85   if (!weight_arg.is_quantized()) {
86     return pack_weights_using_height_packing(weight_arg);
87   }
88 
89   TORCH_CHECK(
90       weight_arg.is_quantized(), "Only quantized weights logic after here");
91 
92   // Rest of the logic are either quantized or batched.
93 
94   api::Context* const context = api::context();
95 
96   const Tensor weight = weight_arg.contiguous();
97   const IntArrayRef w_sizes = weight.sizes();
98   if (use_batch) {
99     TORCH_CHECK(
100         w_sizes.size() == 3,
101         "Vulkan Linear not usable! "
102         "Reason: Unable to perform weight packing with batch; the input tensor of a batch of matrices should contain 3 dimensions: batch, height, width.");
103   }
104   /* Source */
105   int64_t src_kb_sz = 0;
106   int64_t src_kw_sz = 0;
107   int64_t src_kh_sz = 0;
108   /* Destination */
109   int64_t dst_kb_sz = 0;
110   int64_t dst_kw_sz = 0;
111   int64_t dst_kh_sz = 0;
112   std::vector<int64_t> dst_vtensor_sizes;
113   /* Source */
114   src_kb_sz = use_batch ? w_sizes[Layout::BatchMatrices::batch] : 1;
115   src_kw_sz = use_batch ? w_sizes[Layout::BatchMatrices::width]
116                         : w_sizes[Layout::Parameter::width];
117   src_kh_sz = use_batch ? w_sizes[Layout::BatchMatrices::height]
118                         : w_sizes[Layout::Parameter::height];
119 
120   /* Destination */
121   dst_kb_sz = src_kb_sz;
122   dst_kw_sz = div_up(src_kw_sz, INT64_C(2));
123   dst_kh_sz = div_up(src_kh_sz, INT64_C(2));
124   dst_vtensor_sizes = {
125       dst_kb_sz,
126       4,
127       dst_kh_sz,
128       dst_kw_sz,
129   };
130 
131   vTensor v_weight{
132       context, dst_vtensor_sizes, convert_dtype(weight_arg.scalar_type())};
133 
134   v_weight.set_is_quantized();
135   v_weight.set_scale(weight_arg.q_scale());
136   v_weight.set_zero_point(weight_arg.q_zero_point());
137 
138   stage_pack_weights<int8_t>(
139       context,
140       v_weight,
141       weight,
142       src_kb_sz,
143       src_kh_sz,
144       src_kw_sz,
145       dst_kh_sz,
146       dst_kw_sz);
147   return v_weight;
148 }
149 
pack_biases(const Tensor & weight_arg,const std::optional<Tensor> & bias_arg,const bool use_batch=false)150 vTensor pack_biases(
151     const Tensor& weight_arg,
152     const std::optional<Tensor>& bias_arg,
153     const bool use_batch = false) {
154   if (bias_arg) {
155     Tensor bias = *bias_arg;
156     if (bias.is_cpu()) {
157       bias = bias.vulkan();
158     }
159     return convert(bias);
160   } else {
161     return convert(at::zeros({}, at::device(at::kVulkan).dtype(at::kFloat)));
162   }
163 }
164 
165 // Old version of pack_biases that fixes issues with quantization and to be
166 // removed in the future.
pack_biases_quantized_weights(const Tensor & weight_arg,const std::optional<Tensor> & bias_arg,const bool use_batch=false)167 vTensor pack_biases_quantized_weights(
168     const Tensor& weight_arg,
169     const std::optional<Tensor>& bias_arg,
170     const bool use_batch = false) {
171   TORCH_CHECK(
172       weight_arg.is_quantized(),
173       "pack_biases_quantized to be used only when using quantized linear ops");
174 
175   if (bias_arg && bias_arg->is_vulkan()) {
176     return convert(*bias_arg);
177   }
178 
179   api::Context* const context = api::context();
180 
181   if (bias_arg) {
182     const Tensor bias = bias_arg->contiguous();
183     const IntArrayRef b_sizes = bias.sizes();
184     const float* const src_bias_ptr = bias.const_data_ptr<float>();
185 
186     /* Source */
187     int64_t src_kb_sz = 0;
188     int64_t src_kw_sz = 0;
189     int64_t src_kh_sz = 0;
190     if (use_batch) {
191       if (bias.sizes().size() == 3) {
192         src_kb_sz = b_sizes[Layout::BatchMatrices::batch];
193         src_kw_sz = b_sizes[Layout::BatchMatrices::width];
194         src_kh_sz = b_sizes[Layout::BatchMatrices::height];
195       } else if (bias.sizes().size() == 2) {
196         // skip batch dim for boardcasting; index -1
197         src_kb_sz = 1;
198         src_kw_sz = b_sizes[Layout::BatchMatrices::height];
199         src_kh_sz = b_sizes[Layout::BatchMatrices::batch];
200       } else {
201         // skip batch & height dim for boardcasting; index -2
202         src_kb_sz = 1;
203         src_kw_sz = b_sizes[Layout::BatchMatrices::batch];
204         src_kh_sz = 1;
205       }
206     } else {
207       src_kb_sz = 1;
208       if (bias.sizes().size() == 2) {
209         src_kw_sz = b_sizes[Layout::Parameter::width];
210         src_kh_sz = b_sizes[Layout::Parameter::height];
211       } else {
212         src_kw_sz = b_sizes[Layout::Parameter::height];
213         src_kh_sz = 1;
214       }
215     }
216     const int64_t src_matrix_sz = src_kw_sz * src_kh_sz;
217 
218     /* Destination */
219     const int64_t dst_kw_sz = div_up(src_kw_sz, INT64_C(2));
220     const int64_t dst_kh_sz = div_up(src_kh_sz, INT64_C(2));
221     const int64_t dst_plane_sz = dst_kw_sz * dst_kh_sz;
222     const int64_t dst_matrix_sz = dst_plane_sz * 4;
223 
224     vTensor v_bias{
225         context,
226         {
227             src_kb_sz,
228             4,
229             dst_kh_sz,
230             dst_kw_sz,
231         },
232         convert_dtype(bias_arg->scalar_type()),
233     };
234 
235     api::StorageBuffer staging(
236         context, api::ScalarType::Float, v_bias.gpu_numel());
237     {
238       api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE);
239 
240       float* dst_bias_ptr = mapping.template data<float>();
241 
242       memset(dst_bias_ptr, 0, v_bias.nbytes());
243 
244       for (const auto src_b : c10::irange(src_kb_sz)) {
245         for (const auto src_h : c10::irange(src_kh_sz == 1 ? 2 : src_kh_sz)) {
246           for (const auto src_w :
247                c10::irange((use_batch && src_kw_sz == 1) ? 2 : src_kw_sz)) {
248             int64_t dst_plane = 2 * (src_h % 2) + (src_w % 2);
249             int64_t dst_index = (src_h / 2) * dst_kw_sz + (src_w / 2);
250             memcpy(
251                 dst_bias_ptr + src_b * dst_matrix_sz +
252                     dst_plane * dst_plane_sz + dst_index,
253                 src_bias_ptr + src_b * src_matrix_sz +
254                     (src_kh_sz == 1 ? 0 : src_h * src_kw_sz) +
255                     ((use_batch && src_kw_sz == 1) ? 0 : src_w),
256                 sizeof(float));
257           }
258         }
259       }
260     }
261     utils::pack_staging_to_vtensor(staging.buffer(), v_bias);
262 
263     return v_bias;
264   } else {
265     vTensor v_bias{
266         api::context(),
267         {1},
268         convert_dtype(weight_arg.scalar_type()),
269     };
270 
271     api::StorageBuffer staging(
272         context, api::ScalarType::Float, v_bias.gpu_numel());
273     {
274       api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE);
275 
276       float* data_ptr = mapping.template data<float>();
277 
278       memset(
279           data_ptr,
280           // 2's complement integers and IEEE-754 floating point numbers both
281           // have identical bit representations for 0, so can use memset which
282           // only accepts uint8_t parameter.
283           0,
284           v_bias.nbytes());
285     }
286     utils::pack_staging_to_vtensor(staging.buffer(), v_bias);
287 
288     return v_bias;
289   }
290 }
291 
available_check_with_batch(const Tensor & weight,const std::optional<Tensor> & bias)292 bool available_check_with_batch(
293     const Tensor& weight,
294     const std::optional<Tensor>& bias) {
295   const bool weight_available = (3 == weight.ndimension()) &&
296       (weight.size(Layout::BatchMatrices::batch) > 0) &&
297       (weight.size(Layout::BatchMatrices::height) > 0) &&
298       (weight.size(Layout::BatchMatrices::width) > 0) &&
299       ((weight.device().is_cpu()) ||
300        (c10::DeviceType::Vulkan == weight.device().type())) &&
301       (kFloat == weight.scalar_type()) && !weight.requires_grad();
302   if (!weight_available) {
303     return false;
304   }
305 
306   if (!bias || !bias->defined()) {
307     // no need to check bias since it is not used.
308     return true;
309   }
310 
311   bool bias_available = true;
312   bias_available &= (bias->ndimension() > 0);
313   bias_available &=
314       ((bias->device().is_cpu()) ||
315        (c10::DeviceType::Vulkan == bias->device().type()));
316   bias_available &= (kFloat == bias->scalar_type());
317   // Only check the consistency of batch and width dimension. The height
318   // dimension consistency is unchecked, due to the 2nd input which determines
319   // the height is not passed into LinearPackedContext.
320   if (bias->ndimension() == 3) {
321     bias_available &=
322         (bias->size(Layout::BatchMatrices::width) ==
323              weight.size(Layout::BatchMatrices::width) ||
324          bias->size(Layout::BatchMatrices::width) == 1);
325     bias_available &=
326         (bias->size(Layout::BatchMatrices::batch) ==
327              weight.size(Layout::BatchMatrices::batch) ||
328          bias->size(Layout::BatchMatrices::batch) == 1);
329   } else if (bias->ndimension() == 2) {
330     // skip batch dim for boardcasting; index -1
331     bias_available &=
332         (bias->size(Layout::BatchMatrices::height) ==
333              weight.size(Layout::BatchMatrices::width) ||
334          bias->size(Layout::BatchMatrices::height) == 1);
335   } else {
336     // skip batch & height dim for boardcasting; index -2
337     bias_available &=
338         (bias->size(Layout::BatchMatrices::batch) ==
339              weight.size(Layout::BatchMatrices::width) ||
340          bias->size(Layout::BatchMatrices::batch) == 1);
341   }
342   bias_available &= !bias->requires_grad();
343   return bias_available;
344 }
345 
available(const Tensor & weight,const std::optional<Tensor> & bias,const bool use_batch=false)346 bool available(
347     const Tensor& weight,
348     const std::optional<Tensor>& bias,
349     const bool use_batch = false) {
350   if (!api::available()) {
351     return false;
352   }
353 
354   if (use_batch) {
355     return available_check_with_batch(weight, bias);
356   }
357 
358   const bool weight_available = (2 == weight.ndimension()) &&
359       (weight.size(Layout::Parameter::height) > 0) &&
360       (weight.size(Layout::Parameter::width) > 0) &&
361       ((weight.device().is_cpu()) ||
362        (c10::DeviceType::Vulkan == weight.device().type())) &&
363       (kFloat == weight.scalar_type() || kQInt8 == weight.scalar_type()) &&
364       !weight.requires_grad();
365   if (!weight_available) {
366     return false;
367   }
368 
369   const bool bias_available =
370       ((bias && bias.has_value() && bias->defined())
371            ? ((bias->ndimension() > 0) &&
372               ((bias->device().is_cpu()) ||
373                (c10::DeviceType::Vulkan == bias->device().type())) &&
374               (kFloat == bias->scalar_type()) &&
375               ((bias->ndimension() > 1)
376                    ? (bias->size(Layout::Parameter::width) ==
377                       weight.size(Layout::Parameter::width))
378                    : true) &&
379               !bias->requires_grad())
380            : true);
381   return bias_available;
382 }
383 
usable_check_with_batch(const Tensor & input,const IntArrayRef unpacked_weight_sizes)384 bool usable_check_with_batch(
385     const Tensor& input,
386     const IntArrayRef unpacked_weight_sizes) {
387   return (3 == input.ndimension()) &&
388       (c10::DeviceType::Vulkan == input.device().type()) &&
389       (kFloat == input.scalar_type()) &&
390       (input.size(Layout::BatchMatrices::width) ==
391        unpacked_weight_sizes[Layout::BatchMatrices::height]) &&
392       (input.size(Layout::BatchMatrices::batch) ==
393        unpacked_weight_sizes[Layout::BatchMatrices::batch]) &&
394       !input.requires_grad() && true;
395 }
396 
usable(const Tensor & input,const IntArrayRef unpacked_weight_sizes,const bool use_batch=false)397 bool usable(
398     const Tensor& input,
399     const IntArrayRef unpacked_weight_sizes,
400     const bool use_batch = false) {
401   if (use_batch) {
402     return usable_check_with_batch(input, unpacked_weight_sizes);
403   }
404   const auto v_input = convert(input);
405   return (2 == input.ndimension()) &&
406       (c10::DeviceType::Vulkan == input.device().type()) &&
407       ((kFloat == input.scalar_type()) ||
408        (v_input.is_quantized() &&
409         (kQUInt8 == input.scalar_type() || kQInt8 == input.scalar_type()))) &&
410       (input.size(Layout::Parameter::width) ==
411        unpacked_weight_sizes[Layout::Parameter::height]) &&
412       !input.requires_grad() && true;
413 }
414 
reshape_to_2d(const Tensor & input_arg)415 static Tensor reshape_to_2d(const Tensor& input_arg) {
416   TORCH_CHECK(
417       input_arg.dim() >= 1,
418       "Vulkan Linear op only supports input tensor with dim >= 1");
419 
420   if (input_arg.dim() == 1) {
421     return input_arg.unsqueeze(0);
422   }
423   const IntArrayRef input_sizes = input_arg.sizes();
424   const auto d =
425       c10::multiply_integers(input_sizes.cbegin(), input_sizes.end() - 1);
426   return input_arg.reshape({d, input_arg.size(-1)});
427 }
428 
run_quantized_addmm_context(const Tensor & input_arg,const float alpha,const float beta,const c10::intrusive_ptr<LinearPackedContext> & linear_context,double output_scale,int64_t output_zero_point)429 Tensor run_quantized_addmm_context(
430     const Tensor& input_arg,
431     const float alpha,
432     const float beta,
433     const c10::intrusive_ptr<LinearPackedContext>& linear_context,
434     double output_scale,
435     int64_t output_zero_point) {
436   api::Context* const context = api::context();
437 
438   const Tensor input_arg_2d =
439       input_arg.dim() == 2 ? input_arg : reshape_to_2d(input_arg);
440   const Tensor input =
441       input_arg_2d.is_vulkan() ? input_arg_2d : input_arg_2d.vulkan();
442   const vTensor& v_input = convert(input);
443   const vTensor& packed_v_weight = convert(
444       linear_context->get_val(LinearPackedContext::Packed::Weight).toTensor());
445   const vTensor& packed_v_bias = convert(
446       linear_context->get_val(LinearPackedContext::Packed::Bias).toTensor());
447   const std::vector<int64_t> unpacked_weight_sizes =
448       linear_context->get_val(LinearPackedContext::Packed::WeightSizes)
449           .toIntVector();
450   const bool bias_defined =
451       linear_context->get_val(LinearPackedContext::Packed::BiasDefined)
452           .toBool();
453 
454   TORCH_CHECK(
455       usable(input, unpacked_weight_sizes),
456       "Vulkan Linear not usable! "
457       "Reason: The provided input tensor is either invalid on its own, or its "
458       "combination with the provided weight and bias tensors are unsupported by "
459       "Vulkan impl.");
460 
461   TORCH_CHECK(
462       (packed_v_weight.is_quantized() && v_input.is_quantized()),
463       "run_quantized_addmm_context called for quantized version with unquantized input");
464 
465   vTensor v_output{
466       context,
467       {
468           input_arg_2d.sizes()[Layout::Parameter::height],
469           unpacked_weight_sizes[Layout::Parameter::width],
470       },
471       v_input.dtype(),
472   };
473 
474   v_output.set_is_quantized();
475   v_output.set_scale(output_scale);
476   v_output.set_zero_point(output_zero_point);
477 
478   if (bias_defined) {
479     api::UniformParamsBuffer params;
480     api::ShaderInfo compute_shader;
481     compute_shader = (kQInt8 == input_arg.scalar_type())
482         ? VK_KERNEL(quantized_addmm_qint8)
483         : VK_KERNEL(quantized_addmm_quint8);
484     const struct {
485       uvec3 size;
486       int32_t K;
487       uvec3 um1_size;
488       int32_t K1;
489       uvec3 um2_size;
490       int32_t K2;
491       uvec3 ut_size;
492       int32_t K3;
493       vec2 multiplier;
494       vec2 input_scales;
495       float out_scale;
496       float _1;
497       ivec2 input_zero_points;
498       int32_t out_zero_point;
499       int32_t _2;
500     } block{
501         v_output.extents(),
502         safe_downcast<int32_t>(
503             div_up(v_input.sizes()[Layout::Parameter::width], INT64_C(2))),
504         v_input.extents(),
505         0u,
506         packed_v_weight.extents(),
507         0u,
508         packed_v_bias.extents(),
509         0u,
510         {
511             alpha,
512             beta,
513         },
514         {
515             safe_downcast<float>(v_input.get_scale()),
516             safe_downcast<float>(packed_v_weight.get_scale()),
517         },
518         safe_downcast<float>(output_scale),
519         0.0f,
520         {
521             safe_downcast<int32_t>(v_input.get_zero_point()),
522             safe_downcast<int32_t>(packed_v_weight.get_zero_point()),
523         },
524         safe_downcast<int32_t>(output_zero_point),
525         0u,
526     };
527     params = api::UniformParamsBuffer(context, block);
528 
529     api::PipelineBarrier pipeline_barrier{};
530     context->submit_compute_job(
531         // shader descriptor
532         compute_shader,
533         // pipeline barrier
534         pipeline_barrier,
535         // global work group size
536         {
537             safe_downcast<uint32_t>(
538                 div_up(v_output.sizes()[Layout::Parameter::width], INT64_C(2))),
539             safe_downcast<uint32_t>(div_up(
540                 v_output.sizes()[Layout::Parameter::height], INT64_C(2))),
541             1,
542         },
543         // local work group size
544         {8, 8, 1},
545         // fence handle
546         VK_NULL_HANDLE,
547         // shader arguments
548         v_output.image(
549             pipeline_barrier,
550             api::PipelineStage::COMPUTE,
551             api::MemoryAccessType::WRITE),
552         v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE),
553         packed_v_weight.image(pipeline_barrier, api::PipelineStage::COMPUTE),
554         packed_v_bias.image(pipeline_barrier, api::PipelineStage::COMPUTE),
555         // params buffer
556         params.buffer());
557 
558   } else { // no bias
559     api::UniformParamsBuffer params;
560     api::ShaderInfo compute_shader;
561     const struct {
562       uvec3 size;
563       int32_t K;
564       uvec3 um1_size;
565       int32_t K1;
566       uvec3 um2_size;
567       int32_t K2;
568       vec2 input_scales;
569       float out_scale;
570       float _1;
571       ivec2 input_zero_points;
572       int32_t out_zero_point;
573       int32_t _2;
574     } block_no_bias{
575         v_output.extents(),
576         safe_downcast<int32_t>(
577             div_up(v_input.sizes()[Layout::Parameter::width], INT64_C(2))),
578         v_input.extents(),
579         0u,
580         packed_v_weight.extents(),
581         0u,
582         {
583             safe_downcast<float>(v_input.get_scale()),
584             safe_downcast<float>(packed_v_weight.get_scale()),
585         },
586         safe_downcast<float>(output_scale),
587         0.0f,
588         {
589             safe_downcast<int32_t>(v_input.get_zero_point()),
590             safe_downcast<int32_t>(packed_v_weight.get_zero_point()),
591         },
592         safe_downcast<int32_t>(output_zero_point),
593         0u,
594     };
595     params = api::UniformParamsBuffer(context, block_no_bias);
596     compute_shader = (kQInt8 == input_arg.scalar_type())
597         ? VK_KERNEL(quantized_mm_qint8)
598         : VK_KERNEL(quantized_mm_quint8);
599 
600     api::PipelineBarrier pipeline_barrier{};
601 
602     context->submit_compute_job(
603         // shader descriptor
604         compute_shader,
605         // pipeline barrier
606         pipeline_barrier,
607         // global work group size
608         {
609             safe_downcast<uint32_t>(
610                 div_up(v_output.sizes()[Layout::Parameter::width], INT64_C(2))),
611             safe_downcast<uint32_t>(div_up(
612                 v_output.sizes()[Layout::Parameter::height], INT64_C(2))),
613             1,
614         },
615         // local work group size
616         {8, 8, 1},
617         // fence handle
618         VK_NULL_HANDLE,
619         // shader arguments
620         v_output.image(
621             pipeline_barrier,
622             api::PipelineStage::COMPUTE,
623             api::MemoryAccessType::WRITE),
624         v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE),
625         packed_v_weight.image(pipeline_barrier, api::PipelineStage::COMPUTE),
626         // params buffer
627         params.buffer());
628   }
629   Tensor output = convert(v_output);
630   if (input_arg.dim() == 2) {
631     return output;
632   } else {
633     std::vector<int64_t> shape;
634     for (const auto i : c10::irange(input_arg.dim() - 1)) {
635       shape.emplace_back(input_arg.size(i));
636     }
637     shape.emplace_back(output.size(-1));
638     return output.reshape(shape);
639   }
640 }
641 
run_addmm_context(const Tensor & input_arg,const float alpha,const float beta,const c10::intrusive_ptr<LinearPackedContext> & linear_context,bool quantized,double output_scale,int64_t output_zero_point)642 Tensor run_addmm_context(
643     const Tensor& input_arg,
644     const float alpha,
645     const float beta,
646     const c10::intrusive_ptr<LinearPackedContext>& linear_context,
647     bool quantized,
648     double output_scale,
649     int64_t output_zero_point) {
650   if (quantized) {
651     return run_quantized_addmm_context(
652         input_arg,
653         alpha,
654         beta,
655         linear_context,
656         output_scale,
657         output_zero_point);
658   }
659 
660   api::Context* const context = api::context();
661 
662   const Tensor input_arg_2d =
663       input_arg.dim() == 2 ? input_arg : reshape_to_2d(input_arg);
664   const Tensor input =
665       input_arg_2d.is_vulkan() ? input_arg_2d : input_arg_2d.vulkan();
666   const vTensor& v_input = pack_inputs_using_width_packing(input);
667 
668   const vTensor& packed_v_weight = convert(
669       linear_context->get_val(LinearPackedContext::Packed::Weight).toTensor());
670   const vTensor& packed_v_bias = convert(
671       linear_context->get_val(LinearPackedContext::Packed::Bias).toTensor());
672   const std::vector<int64_t> unpacked_weight_sizes =
673       linear_context->get_val(LinearPackedContext::Packed::WeightSizes)
674           .toIntVector();
675 
676   TORCH_CHECK(
677       usable(input, unpacked_weight_sizes),
678       "Vulkan Linear not usable! "
679       "Reason: The provided input tensor is either invalid on its own, or its "
680       "combination with the provided weight and bias tensors are unsupported by "
681       "Vulkan impl.");
682 
683   TORCH_CHECK(
684       v_input.gpu_memory_layout() == api::GPUMemoryLayout::TENSOR_WIDTH_PACKED,
685       "run_addmm_context must have width packed input");
686 
687   TORCH_CHECK(
688       packed_v_weight.gpu_memory_layout() ==
689           api::GPUMemoryLayout::TENSOR_HEIGHT_PACKED,
690       "run_addmm_context must have height packed weight");
691 
692   vTensor v_output{
693       context,
694       {
695           input_arg_2d.sizes()[Layout::Parameter::height],
696           unpacked_weight_sizes[Layout::Parameter::width],
697       },
698       v_input.dtype(),
699   };
700 
701   api::UniformParamsBuffer params;
702   api::ShaderInfo compute_shader;
703   // Step size is the 2d input's w dimension / 4.
704   int step_size = div_up(v_input.sizes()[Layout::Parameter::width], INT64_C(4));
705 
706   const struct {
707     uvec3 shader_extents;
708     uint32_t mm_step_size;
709   } block_no_bias{
710       v_output.extents(),
711       safe_downcast<uint32_t>(step_size),
712   };
713   params = api::UniformParamsBuffer(context, block_no_bias);
714   compute_shader = VK_KERNEL(mm);
715 
716   api::PipelineBarrier pipeline_barrier{};
717 
718   context->submit_compute_job(
719       // shader descriptor
720       compute_shader,
721       // pipeline barrier
722       pipeline_barrier,
723       // global work group size
724       {
725           safe_downcast<uint32_t>(
726               div_up(v_output.sizes()[Layout::Parameter::width], INT64_C(4))),
727           safe_downcast<uint32_t>(
728               div_up(v_output.sizes()[Layout::Parameter::height], INT64_C(4))),
729           1,
730       },
731       // local work group size
732       {8, 8, 1},
733       // fence handle
734       VK_NULL_HANDLE,
735       // shader arguments
736       v_output.image(
737           pipeline_barrier,
738           api::PipelineStage::COMPUTE,
739           api::MemoryAccessType::WRITE),
740       v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE),
741       packed_v_weight.image(pipeline_barrier, api::PipelineStage::COMPUTE),
742       // params buffer
743       params.buffer());
744 
745   Tensor output = convert(v_output);
746 
747   // addmm operation, multiplying the alpha and adding bias.
748   output = output.mul(alpha).add(convert(packed_v_bias).mul(beta));
749 
750   if (input_arg.dim() == 2) {
751     return output;
752   } else {
753     std::vector<int64_t> shape;
754     for (const auto i : c10::irange(input_arg.dim() - 1)) {
755       shape.emplace_back(input_arg.size(i));
756     }
757     shape.emplace_back(output.size(-1));
758     return output.reshape(shape);
759   }
760 }
761 
run_baddbmm_context(const Tensor & input_arg,const float alpha,const float beta,const c10::intrusive_ptr<LinearPackedContext> & linear_context)762 Tensor run_baddbmm_context(
763     const Tensor& input_arg,
764     const float alpha,
765     const float beta,
766     const c10::intrusive_ptr<LinearPackedContext>& linear_context) {
767   // TODO: Refactor run_baddbmm_context and run_addmm_context into one.
768   api::Context* const context = api::context();
769 
770   TORCH_CHECK(
771       input_arg.dim() == 3,
772       "Vulkan Linear not usable! "
773       "Reason: The input has the wrong dimension; the tensor of a batch of matrices should contain 3 dimensions: batch, height, width.");
774 
775   const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan();
776   vTensor packed_v_input = pack_inputs_using_width_packing(input);
777 
778   const vTensor& packed_v_weight = convert(
779       linear_context->get_val(LinearPackedContext::Packed::Weight).toTensor());
780   const vTensor& packed_v_bias = convert(
781       linear_context->get_val(LinearPackedContext::Packed::Bias).toTensor());
782   const std::vector<int64_t> unpacked_weight_sizes =
783       linear_context->get_val(LinearPackedContext::Packed::WeightSizes)
784           .toIntVector();
785 
786   TORCH_CHECK(
787       usable(input, unpacked_weight_sizes, true /*use batch*/),
788       "Vulkan Linear not usable! "
789       "Reason: The provided input tensor is either invalid on its own, or its "
790       "combination with the provided weight and bias tensors are unsupported by "
791       "Vulkan impl.");
792 
793   TORCH_CHECK(
794       packed_v_input.gpu_memory_layout() ==
795           api::GPUMemoryLayout::TENSOR_WIDTH_PACKED,
796       "run_addmm_context called for non-quantized version with unpacked weight");
797 
798   TORCH_CHECK(
799       packed_v_weight.gpu_memory_layout() ==
800           api::GPUMemoryLayout::TENSOR_HEIGHT_PACKED,
801       "run_addmm_context called for non-quantized version with unpacked weight");
802 
803   // In the shader, each batch is computed in separate invocation.
804   // The result is stored in the .x position of the texel.
805   // As the tensor by default is channel packed, the shader is effectively
806   // producing 3 all-zeros layer. We workaround this issue by creating
807   // a vTensor that is 4 times the batch size.
808   // At the end of the computation, we run a "slice" with a step-size of 4
809   // to get back the original shape.
810 
811   int64_t input_batch = packed_v_input.sizes()[Layout::BatchMatrices::batch];
812 
813   // Step size is the input's w dimension / 4.
814   int64_t input_width = packed_v_input.sizes()[Layout::BatchMatrices::width];
815   int64_t mm_step_size = div_up(input_width, INT64_C(4));
816 
817   vTensor v_output{
818       context,
819       {
820           input_batch * 4,
821           packed_v_input.sizes()[Layout::BatchMatrices::height],
822           unpacked_weight_sizes.back(), // "w" dimension in weight matrix
823       },
824       packed_v_input.dtype(),
825   };
826 
827   const struct {
828     uvec3 shader_extents;
829     uint32_t mm_step_size;
830   } block_no_bias{
831       v_output.extents(),
832       safe_downcast<uint32_t>(mm_step_size),
833   };
834 
835   api::UniformParamsBuffer params(context, block_no_bias);
836 
837   api::PipelineBarrier pipeline_barrier{};
838 
839   context->submit_compute_job(
840       // shader descriptor
841       VK_KERNEL(mm),
842       // pipeline barrier
843       pipeline_barrier,
844       // global work group size
845       {
846           safe_downcast<uint32_t>(div_up(
847               v_output.sizes()[Layout::BatchMatrices::width], INT64_C(4))),
848           safe_downcast<uint32_t>(div_up(
849               v_output.sizes()[Layout::BatchMatrices::height], INT64_C(4))),
850           safe_downcast<uint32_t>(
851               v_output.sizes()[Layout::BatchMatrices::batch]),
852       },
853       // local work group size
854       {8, 8, 1},
855       // fence handle
856       VK_NULL_HANDLE,
857       // shader arguments
858       v_output.image(
859           pipeline_barrier,
860           api::PipelineStage::COMPUTE,
861           api::MemoryAccessType::WRITE),
862       packed_v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE),
863       packed_v_weight.image(pipeline_barrier, api::PipelineStage::COMPUTE),
864       // params buffer
865       params.buffer());
866 
867   // After computing the multiplication, we need to slice 4 on the batch
868   // dimension to get the channel packed layout.
869   auto mm_output_unpacked = convert(v_output);
870   int step = 4;
871   auto mm_output = mm_output_unpacked.slice(
872       Layout::BatchMatrices::batch, 0, input_batch * step, step);
873 
874   return mm_output.mul(alpha).add(convert(packed_v_bias).mul(beta));
875 }
876 
addmm(const Tensor & bias,const Tensor & input,const Tensor & weight,const Scalar & beta,const Scalar & alpha)877 Tensor addmm(
878     const Tensor& bias,
879     const Tensor& input,
880     const Tensor& weight,
881     const Scalar& beta,
882     const Scalar& alpha) {
883   return run_addmm_context(
884       input,
885       alpha.to<float>(),
886       beta.to<float>(),
887       c10::make_intrusive<LinearPackedContext>(
888           LinearPackedContext(weight, bias)),
889       false,
890       0,
891       0);
892 }
893 
mm(const Tensor & mat1_arg,const Tensor & mat2_arg)894 Tensor mm(const Tensor& mat1_arg, const Tensor& mat2_arg) {
895   return run_addmm_context(
896       mat1_arg,
897       1.0f,
898       1.0f,
899       c10::make_intrusive<LinearPackedContext>(
900           LinearPackedContext(mat2_arg, std::optional<Tensor>())),
901       false,
902       0,
903       0);
904 }
905 
bmm(const Tensor & mat1_arg,const Tensor & mat2_arg)906 Tensor bmm(const Tensor& mat1_arg, const Tensor& mat2_arg) {
907   return run_baddbmm_context(
908       mat1_arg,
909       1.0f,
910       1.0f,
911       c10::make_intrusive<LinearPackedContext>(LinearPackedContext(
912           mat2_arg, std::optional<Tensor>(), true /*use batch*/)));
913 }
914 
baddbmm(const Tensor & bias,const Tensor & input,const Tensor & weight,const Scalar & beta,const Scalar & alpha)915 Tensor baddbmm(
916     const Tensor& bias,
917     const Tensor& input,
918     const Tensor& weight,
919     const Scalar& beta,
920     const Scalar& alpha) {
921   return run_baddbmm_context(
922       input,
923       alpha.to<float>(),
924       beta.to<float>(),
925       c10::make_intrusive<LinearPackedContext>(
926           LinearPackedContext(weight, bias, true /*use batch*/)));
927 }
928 
929 #ifdef USE_VULKAN_API
930 
TORCH_LIBRARY_IMPL(aten,Vulkan,m)931 TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
932   m.impl(TORCH_SELECTIVE_NAME("aten::addmm"), TORCH_FN(addmm));
933   m.impl(TORCH_SELECTIVE_NAME("aten::mm"), TORCH_FN(mm));
934   m.impl(TORCH_SELECTIVE_NAME("aten::bmm"), TORCH_FN(bmm));
935   m.impl(TORCH_SELECTIVE_NAME("aten::baddbmm"), TORCH_FN(baddbmm));
936 }
937 
938 #endif /* USE_VULKAN_API */
939 
940 } // namespace
941 
LinearPackedContext(const Tensor & weight,const std::optional<Tensor> & bias,const bool use_batch)942 LinearPackedContext::LinearPackedContext(
943     const Tensor& weight,
944     const std::optional<Tensor>& bias,
945     const bool use_batch)
946     : unpacked_{c10::AnyType::get()} {
947   TORCH_CHECK(
948       available(weight, bias, use_batch),
949       "Vulkan Linear not available! "
950       "Reason: The provided (weight, bias) parameters are either invalid "
951       "individually or their combination is not supported by Vulkan Impl.");
952 
953   packed_.reserve(Packed::NumArgs);
954   packed_.emplace_back(convert(pack_weights(weight, use_batch)));
955   const auto& packed_biases = weight.is_quantized()
956       ? pack_biases_quantized_weights(weight, bias, use_batch)
957       : pack_biases(weight, bias, use_batch);
958   packed_.emplace_back(convert(packed_biases));
959   packed_.emplace_back(weight.sizes());
960   packed_.emplace_back(bias && bias->defined());
961 
962   if (!at::globalContext().releaseWeightsWhenPrepacking()) {
963     unpacked_.reserve(Unpacked::NumArgs);
964     unpacked_.emplace_back(weight);
965     unpacked_.emplace_back(bias);
966   }
967 }
968 
pack(c10::impl::GenericList unpacked)969 LinearPackedContext LinearPackedContext::pack(c10::impl::GenericList unpacked) {
970   return LinearPackedContext(
971       unpacked.get(Unpacked::Weight).toTensor(),
972       get_optional_tensor(unpacked, Unpacked::Bias));
973 }
974 
create_linear_context(Tensor && weight,std::optional<Tensor> && bias)975 c10::intrusive_ptr<LinearPackedContext> create_linear_context(
976     Tensor&& weight,
977     std::optional<Tensor>&& bias) {
978   return c10::make_intrusive<LinearPackedContext>(
979       LinearPackedContext(weight, bias));
980 }
981 
run_linear_context(const Tensor & input,const c10::intrusive_ptr<LinearPackedContext> & linear_context)982 Tensor run_linear_context(
983     const Tensor& input,
984     const c10::intrusive_ptr<LinearPackedContext>& linear_context) {
985   return run_addmm_context(input, 1.0f, 1.0f, linear_context, false, 0, 0);
986 }
987 
run_qlinear_context(const Tensor & input_arg,double output_scale,int64_t output_zero_point,const c10::intrusive_ptr<LinearPackedContext> & linear_context)988 Tensor run_qlinear_context(
989     const Tensor& input_arg,
990     double output_scale,
991     int64_t output_zero_point,
992     const c10::intrusive_ptr<LinearPackedContext>& linear_context) {
993   return run_addmm_context(
994       input_arg,
995       1.0f,
996       1.0f,
997       linear_context,
998       true,
999       output_scale,
1000       output_zero_point);
1001 }
1002 
1003 } // namespace ops
1004 } // namespace vulkan
1005 } // namespace native
1006 } // namespace at
1007