xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_serialize.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 
3 #ifdef USE_FBGEMM
4 #include <ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h>
5 #endif
6 #ifdef USE_PYTORCH_QNNPACK
7 #include <ATen/native/ao_sparse/quantized/cpu/qnnpack_utils.h>
8 
9 #include <utility>
10 #endif
11 
12 namespace ao {
13 namespace sparse {
14 
15 namespace {
16 /**
17   - Wrap a vector in a Tensor, copying data into its own data pointer.
18   - The type of vec is T& (not vector<T>&) so this works with any vector-like
19     datastructure which has .data() and .size()
20  */
21 template <typename UNDERLYING_DTYPE, typename T>
wrap_vector(T & vec,c10::ScalarType dtype)22 at::Tensor wrap_vector(T& vec, c10::ScalarType dtype) {
23   at::Tensor t = at::empty(
24       {static_cast<long>(vec.size())}, at::device(c10::kCPU).dtype(dtype));
25   std::copy(
26       vec.data(), vec.data() + vec.size(), t.mutable_data_ptr<UNDERLYING_DTYPE>());
27   return t;
28 }
29 
30 #ifdef USE_FBGEMM
31 /**
32  * Adapted from Fbgemm BCSRMatrix::pack, but with zero points, without tiling,
33  * and without determining row_offsets
34  * https://github.com/pytorch/FBGEMM/blob/9d7c48a65419d0350f9e9e72f31e05bfe37e85a4/src/FbgemmSparseDense.cc#L84
35  */
pack_bcsr(const int8_t * src,const int64_t R,const int64_t C,const int64_t RB,const int64_t CB,const int8_t * zero_points,const bool qscheme_per_tensor)36 ao::sparse::BCSR pack_bcsr(
37     const int8_t* src,
38     const int64_t R,
39     const int64_t C,
40     const int64_t RB,
41     const int64_t CB,
42     const int8_t* zero_points,
43     const bool qscheme_per_tensor) {
44   const size_t ld = C;
45   std::vector<int32_t> rowBPtr;
46   std::vector<int32_t> colBIdx;
47   std::vector<int8_t> values;
48   rowBPtr.push_back(0);
49   int64_t nnzb = 0;
50   int64_t rowBlocks = (R + RB - 1) / RB;
51   for (int64_t i = 0; i < rowBlocks; ++i) {
52     int64_t curCols = C;
53     int64_t curColBlocks = (curCols + CB - 1) / CB;
54     for (int64_t j = 0; j < curColBlocks; ++j) {
55       // is the whole block zero?
56       bool isCurrentBlockNonZero = false;
57       for (int64_t ib = 0; ib < RB; ++ib) {
58         // break if already found a non-zero element or
59         // out of bounds
60         if (isCurrentBlockNonZero || (i * RB + ib) >= R) {
61           break;
62         }
63         const int64_t curr_row = i * RB + ib;
64         const int8_t curr_row_zero_point =
65             qscheme_per_tensor ? zero_points[0] : zero_points[curr_row];
66         for (int64_t jb = 0; jb < CB; ++jb) {
67           // within bound?
68           if ((j * CB + jb) >= C) {
69             continue;
70           } else {
71             if (src[curr_row * ld + j * CB + jb] != curr_row_zero_point) {
72               isCurrentBlockNonZero = true;
73               break;
74             }
75           }
76         }
77       }
78       if (isCurrentBlockNonZero) {
79         for (int64_t ib = 0; ib < RB; ++ib) {
80           for (int64_t jb = 0; jb < CB; ++jb) {
81             if ((i * RB + ib) >= R || (j * CB + jb) >= C) {
82               // zero fill
83               values.push_back(0);
84             } else {
85               int8_t val = src[(i * RB + ib) * ld + j * CB + jb];
86               values.push_back(val);
87             }
88           }
89         }
90         colBIdx.push_back(static_cast<int32_t>(j));
91         nnzb++;
92       }
93     }
94     rowBPtr.push_back(static_cast<int32_t>(nnzb));
95   }
96   return ao::sparse::BCSR(
97       std::move(values), std::move(rowBPtr), std::move(colBIdx));
98 }
99 #endif // USE_FBGEMM
100 } // namespace
101 
102 #ifdef USE_FBGEMM
103 
serialize()104 BCSRSerializationType PackedLinearWeight::serialize() {
105   // Get weights, row indices, and col indices in untiled form;
106   // unpack the tiled bcsr then pack it in untiled form
107   std::vector<int8_t> dense_weight_values = std::vector<int8_t>(w->R * w->C);
108   w->unpack(dense_weight_values.data());
109 
110   const bool qscheme_per_tensor = (q_scheme == c10::kPerTensorAffine);
111   at::Tensor zero_points = wrap_vector<int8_t>(w_zp, c10::kChar);
112 
113   ao::sparse::BCSR untiled_bcsr = pack_bcsr(
114       dense_weight_values.data(),
115       w->R,
116       w->C,
117       w->RB,
118       w->CB,
119       zero_points.data_ptr<int8_t>(),
120       qscheme_per_tensor);
121 
122   std::vector<int8_t>& packed_weight_values = std::get<0>(untiled_bcsr);
123   // Add 128 to each weight value. This serialization format is best for
124   // minimizing memory footprint for QNNPack
125 
126   at::Tensor weight_values = at::empty(
127       {static_cast<long>(packed_weight_values.size())},
128       at::device(c10::kCPU).dtype(c10::kByte));
129   std::transform(
130       packed_weight_values.begin(),
131       packed_weight_values.end(),
132       weight_values.mutable_data_ptr<uint8_t>(),
133       [](int8_t v) {
134         return static_cast<uint8_t>(static_cast<int16_t>(v) + 128);
135       });
136 
137   return BCSRSerializationType(
138       SPARSE_LINEAR_PACKED_PARAM_SERIALIZATION_VERSION,
139       bias_,
140       out_features_block_size_,
141       in_features_block_size_,
142       wrap_vector<float>(w_scale, c10::kFloat),
143       // Narrowing from int32_t to int8_t; this is okay because qint8 zero
144       // points are restricted to fit in bounds of int_8
145       std::move(zero_points),
146       qscheme_per_tensor,
147       wrap_vector<int>(
148           std::get<1>(untiled_bcsr), c10::kInt), // Row block indices
149       wrap_vector<int>(
150           std::get<2>(untiled_bcsr), c10::kInt), // Col block indices
151       std::move(weight_values),
152       w->R,
153       w->C);
154 }
155 
156 #endif // USE_FBGEMM
157 
158 #ifdef USE_PYTORCH_QNNPACK
159 
serialize()160 BCSRSerializationType PackedLinearWeightQnnp::serialize() {
161   at::Tensor w_scales_compact;
162   at::Tensor w_zero_points_compact;
163   const float* w_scales_data_ptr = w_scales_.const_data_ptr<float>();
164   std::function<int8_t(uint8_t)> subtract_128 = [](uint8_t v) {
165     return static_cast<int8_t>(static_cast<int16_t>(v) - 128);
166   };
167 
168   if (q_scheme_ == at::kPerTensorAffine) {
169     w_scales_compact = at::empty({1}, at::device(c10::kCPU).dtype(c10::kFloat));
170     w_zero_points_compact =
171         at::empty({1}, at::device(c10::kCPU).dtype(c10::kChar));
172 
173     w_scales_compact.mutable_data_ptr<float>()[0] = w_scales_data_ptr[0];
174     w_zero_points_compact.mutable_data_ptr<int8_t>()[0] =
175         static_cast<int8_t>(static_cast<int16_t>(w_zero_points_[0]) - 128);
176   } else if (q_scheme_ == at::kPerChannelAffine) {
177     w_scales_compact =
178         at::empty({output_channels_}, at::device(c10::kCPU).dtype(c10::kFloat));
179     w_zero_points_compact =
180         at::empty({output_channels_}, at::device(c10::kCPU).dtype(c10::kChar));
181 
182     std::copy(
183         w_scales_data_ptr,
184         w_scales_data_ptr +
185             output_channels_, // Don't go to the end because of padding
186         w_scales_compact.mutable_data_ptr<float>());
187 
188     // Subtract 128 from each zero point, to reverse addition done during
189     // prepacking
190     std::transform(
191         w_zero_points_.begin(),
192         w_zero_points_.begin() +
193             output_channels_, // Don't go to the end because of padding
194         w_zero_points_compact.mutable_data_ptr<int8_t>(),
195         std::move(subtract_128));
196   } else {
197     TORCH_CHECK(false, "Unsupported quantization scheme.");
198   }
199 
200   at::Tensor wrapped_row_values;
201   at::Tensor wrapped_col_indices;
202 
203   const uint32_t max_index = bcsr_matrix_->max_index();
204 
205   if (max_index <= std::numeric_limits<uint8_t>::max()) {
206     // Cast from uint8_t range to int8_t
207     wrapped_row_values = QNNPACK_BCSRMATRIX_DISPATCH_INDICES_DTYPE(
208         bcsr_matrix_,
209         { return wrap_vector<int8_t>(typed_bcsr->row_values, c10::kChar); });
210     wrapped_col_indices = QNNPACK_BCSRMATRIX_DISPATCH_INDICES_DTYPE(
211         bcsr_matrix_,
212         { return wrap_vector<int8_t>(typed_bcsr->col_indices, c10::kChar); });
213   } else if (max_index <= std::numeric_limits<uint16_t>::max()) {
214     // Cast from uint16_t range to int16_t
215     wrapped_row_values = QNNPACK_BCSRMATRIX_DISPATCH_INDICES_DTYPE(
216         bcsr_matrix_,
217         { return wrap_vector<int16_t>(typed_bcsr->row_values, c10::kShort); });
218     wrapped_col_indices = QNNPACK_BCSRMATRIX_DISPATCH_INDICES_DTYPE(
219         bcsr_matrix_,
220         { return wrap_vector<int16_t>(typed_bcsr->col_indices, c10::kShort); });
221   } else {
222     // Cast from uint32_t range to int32_t
223     wrapped_row_values = QNNPACK_BCSRMATRIX_DISPATCH_INDICES_DTYPE(
224         bcsr_matrix_,
225         { return wrap_vector<int>(typed_bcsr->row_values, c10::kInt); });
226     wrapped_col_indices = QNNPACK_BCSRMATRIX_DISPATCH_INDICES_DTYPE(
227         bcsr_matrix_,
228         { return wrap_vector<int>(typed_bcsr->col_indices, c10::kInt); });
229   }
230 
231   return BCSRSerializationType(
232       SPARSE_LINEAR_PACKED_PARAM_SERIALIZATION_VERSION,
233       orig_bias_,
234       out_features_block_size_,
235       in_features_block_size_,
236       std::move(w_scales_compact),
237       std::move(w_zero_points_compact),
238       (q_scheme_ == c10::kPerTensorAffine),
239       wrapped_row_values,
240       wrapped_col_indices,
241       wrap_vector<uint8_t>(bcsr_matrix_->values, c10::kByte),
242       output_channels_,
243       input_channels_);
244 }
245 
246 #endif // USE_PYTORCH_QNNPACK
247 
248 } // namespace sparse
249 } // namespace ao
250