xref: /aosp_15_r20/external/executorch/backends/vulkan/runtime/graph/ops/impl/utils/QPackUtils.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/QPackUtils.h>
10 
11 namespace vkcompute {
12 
pack4(const uint8_t * w_ptr,uint8_t * b_ptr,uint32_t N,uint32_t K)13 void pack4(const uint8_t* w_ptr, uint8_t* b_ptr, uint32_t N, uint32_t K) {
14   for (int32_t n = 0; n < N; n++) {
15     for (int32_t k2 = 0; k2 < K / 2; k2++) {
16       uint8_t src_val0 = w_ptr[n * K + k2 * 2];
17       uint8_t src_val1 = w_ptr[n * K + k2 * 2 + 1];
18       b_ptr[n * (K / 2) + k2] = (uint8_t(src_val1) << 4) | uint8_t(src_val0);
19     }
20   }
21 }
22 
int4mm_pack_weights(const std::vector<int64_t> & W_sizes,const uint8_t * w_ptr)23 std::vector<uint8_t> int4mm_pack_weights(
24     const std::vector<int64_t>& W_sizes,
25     const uint8_t* w_ptr) {
26   const int32_t N = utils::val_at(-1, W_sizes);
27   const int32_t K = utils::val_at(-2, W_sizes);
28 
29   const auto numel = K * N;
30   std::vector<uint8_t> w_ptr_T(numel);
31   std::vector<uint8_t> b_ptr(utils::div_up(numel, 2));
32 
33   // Transpose the weights
34   for (int32_t k = 0; k < K; k++) {
35     for (int32_t n = 0; n < N; n++) {
36       w_ptr_T[n * K + k] = w_ptr[k * N + n];
37     }
38   }
39 
40   // Pack two int4s into each int8
41   pack4(w_ptr_T.data(), b_ptr.data(), N, K);
42 
43   return b_ptr;
44 }
45 
int4mm_dequantize_weights(const std::vector<int64_t> & W_sizes,const uint8_t * w_ptr,const uint32_t group_size,const float * scales_and_zeros)46 std::vector<float> int4mm_dequantize_weights(
47     const std::vector<int64_t>& W_sizes,
48     const uint8_t* w_ptr,
49     const uint32_t group_size,
50     const float* scales_and_zeros) {
51   const int64_t N = utils::val_at(-1, W_sizes);
52   const int64_t K = utils::val_at(-2, W_sizes);
53 
54   std::vector<float> w_ptr_deq(K * N);
55   const int k_groups = K / group_size;
56   const int zeros_stride = k_groups * N;
57 
58   for (int k = 0; k < K; k++) {
59     for (int n = 0; n < N; n++) {
60       const int kb = k / group_size;
61       const int scale_idx = k_groups * n + kb;
62       const float scale = scales_and_zeros[scale_idx];
63       const float zero =
64           scales_and_zeros[scale_idx + zeros_stride] - scale * 8.0;
65       w_ptr_deq[k * N + n] = w_ptr[k * N + n] * scale + zero;
66     }
67   }
68 
69   return w_ptr_deq;
70 }
71 
72 } // namespace vkcompute
73