1 #include <ATen/native/vulkan/ops/Mm.h>
2 #include <ATen/native/vulkan/ops/Utils.h>
3
4 #include <ATen/Context.h>
5 #include <ATen/Functions.h>
6 #include <ATen/native/vulkan/api/Tensor.h>
7 #include <ATen/native/vulkan/api/Types.h>
8 #include <ATen/native/vulkan/impl/Packing.h>
9 #include <c10/util/irange.h>
10
11 namespace at {
12 namespace native {
13 namespace vulkan {
14 namespace ops {
15 namespace {
16
17 using namespace api::utils;
18 using namespace at::native::vulkan::ops;
19
pack_inputs_using_width_packing(const Tensor & input_arg)20 vTensor pack_inputs_using_width_packing(const Tensor& input_arg) {
21 TORCH_INTERNAL_ASSERT(
22 !input_arg.is_quantized(),
23 "Vulkan Linear not usable! "
24 "Reason: Input packing only supports non-quantized tensors.");
25 TORCH_INTERNAL_ASSERT(
26 input_arg.dim() == 2 || input_arg.dim() == 3,
27 "Vulkan Linear not usable! "
28 "Reason: Input packing only supports 2D or 3D tensors.");
29
30 Tensor input = input_arg;
31 if (input.is_cpu()) {
32 input = input.vulkan();
33 }
34
35 TORCH_CHECK(input.is_vulkan(), "Input must be on Vulkan device!");
36
37 vTensor v_input = convert(input);
38 if (v_input.gpu_memory_layout() ==
39 api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED) {
40 v_input = packing::convert_image_channels_packed_to_width_packed(v_input);
41 }
42
43 TORCH_CHECK(
44 v_input.gpu_memory_layout() == api::GPUMemoryLayout::TENSOR_WIDTH_PACKED,
45 "After packing, the v_input must be in TENSOR_WIDTH_PACKED format");
46
47 return v_input;
48 }
49
pack_weights_using_height_packing(const Tensor & weight_arg)50 vTensor pack_weights_using_height_packing(const Tensor& weight_arg) {
51 // Only non-batch, non-quantized tensors are supported
52 TORCH_INTERNAL_ASSERT(
53 !weight_arg.is_quantized(),
54 "Vulkan Linear not usable! "
55 "Reason: Weight packing only supports non-quantized tensors.");
56 TORCH_INTERNAL_ASSERT(
57 weight_arg.dim() == 2 || weight_arg.dim() == 3,
58 "Vulkan Linear not usable! "
59 "Reason: Weight packing only supports 2D or 3D tensors.");
60
61 Tensor weight = weight_arg;
62
63 if (weight.is_cpu()) {
64 weight = weight.vulkan();
65 }
66
67 TORCH_CHECK(weight.is_vulkan(), "Weight must be on Vulkan device!");
68
69 vTensor v_weight = convert(weight);
70 if (v_weight.gpu_memory_layout() ==
71 api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED) {
72 v_weight =
73 packing::convert_image_channels_packed_to_height_packed(v_weight);
74 }
75
76 TORCH_CHECK(
77 v_weight.gpu_memory_layout() ==
78 api::GPUMemoryLayout::TENSOR_HEIGHT_PACKED,
79 "After packing, the v_weight must be in TENSOR_HEIGHT_PACKED format");
80
81 return v_weight;
82 }
83
pack_weights(const Tensor & weight_arg,const bool use_batch=false)84 vTensor pack_weights(const Tensor& weight_arg, const bool use_batch = false) {
85 if (!weight_arg.is_quantized()) {
86 return pack_weights_using_height_packing(weight_arg);
87 }
88
89 TORCH_CHECK(
90 weight_arg.is_quantized(), "Only quantized weights logic after here");
91
92 // Rest of the logic are either quantized or batched.
93
94 api::Context* const context = api::context();
95
96 const Tensor weight = weight_arg.contiguous();
97 const IntArrayRef w_sizes = weight.sizes();
98 if (use_batch) {
99 TORCH_CHECK(
100 w_sizes.size() == 3,
101 "Vulkan Linear not usable! "
102 "Reason: Unable to perform weight packing with batch; the input tensor of a batch of matrices should contain 3 dimensions: batch, height, width.");
103 }
104 /* Source */
105 int64_t src_kb_sz = 0;
106 int64_t src_kw_sz = 0;
107 int64_t src_kh_sz = 0;
108 /* Destination */
109 int64_t dst_kb_sz = 0;
110 int64_t dst_kw_sz = 0;
111 int64_t dst_kh_sz = 0;
112 std::vector<int64_t> dst_vtensor_sizes;
113 /* Source */
114 src_kb_sz = use_batch ? w_sizes[Layout::BatchMatrices::batch] : 1;
115 src_kw_sz = use_batch ? w_sizes[Layout::BatchMatrices::width]
116 : w_sizes[Layout::Parameter::width];
117 src_kh_sz = use_batch ? w_sizes[Layout::BatchMatrices::height]
118 : w_sizes[Layout::Parameter::height];
119
120 /* Destination */
121 dst_kb_sz = src_kb_sz;
122 dst_kw_sz = div_up(src_kw_sz, INT64_C(2));
123 dst_kh_sz = div_up(src_kh_sz, INT64_C(2));
124 dst_vtensor_sizes = {
125 dst_kb_sz,
126 4,
127 dst_kh_sz,
128 dst_kw_sz,
129 };
130
131 vTensor v_weight{
132 context, dst_vtensor_sizes, convert_dtype(weight_arg.scalar_type())};
133
134 v_weight.set_is_quantized();
135 v_weight.set_scale(weight_arg.q_scale());
136 v_weight.set_zero_point(weight_arg.q_zero_point());
137
138 stage_pack_weights<int8_t>(
139 context,
140 v_weight,
141 weight,
142 src_kb_sz,
143 src_kh_sz,
144 src_kw_sz,
145 dst_kh_sz,
146 dst_kw_sz);
147 return v_weight;
148 }
149
pack_biases(const Tensor & weight_arg,const std::optional<Tensor> & bias_arg,const bool use_batch=false)150 vTensor pack_biases(
151 const Tensor& weight_arg,
152 const std::optional<Tensor>& bias_arg,
153 const bool use_batch = false) {
154 if (bias_arg) {
155 Tensor bias = *bias_arg;
156 if (bias.is_cpu()) {
157 bias = bias.vulkan();
158 }
159 return convert(bias);
160 } else {
161 return convert(at::zeros({}, at::device(at::kVulkan).dtype(at::kFloat)));
162 }
163 }
164
165 // Old version of pack_biases that fixes issues with quantization and to be
166 // removed in the future.
pack_biases_quantized_weights(const Tensor & weight_arg,const std::optional<Tensor> & bias_arg,const bool use_batch=false)167 vTensor pack_biases_quantized_weights(
168 const Tensor& weight_arg,
169 const std::optional<Tensor>& bias_arg,
170 const bool use_batch = false) {
171 TORCH_CHECK(
172 weight_arg.is_quantized(),
173 "pack_biases_quantized to be used only when using quantized linear ops");
174
175 if (bias_arg && bias_arg->is_vulkan()) {
176 return convert(*bias_arg);
177 }
178
179 api::Context* const context = api::context();
180
181 if (bias_arg) {
182 const Tensor bias = bias_arg->contiguous();
183 const IntArrayRef b_sizes = bias.sizes();
184 const float* const src_bias_ptr = bias.const_data_ptr<float>();
185
186 /* Source */
187 int64_t src_kb_sz = 0;
188 int64_t src_kw_sz = 0;
189 int64_t src_kh_sz = 0;
190 if (use_batch) {
191 if (bias.sizes().size() == 3) {
192 src_kb_sz = b_sizes[Layout::BatchMatrices::batch];
193 src_kw_sz = b_sizes[Layout::BatchMatrices::width];
194 src_kh_sz = b_sizes[Layout::BatchMatrices::height];
195 } else if (bias.sizes().size() == 2) {
196 // skip batch dim for boardcasting; index -1
197 src_kb_sz = 1;
198 src_kw_sz = b_sizes[Layout::BatchMatrices::height];
199 src_kh_sz = b_sizes[Layout::BatchMatrices::batch];
200 } else {
201 // skip batch & height dim for boardcasting; index -2
202 src_kb_sz = 1;
203 src_kw_sz = b_sizes[Layout::BatchMatrices::batch];
204 src_kh_sz = 1;
205 }
206 } else {
207 src_kb_sz = 1;
208 if (bias.sizes().size() == 2) {
209 src_kw_sz = b_sizes[Layout::Parameter::width];
210 src_kh_sz = b_sizes[Layout::Parameter::height];
211 } else {
212 src_kw_sz = b_sizes[Layout::Parameter::height];
213 src_kh_sz = 1;
214 }
215 }
216 const int64_t src_matrix_sz = src_kw_sz * src_kh_sz;
217
218 /* Destination */
219 const int64_t dst_kw_sz = div_up(src_kw_sz, INT64_C(2));
220 const int64_t dst_kh_sz = div_up(src_kh_sz, INT64_C(2));
221 const int64_t dst_plane_sz = dst_kw_sz * dst_kh_sz;
222 const int64_t dst_matrix_sz = dst_plane_sz * 4;
223
224 vTensor v_bias{
225 context,
226 {
227 src_kb_sz,
228 4,
229 dst_kh_sz,
230 dst_kw_sz,
231 },
232 convert_dtype(bias_arg->scalar_type()),
233 };
234
235 api::StorageBuffer staging(
236 context, api::ScalarType::Float, v_bias.gpu_numel());
237 {
238 api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE);
239
240 float* dst_bias_ptr = mapping.template data<float>();
241
242 memset(dst_bias_ptr, 0, v_bias.nbytes());
243
244 for (const auto src_b : c10::irange(src_kb_sz)) {
245 for (const auto src_h : c10::irange(src_kh_sz == 1 ? 2 : src_kh_sz)) {
246 for (const auto src_w :
247 c10::irange((use_batch && src_kw_sz == 1) ? 2 : src_kw_sz)) {
248 int64_t dst_plane = 2 * (src_h % 2) + (src_w % 2);
249 int64_t dst_index = (src_h / 2) * dst_kw_sz + (src_w / 2);
250 memcpy(
251 dst_bias_ptr + src_b * dst_matrix_sz +
252 dst_plane * dst_plane_sz + dst_index,
253 src_bias_ptr + src_b * src_matrix_sz +
254 (src_kh_sz == 1 ? 0 : src_h * src_kw_sz) +
255 ((use_batch && src_kw_sz == 1) ? 0 : src_w),
256 sizeof(float));
257 }
258 }
259 }
260 }
261 utils::pack_staging_to_vtensor(staging.buffer(), v_bias);
262
263 return v_bias;
264 } else {
265 vTensor v_bias{
266 api::context(),
267 {1},
268 convert_dtype(weight_arg.scalar_type()),
269 };
270
271 api::StorageBuffer staging(
272 context, api::ScalarType::Float, v_bias.gpu_numel());
273 {
274 api::MemoryMap mapping(staging.buffer(), api::MemoryAccessType::WRITE);
275
276 float* data_ptr = mapping.template data<float>();
277
278 memset(
279 data_ptr,
280 // 2's complement integers and IEEE-754 floating point numbers both
281 // have identical bit representations for 0, so can use memset which
282 // only accepts uint8_t parameter.
283 0,
284 v_bias.nbytes());
285 }
286 utils::pack_staging_to_vtensor(staging.buffer(), v_bias);
287
288 return v_bias;
289 }
290 }
291
available_check_with_batch(const Tensor & weight,const std::optional<Tensor> & bias)292 bool available_check_with_batch(
293 const Tensor& weight,
294 const std::optional<Tensor>& bias) {
295 const bool weight_available = (3 == weight.ndimension()) &&
296 (weight.size(Layout::BatchMatrices::batch) > 0) &&
297 (weight.size(Layout::BatchMatrices::height) > 0) &&
298 (weight.size(Layout::BatchMatrices::width) > 0) &&
299 ((weight.device().is_cpu()) ||
300 (c10::DeviceType::Vulkan == weight.device().type())) &&
301 (kFloat == weight.scalar_type()) && !weight.requires_grad();
302 if (!weight_available) {
303 return false;
304 }
305
306 if (!bias || !bias->defined()) {
307 // no need to check bias since it is not used.
308 return true;
309 }
310
311 bool bias_available = true;
312 bias_available &= (bias->ndimension() > 0);
313 bias_available &=
314 ((bias->device().is_cpu()) ||
315 (c10::DeviceType::Vulkan == bias->device().type()));
316 bias_available &= (kFloat == bias->scalar_type());
317 // Only check the consistency of batch and width dimension. The height
318 // dimension consistency is unchecked, due to the 2nd input which determines
319 // the height is not passed into LinearPackedContext.
320 if (bias->ndimension() == 3) {
321 bias_available &=
322 (bias->size(Layout::BatchMatrices::width) ==
323 weight.size(Layout::BatchMatrices::width) ||
324 bias->size(Layout::BatchMatrices::width) == 1);
325 bias_available &=
326 (bias->size(Layout::BatchMatrices::batch) ==
327 weight.size(Layout::BatchMatrices::batch) ||
328 bias->size(Layout::BatchMatrices::batch) == 1);
329 } else if (bias->ndimension() == 2) {
330 // skip batch dim for boardcasting; index -1
331 bias_available &=
332 (bias->size(Layout::BatchMatrices::height) ==
333 weight.size(Layout::BatchMatrices::width) ||
334 bias->size(Layout::BatchMatrices::height) == 1);
335 } else {
336 // skip batch & height dim for boardcasting; index -2
337 bias_available &=
338 (bias->size(Layout::BatchMatrices::batch) ==
339 weight.size(Layout::BatchMatrices::width) ||
340 bias->size(Layout::BatchMatrices::batch) == 1);
341 }
342 bias_available &= !bias->requires_grad();
343 return bias_available;
344 }
345
available(const Tensor & weight,const std::optional<Tensor> & bias,const bool use_batch=false)346 bool available(
347 const Tensor& weight,
348 const std::optional<Tensor>& bias,
349 const bool use_batch = false) {
350 if (!api::available()) {
351 return false;
352 }
353
354 if (use_batch) {
355 return available_check_with_batch(weight, bias);
356 }
357
358 const bool weight_available = (2 == weight.ndimension()) &&
359 (weight.size(Layout::Parameter::height) > 0) &&
360 (weight.size(Layout::Parameter::width) > 0) &&
361 ((weight.device().is_cpu()) ||
362 (c10::DeviceType::Vulkan == weight.device().type())) &&
363 (kFloat == weight.scalar_type() || kQInt8 == weight.scalar_type()) &&
364 !weight.requires_grad();
365 if (!weight_available) {
366 return false;
367 }
368
369 const bool bias_available =
370 ((bias && bias.has_value() && bias->defined())
371 ? ((bias->ndimension() > 0) &&
372 ((bias->device().is_cpu()) ||
373 (c10::DeviceType::Vulkan == bias->device().type())) &&
374 (kFloat == bias->scalar_type()) &&
375 ((bias->ndimension() > 1)
376 ? (bias->size(Layout::Parameter::width) ==
377 weight.size(Layout::Parameter::width))
378 : true) &&
379 !bias->requires_grad())
380 : true);
381 return bias_available;
382 }
383
usable_check_with_batch(const Tensor & input,const IntArrayRef unpacked_weight_sizes)384 bool usable_check_with_batch(
385 const Tensor& input,
386 const IntArrayRef unpacked_weight_sizes) {
387 return (3 == input.ndimension()) &&
388 (c10::DeviceType::Vulkan == input.device().type()) &&
389 (kFloat == input.scalar_type()) &&
390 (input.size(Layout::BatchMatrices::width) ==
391 unpacked_weight_sizes[Layout::BatchMatrices::height]) &&
392 (input.size(Layout::BatchMatrices::batch) ==
393 unpacked_weight_sizes[Layout::BatchMatrices::batch]) &&
394 !input.requires_grad() && true;
395 }
396
usable(const Tensor & input,const IntArrayRef unpacked_weight_sizes,const bool use_batch=false)397 bool usable(
398 const Tensor& input,
399 const IntArrayRef unpacked_weight_sizes,
400 const bool use_batch = false) {
401 if (use_batch) {
402 return usable_check_with_batch(input, unpacked_weight_sizes);
403 }
404 const auto v_input = convert(input);
405 return (2 == input.ndimension()) &&
406 (c10::DeviceType::Vulkan == input.device().type()) &&
407 ((kFloat == input.scalar_type()) ||
408 (v_input.is_quantized() &&
409 (kQUInt8 == input.scalar_type() || kQInt8 == input.scalar_type()))) &&
410 (input.size(Layout::Parameter::width) ==
411 unpacked_weight_sizes[Layout::Parameter::height]) &&
412 !input.requires_grad() && true;
413 }
414
reshape_to_2d(const Tensor & input_arg)415 static Tensor reshape_to_2d(const Tensor& input_arg) {
416 TORCH_CHECK(
417 input_arg.dim() >= 1,
418 "Vulkan Linear op only supports input tensor with dim >= 1");
419
420 if (input_arg.dim() == 1) {
421 return input_arg.unsqueeze(0);
422 }
423 const IntArrayRef input_sizes = input_arg.sizes();
424 const auto d =
425 c10::multiply_integers(input_sizes.cbegin(), input_sizes.end() - 1);
426 return input_arg.reshape({d, input_arg.size(-1)});
427 }
428
run_quantized_addmm_context(const Tensor & input_arg,const float alpha,const float beta,const c10::intrusive_ptr<LinearPackedContext> & linear_context,double output_scale,int64_t output_zero_point)429 Tensor run_quantized_addmm_context(
430 const Tensor& input_arg,
431 const float alpha,
432 const float beta,
433 const c10::intrusive_ptr<LinearPackedContext>& linear_context,
434 double output_scale,
435 int64_t output_zero_point) {
436 api::Context* const context = api::context();
437
438 const Tensor input_arg_2d =
439 input_arg.dim() == 2 ? input_arg : reshape_to_2d(input_arg);
440 const Tensor input =
441 input_arg_2d.is_vulkan() ? input_arg_2d : input_arg_2d.vulkan();
442 const vTensor& v_input = convert(input);
443 const vTensor& packed_v_weight = convert(
444 linear_context->get_val(LinearPackedContext::Packed::Weight).toTensor());
445 const vTensor& packed_v_bias = convert(
446 linear_context->get_val(LinearPackedContext::Packed::Bias).toTensor());
447 const std::vector<int64_t> unpacked_weight_sizes =
448 linear_context->get_val(LinearPackedContext::Packed::WeightSizes)
449 .toIntVector();
450 const bool bias_defined =
451 linear_context->get_val(LinearPackedContext::Packed::BiasDefined)
452 .toBool();
453
454 TORCH_CHECK(
455 usable(input, unpacked_weight_sizes),
456 "Vulkan Linear not usable! "
457 "Reason: The provided input tensor is either invalid on its own, or its "
458 "combination with the provided weight and bias tensors are unsupported by "
459 "Vulkan impl.");
460
461 TORCH_CHECK(
462 (packed_v_weight.is_quantized() && v_input.is_quantized()),
463 "run_quantized_addmm_context called for quantized version with unquantized input");
464
465 vTensor v_output{
466 context,
467 {
468 input_arg_2d.sizes()[Layout::Parameter::height],
469 unpacked_weight_sizes[Layout::Parameter::width],
470 },
471 v_input.dtype(),
472 };
473
474 v_output.set_is_quantized();
475 v_output.set_scale(output_scale);
476 v_output.set_zero_point(output_zero_point);
477
478 if (bias_defined) {
479 api::UniformParamsBuffer params;
480 api::ShaderInfo compute_shader;
481 compute_shader = (kQInt8 == input_arg.scalar_type())
482 ? VK_KERNEL(quantized_addmm_qint8)
483 : VK_KERNEL(quantized_addmm_quint8);
484 const struct {
485 uvec3 size;
486 int32_t K;
487 uvec3 um1_size;
488 int32_t K1;
489 uvec3 um2_size;
490 int32_t K2;
491 uvec3 ut_size;
492 int32_t K3;
493 vec2 multiplier;
494 vec2 input_scales;
495 float out_scale;
496 float _1;
497 ivec2 input_zero_points;
498 int32_t out_zero_point;
499 int32_t _2;
500 } block{
501 v_output.extents(),
502 safe_downcast<int32_t>(
503 div_up(v_input.sizes()[Layout::Parameter::width], INT64_C(2))),
504 v_input.extents(),
505 0u,
506 packed_v_weight.extents(),
507 0u,
508 packed_v_bias.extents(),
509 0u,
510 {
511 alpha,
512 beta,
513 },
514 {
515 safe_downcast<float>(v_input.get_scale()),
516 safe_downcast<float>(packed_v_weight.get_scale()),
517 },
518 safe_downcast<float>(output_scale),
519 0.0f,
520 {
521 safe_downcast<int32_t>(v_input.get_zero_point()),
522 safe_downcast<int32_t>(packed_v_weight.get_zero_point()),
523 },
524 safe_downcast<int32_t>(output_zero_point),
525 0u,
526 };
527 params = api::UniformParamsBuffer(context, block);
528
529 api::PipelineBarrier pipeline_barrier{};
530 context->submit_compute_job(
531 // shader descriptor
532 compute_shader,
533 // pipeline barrier
534 pipeline_barrier,
535 // global work group size
536 {
537 safe_downcast<uint32_t>(
538 div_up(v_output.sizes()[Layout::Parameter::width], INT64_C(2))),
539 safe_downcast<uint32_t>(div_up(
540 v_output.sizes()[Layout::Parameter::height], INT64_C(2))),
541 1,
542 },
543 // local work group size
544 {8, 8, 1},
545 // fence handle
546 VK_NULL_HANDLE,
547 // shader arguments
548 v_output.image(
549 pipeline_barrier,
550 api::PipelineStage::COMPUTE,
551 api::MemoryAccessType::WRITE),
552 v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE),
553 packed_v_weight.image(pipeline_barrier, api::PipelineStage::COMPUTE),
554 packed_v_bias.image(pipeline_barrier, api::PipelineStage::COMPUTE),
555 // params buffer
556 params.buffer());
557
558 } else { // no bias
559 api::UniformParamsBuffer params;
560 api::ShaderInfo compute_shader;
561 const struct {
562 uvec3 size;
563 int32_t K;
564 uvec3 um1_size;
565 int32_t K1;
566 uvec3 um2_size;
567 int32_t K2;
568 vec2 input_scales;
569 float out_scale;
570 float _1;
571 ivec2 input_zero_points;
572 int32_t out_zero_point;
573 int32_t _2;
574 } block_no_bias{
575 v_output.extents(),
576 safe_downcast<int32_t>(
577 div_up(v_input.sizes()[Layout::Parameter::width], INT64_C(2))),
578 v_input.extents(),
579 0u,
580 packed_v_weight.extents(),
581 0u,
582 {
583 safe_downcast<float>(v_input.get_scale()),
584 safe_downcast<float>(packed_v_weight.get_scale()),
585 },
586 safe_downcast<float>(output_scale),
587 0.0f,
588 {
589 safe_downcast<int32_t>(v_input.get_zero_point()),
590 safe_downcast<int32_t>(packed_v_weight.get_zero_point()),
591 },
592 safe_downcast<int32_t>(output_zero_point),
593 0u,
594 };
595 params = api::UniformParamsBuffer(context, block_no_bias);
596 compute_shader = (kQInt8 == input_arg.scalar_type())
597 ? VK_KERNEL(quantized_mm_qint8)
598 : VK_KERNEL(quantized_mm_quint8);
599
600 api::PipelineBarrier pipeline_barrier{};
601
602 context->submit_compute_job(
603 // shader descriptor
604 compute_shader,
605 // pipeline barrier
606 pipeline_barrier,
607 // global work group size
608 {
609 safe_downcast<uint32_t>(
610 div_up(v_output.sizes()[Layout::Parameter::width], INT64_C(2))),
611 safe_downcast<uint32_t>(div_up(
612 v_output.sizes()[Layout::Parameter::height], INT64_C(2))),
613 1,
614 },
615 // local work group size
616 {8, 8, 1},
617 // fence handle
618 VK_NULL_HANDLE,
619 // shader arguments
620 v_output.image(
621 pipeline_barrier,
622 api::PipelineStage::COMPUTE,
623 api::MemoryAccessType::WRITE),
624 v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE),
625 packed_v_weight.image(pipeline_barrier, api::PipelineStage::COMPUTE),
626 // params buffer
627 params.buffer());
628 }
629 Tensor output = convert(v_output);
630 if (input_arg.dim() == 2) {
631 return output;
632 } else {
633 std::vector<int64_t> shape;
634 for (const auto i : c10::irange(input_arg.dim() - 1)) {
635 shape.emplace_back(input_arg.size(i));
636 }
637 shape.emplace_back(output.size(-1));
638 return output.reshape(shape);
639 }
640 }
641
run_addmm_context(const Tensor & input_arg,const float alpha,const float beta,const c10::intrusive_ptr<LinearPackedContext> & linear_context,bool quantized,double output_scale,int64_t output_zero_point)642 Tensor run_addmm_context(
643 const Tensor& input_arg,
644 const float alpha,
645 const float beta,
646 const c10::intrusive_ptr<LinearPackedContext>& linear_context,
647 bool quantized,
648 double output_scale,
649 int64_t output_zero_point) {
650 if (quantized) {
651 return run_quantized_addmm_context(
652 input_arg,
653 alpha,
654 beta,
655 linear_context,
656 output_scale,
657 output_zero_point);
658 }
659
660 api::Context* const context = api::context();
661
662 const Tensor input_arg_2d =
663 input_arg.dim() == 2 ? input_arg : reshape_to_2d(input_arg);
664 const Tensor input =
665 input_arg_2d.is_vulkan() ? input_arg_2d : input_arg_2d.vulkan();
666 const vTensor& v_input = pack_inputs_using_width_packing(input);
667
668 const vTensor& packed_v_weight = convert(
669 linear_context->get_val(LinearPackedContext::Packed::Weight).toTensor());
670 const vTensor& packed_v_bias = convert(
671 linear_context->get_val(LinearPackedContext::Packed::Bias).toTensor());
672 const std::vector<int64_t> unpacked_weight_sizes =
673 linear_context->get_val(LinearPackedContext::Packed::WeightSizes)
674 .toIntVector();
675
676 TORCH_CHECK(
677 usable(input, unpacked_weight_sizes),
678 "Vulkan Linear not usable! "
679 "Reason: The provided input tensor is either invalid on its own, or its "
680 "combination with the provided weight and bias tensors are unsupported by "
681 "Vulkan impl.");
682
683 TORCH_CHECK(
684 v_input.gpu_memory_layout() == api::GPUMemoryLayout::TENSOR_WIDTH_PACKED,
685 "run_addmm_context must have width packed input");
686
687 TORCH_CHECK(
688 packed_v_weight.gpu_memory_layout() ==
689 api::GPUMemoryLayout::TENSOR_HEIGHT_PACKED,
690 "run_addmm_context must have height packed weight");
691
692 vTensor v_output{
693 context,
694 {
695 input_arg_2d.sizes()[Layout::Parameter::height],
696 unpacked_weight_sizes[Layout::Parameter::width],
697 },
698 v_input.dtype(),
699 };
700
701 api::UniformParamsBuffer params;
702 api::ShaderInfo compute_shader;
703 // Step size is the 2d input's w dimension / 4.
704 int step_size = div_up(v_input.sizes()[Layout::Parameter::width], INT64_C(4));
705
706 const struct {
707 uvec3 shader_extents;
708 uint32_t mm_step_size;
709 } block_no_bias{
710 v_output.extents(),
711 safe_downcast<uint32_t>(step_size),
712 };
713 params = api::UniformParamsBuffer(context, block_no_bias);
714 compute_shader = VK_KERNEL(mm);
715
716 api::PipelineBarrier pipeline_barrier{};
717
718 context->submit_compute_job(
719 // shader descriptor
720 compute_shader,
721 // pipeline barrier
722 pipeline_barrier,
723 // global work group size
724 {
725 safe_downcast<uint32_t>(
726 div_up(v_output.sizes()[Layout::Parameter::width], INT64_C(4))),
727 safe_downcast<uint32_t>(
728 div_up(v_output.sizes()[Layout::Parameter::height], INT64_C(4))),
729 1,
730 },
731 // local work group size
732 {8, 8, 1},
733 // fence handle
734 VK_NULL_HANDLE,
735 // shader arguments
736 v_output.image(
737 pipeline_barrier,
738 api::PipelineStage::COMPUTE,
739 api::MemoryAccessType::WRITE),
740 v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE),
741 packed_v_weight.image(pipeline_barrier, api::PipelineStage::COMPUTE),
742 // params buffer
743 params.buffer());
744
745 Tensor output = convert(v_output);
746
747 // addmm operation, multiplying the alpha and adding bias.
748 output = output.mul(alpha).add(convert(packed_v_bias).mul(beta));
749
750 if (input_arg.dim() == 2) {
751 return output;
752 } else {
753 std::vector<int64_t> shape;
754 for (const auto i : c10::irange(input_arg.dim() - 1)) {
755 shape.emplace_back(input_arg.size(i));
756 }
757 shape.emplace_back(output.size(-1));
758 return output.reshape(shape);
759 }
760 }
761
run_baddbmm_context(const Tensor & input_arg,const float alpha,const float beta,const c10::intrusive_ptr<LinearPackedContext> & linear_context)762 Tensor run_baddbmm_context(
763 const Tensor& input_arg,
764 const float alpha,
765 const float beta,
766 const c10::intrusive_ptr<LinearPackedContext>& linear_context) {
767 // TODO: Refactor run_baddbmm_context and run_addmm_context into one.
768 api::Context* const context = api::context();
769
770 TORCH_CHECK(
771 input_arg.dim() == 3,
772 "Vulkan Linear not usable! "
773 "Reason: The input has the wrong dimension; the tensor of a batch of matrices should contain 3 dimensions: batch, height, width.");
774
775 const Tensor input = input_arg.is_vulkan() ? input_arg : input_arg.vulkan();
776 vTensor packed_v_input = pack_inputs_using_width_packing(input);
777
778 const vTensor& packed_v_weight = convert(
779 linear_context->get_val(LinearPackedContext::Packed::Weight).toTensor());
780 const vTensor& packed_v_bias = convert(
781 linear_context->get_val(LinearPackedContext::Packed::Bias).toTensor());
782 const std::vector<int64_t> unpacked_weight_sizes =
783 linear_context->get_val(LinearPackedContext::Packed::WeightSizes)
784 .toIntVector();
785
786 TORCH_CHECK(
787 usable(input, unpacked_weight_sizes, true /*use batch*/),
788 "Vulkan Linear not usable! "
789 "Reason: The provided input tensor is either invalid on its own, or its "
790 "combination with the provided weight and bias tensors are unsupported by "
791 "Vulkan impl.");
792
793 TORCH_CHECK(
794 packed_v_input.gpu_memory_layout() ==
795 api::GPUMemoryLayout::TENSOR_WIDTH_PACKED,
796 "run_addmm_context called for non-quantized version with unpacked weight");
797
798 TORCH_CHECK(
799 packed_v_weight.gpu_memory_layout() ==
800 api::GPUMemoryLayout::TENSOR_HEIGHT_PACKED,
801 "run_addmm_context called for non-quantized version with unpacked weight");
802
803 // In the shader, each batch is computed in separate invocation.
804 // The result is stored in the .x position of the texel.
805 // As the tensor by default is channel packed, the shader is effectively
806 // producing 3 all-zeros layer. We workaround this issue by creating
807 // a vTensor that is 4 times the batch size.
808 // At the end of the computation, we run a "slice" with a step-size of 4
809 // to get back the original shape.
810
811 int64_t input_batch = packed_v_input.sizes()[Layout::BatchMatrices::batch];
812
813 // Step size is the input's w dimension / 4.
814 int64_t input_width = packed_v_input.sizes()[Layout::BatchMatrices::width];
815 int64_t mm_step_size = div_up(input_width, INT64_C(4));
816
817 vTensor v_output{
818 context,
819 {
820 input_batch * 4,
821 packed_v_input.sizes()[Layout::BatchMatrices::height],
822 unpacked_weight_sizes.back(), // "w" dimension in weight matrix
823 },
824 packed_v_input.dtype(),
825 };
826
827 const struct {
828 uvec3 shader_extents;
829 uint32_t mm_step_size;
830 } block_no_bias{
831 v_output.extents(),
832 safe_downcast<uint32_t>(mm_step_size),
833 };
834
835 api::UniformParamsBuffer params(context, block_no_bias);
836
837 api::PipelineBarrier pipeline_barrier{};
838
839 context->submit_compute_job(
840 // shader descriptor
841 VK_KERNEL(mm),
842 // pipeline barrier
843 pipeline_barrier,
844 // global work group size
845 {
846 safe_downcast<uint32_t>(div_up(
847 v_output.sizes()[Layout::BatchMatrices::width], INT64_C(4))),
848 safe_downcast<uint32_t>(div_up(
849 v_output.sizes()[Layout::BatchMatrices::height], INT64_C(4))),
850 safe_downcast<uint32_t>(
851 v_output.sizes()[Layout::BatchMatrices::batch]),
852 },
853 // local work group size
854 {8, 8, 1},
855 // fence handle
856 VK_NULL_HANDLE,
857 // shader arguments
858 v_output.image(
859 pipeline_barrier,
860 api::PipelineStage::COMPUTE,
861 api::MemoryAccessType::WRITE),
862 packed_v_input.image(pipeline_barrier, api::PipelineStage::COMPUTE),
863 packed_v_weight.image(pipeline_barrier, api::PipelineStage::COMPUTE),
864 // params buffer
865 params.buffer());
866
867 // After computing the multiplication, we need to slice 4 on the batch
868 // dimension to get the channel packed layout.
869 auto mm_output_unpacked = convert(v_output);
870 int step = 4;
871 auto mm_output = mm_output_unpacked.slice(
872 Layout::BatchMatrices::batch, 0, input_batch * step, step);
873
874 return mm_output.mul(alpha).add(convert(packed_v_bias).mul(beta));
875 }
876
addmm(const Tensor & bias,const Tensor & input,const Tensor & weight,const Scalar & beta,const Scalar & alpha)877 Tensor addmm(
878 const Tensor& bias,
879 const Tensor& input,
880 const Tensor& weight,
881 const Scalar& beta,
882 const Scalar& alpha) {
883 return run_addmm_context(
884 input,
885 alpha.to<float>(),
886 beta.to<float>(),
887 c10::make_intrusive<LinearPackedContext>(
888 LinearPackedContext(weight, bias)),
889 false,
890 0,
891 0);
892 }
893
mm(const Tensor & mat1_arg,const Tensor & mat2_arg)894 Tensor mm(const Tensor& mat1_arg, const Tensor& mat2_arg) {
895 return run_addmm_context(
896 mat1_arg,
897 1.0f,
898 1.0f,
899 c10::make_intrusive<LinearPackedContext>(
900 LinearPackedContext(mat2_arg, std::optional<Tensor>())),
901 false,
902 0,
903 0);
904 }
905
bmm(const Tensor & mat1_arg,const Tensor & mat2_arg)906 Tensor bmm(const Tensor& mat1_arg, const Tensor& mat2_arg) {
907 return run_baddbmm_context(
908 mat1_arg,
909 1.0f,
910 1.0f,
911 c10::make_intrusive<LinearPackedContext>(LinearPackedContext(
912 mat2_arg, std::optional<Tensor>(), true /*use batch*/)));
913 }
914
baddbmm(const Tensor & bias,const Tensor & input,const Tensor & weight,const Scalar & beta,const Scalar & alpha)915 Tensor baddbmm(
916 const Tensor& bias,
917 const Tensor& input,
918 const Tensor& weight,
919 const Scalar& beta,
920 const Scalar& alpha) {
921 return run_baddbmm_context(
922 input,
923 alpha.to<float>(),
924 beta.to<float>(),
925 c10::make_intrusive<LinearPackedContext>(
926 LinearPackedContext(weight, bias, true /*use batch*/)));
927 }
928
929 #ifdef USE_VULKAN_API
930
TORCH_LIBRARY_IMPL(aten,Vulkan,m)931 TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
932 m.impl(TORCH_SELECTIVE_NAME("aten::addmm"), TORCH_FN(addmm));
933 m.impl(TORCH_SELECTIVE_NAME("aten::mm"), TORCH_FN(mm));
934 m.impl(TORCH_SELECTIVE_NAME("aten::bmm"), TORCH_FN(bmm));
935 m.impl(TORCH_SELECTIVE_NAME("aten::baddbmm"), TORCH_FN(baddbmm));
936 }
937
938 #endif /* USE_VULKAN_API */
939
940 } // namespace
941
LinearPackedContext(const Tensor & weight,const std::optional<Tensor> & bias,const bool use_batch)942 LinearPackedContext::LinearPackedContext(
943 const Tensor& weight,
944 const std::optional<Tensor>& bias,
945 const bool use_batch)
946 : unpacked_{c10::AnyType::get()} {
947 TORCH_CHECK(
948 available(weight, bias, use_batch),
949 "Vulkan Linear not available! "
950 "Reason: The provided (weight, bias) parameters are either invalid "
951 "individually or their combination is not supported by Vulkan Impl.");
952
953 packed_.reserve(Packed::NumArgs);
954 packed_.emplace_back(convert(pack_weights(weight, use_batch)));
955 const auto& packed_biases = weight.is_quantized()
956 ? pack_biases_quantized_weights(weight, bias, use_batch)
957 : pack_biases(weight, bias, use_batch);
958 packed_.emplace_back(convert(packed_biases));
959 packed_.emplace_back(weight.sizes());
960 packed_.emplace_back(bias && bias->defined());
961
962 if (!at::globalContext().releaseWeightsWhenPrepacking()) {
963 unpacked_.reserve(Unpacked::NumArgs);
964 unpacked_.emplace_back(weight);
965 unpacked_.emplace_back(bias);
966 }
967 }
968
pack(c10::impl::GenericList unpacked)969 LinearPackedContext LinearPackedContext::pack(c10::impl::GenericList unpacked) {
970 return LinearPackedContext(
971 unpacked.get(Unpacked::Weight).toTensor(),
972 get_optional_tensor(unpacked, Unpacked::Bias));
973 }
974
create_linear_context(Tensor && weight,std::optional<Tensor> && bias)975 c10::intrusive_ptr<LinearPackedContext> create_linear_context(
976 Tensor&& weight,
977 std::optional<Tensor>&& bias) {
978 return c10::make_intrusive<LinearPackedContext>(
979 LinearPackedContext(weight, bias));
980 }
981
run_linear_context(const Tensor & input,const c10::intrusive_ptr<LinearPackedContext> & linear_context)982 Tensor run_linear_context(
983 const Tensor& input,
984 const c10::intrusive_ptr<LinearPackedContext>& linear_context) {
985 return run_addmm_context(input, 1.0f, 1.0f, linear_context, false, 0, 0);
986 }
987
run_qlinear_context(const Tensor & input_arg,double output_scale,int64_t output_zero_point,const c10::intrusive_ptr<LinearPackedContext> & linear_context)988 Tensor run_qlinear_context(
989 const Tensor& input_arg,
990 double output_scale,
991 int64_t output_zero_point,
992 const c10::intrusive_ptr<LinearPackedContext>& linear_context) {
993 return run_addmm_context(
994 input_arg,
995 1.0f,
996 1.0f,
997 linear_context,
998 true,
999 output_scale,
1000 output_zero_point);
1001 }
1002
1003 } // namespace ops
1004 } // namespace vulkan
1005 } // namespace native
1006 } // namespace at
1007