1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/QPackUtils.h>
10
11 namespace vkcompute {
12
pack4(const uint8_t * w_ptr,uint8_t * b_ptr,uint32_t N,uint32_t K)13 void pack4(const uint8_t* w_ptr, uint8_t* b_ptr, uint32_t N, uint32_t K) {
14 for (int32_t n = 0; n < N; n++) {
15 for (int32_t k2 = 0; k2 < K / 2; k2++) {
16 uint8_t src_val0 = w_ptr[n * K + k2 * 2];
17 uint8_t src_val1 = w_ptr[n * K + k2 * 2 + 1];
18 b_ptr[n * (K / 2) + k2] = (uint8_t(src_val1) << 4) | uint8_t(src_val0);
19 }
20 }
21 }
22
int4mm_pack_weights(const std::vector<int64_t> & W_sizes,const uint8_t * w_ptr)23 std::vector<uint8_t> int4mm_pack_weights(
24 const std::vector<int64_t>& W_sizes,
25 const uint8_t* w_ptr) {
26 const int32_t N = utils::val_at(-1, W_sizes);
27 const int32_t K = utils::val_at(-2, W_sizes);
28
29 const auto numel = K * N;
30 std::vector<uint8_t> w_ptr_T(numel);
31 std::vector<uint8_t> b_ptr(utils::div_up(numel, 2));
32
33 // Transpose the weights
34 for (int32_t k = 0; k < K; k++) {
35 for (int32_t n = 0; n < N; n++) {
36 w_ptr_T[n * K + k] = w_ptr[k * N + n];
37 }
38 }
39
40 // Pack two int4s into each int8
41 pack4(w_ptr_T.data(), b_ptr.data(), N, K);
42
43 return b_ptr;
44 }
45
int4mm_dequantize_weights(const std::vector<int64_t> & W_sizes,const uint8_t * w_ptr,const uint32_t group_size,const float * scales_and_zeros)46 std::vector<float> int4mm_dequantize_weights(
47 const std::vector<int64_t>& W_sizes,
48 const uint8_t* w_ptr,
49 const uint32_t group_size,
50 const float* scales_and_zeros) {
51 const int64_t N = utils::val_at(-1, W_sizes);
52 const int64_t K = utils::val_at(-2, W_sizes);
53
54 std::vector<float> w_ptr_deq(K * N);
55 const int k_groups = K / group_size;
56 const int zeros_stride = k_groups * N;
57
58 for (int k = 0; k < K; k++) {
59 for (int n = 0; n < N; n++) {
60 const int kb = k / group_size;
61 const int scale_idx = k_groups * n + kb;
62 const float scale = scales_and_zeros[scale_idx];
63 const float zero =
64 scales_and_zeros[scale_idx + zeros_stride] - scale * 8.0;
65 w_ptr_deq[k * N + n] = w_ptr[k * N + n] * scale + zero;
66 }
67 }
68
69 return w_ptr_deq;
70 }
71
72 } // namespace vkcompute
73