xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/qnnpack/src/fc-prepack.cc (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <pytorch_qnnpack.h>
2 #include <qnnpack/log.h>
3 #include <qnnpack/pack.h>
4 #include <qnnpack_func.h>
5 #include <cstdlib>
6 #include <cstring>
7 #include <cmath>
8 
9 namespace qnnpack {
10 // For runtime quantization packing.
PackBMatrix(const size_t input_channels,const size_t output_channels,const uint8_t * kernel_zero_points,const float * requantization_scales,const uint8_t * kernel,const int32_t * bias)11 PackBMatrix::PackBMatrix(
12     const size_t input_channels,
13     const size_t output_channels,
14     const uint8_t* kernel_zero_points,
15     const float* requantization_scales,
16     const uint8_t* kernel,
17     const int32_t* bias) {
18   for (size_t i = 0; i < output_channels; ++i) {
19     if (requantization_scales[i] <= 0.0f ||
20         !std::isnormal(requantization_scales[i])) {
21       pytorch_qnnp_log_error(
22           "failed to create fully connected operator with requant scale of "
23           "%.7g for output channel %d."
24           "Scale must be finite and positive",
25           requantization_scales[i], (int)i);
26       assert(false && "QNNPACK Runtime Error.");
27     }
28   }
29 
30   const uint32_t nr = pytorch_qnnp_params.q8conv.nr;
31   const uint32_t kr = pytorch_qnnp_params.q8conv.kr;
32 
33   const uint32_t n_stride = (output_channels + (nr - 1)) & -nr;
34   const uint32_t k_stride = (input_channels + (kr - 1)) & -kr;
35   input_channels_ = input_channels;
36   output_channels_ = output_channels;
37   packed_weights_ =
38       malloc(n_stride * (k_stride * sizeof(uint8_t) + sizeof(int32_t)));
39   if (packed_weights_ == nullptr) {
40     pytorch_qnnp_log_error(
41         "failed to allocate %zu bytes for packed weights",
42         n_stride * (k_stride * sizeof(uint8_t) + sizeof(int32_t)));
43     assert(false && "QNNPACK Runtime Error.");
44   }
45 
46   pytorch_pack_q8gemm_wrq(
47       output_channels,
48       input_channels,
49       nr,
50       nr,
51       kr,
52       kernel,
53       bias,
54       kernel_zero_points,
55       packed_weights_);
56 }
57 
58 } // namespace qnnpack
59