xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/fc-unpack.cc (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <pytorch_qnnpack.h>
2 #include <qnnpack/log.h>
3 #include <qnnpack/pack.h>
4 #include <qnnpack_func.h>
5 #include <cstdlib>
6 #include <cstring>
7 #include <cmath>
8 
9 namespace qnnpack {
10 // For runtime quantization unpacking.
unpackWeights(const uint8_t * kernel_zero_points,int8_t * kernel) const11 void PackBMatrix::unpackWeights(
12   const uint8_t* kernel_zero_points,
13   int8_t* kernel
14 ) const {
15   union {
16     void* const as_void_ptr;
17     uint8_t* as_uint8_ptr;
18     int32_t* as_int32_ptr;
19   } packed = {packed_weights_};
20 
21   // C = A * B
22   // A = M*K
23   // B = K*N
24   const uint32_t nr = pytorch_qnnp_params.q8conv.nr;
25   const uint32_t kr = pytorch_qnnp_params.q8conv.kr;
26 
27   // Convert prepacked weight to original weight / bias.
28   for (size_t nr_block_start = 0; nr_block_start < output_channels_; nr_block_start += nr) {
29     const size_t nr_block_size = min(output_channels_ - nr_block_start, nr);
30     for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
31          nr_block_offset++) {
32       packed.as_int32_ptr++;
33     }
34     packed.as_int32_ptr += (nr - nr_block_size);
35     for (size_t kr_block_start = 0; kr_block_start < input_channels_; kr_block_start += kr) {
36       const size_t kr_block_size = min(input_channels_ - kr_block_start, kr);
37       for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size;
38            nr_block_offset++) {
39         for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size;
40              kr_block_offset++) {
41           kernel[(nr_block_start + nr_block_offset) * input_channels_ +
42           (kr_block_start + kr_block_offset)] = *(packed.as_uint8_ptr++);
43         }
44         if (kernel_zero_points != 0) {
45           for (size_t kr_block_offset = 0; kr_block_offset < (kr - kr_block_size);
46                kr_block_offset++) {
47             packed.as_uint8_ptr++;
48           }
49         } else {
50           packed.as_uint8_ptr += (kr - kr_block_size);
51         }
52       }
53       if (kernel_zero_points != 0) {
54         size_t remaining_nr_blocks = ((nr - nr_block_size) & (nr - 1));
55         for (size_t nr_block_offset = 0; nr_block_offset < remaining_nr_blocks;
56              nr_block_offset++) {
57           for (size_t kr_block_offset = 0; kr_block_offset < kr;
58                kr_block_offset++) {
59             packed.as_uint8_ptr++;
60           }
61         }
62       } else {
63         packed.as_uint8_ptr += ((nr - nr_block_size) & (nr - 1)) * kr;
64       }
65     }
66   }
67 
68 }
69 
70 } // namespace qnnpack
71