xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/ao_sparse/quantized/cpu/qlinear_deserialize.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 
3 #ifdef USE_FBGEMM
4 #include <ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h>
5 #endif
6 #ifdef USE_PYTORCH_QNNPACK
7 #include <ATen/native/ao_sparse/quantized/cpu/qnnpack_utils.h>
8 #endif
9 
10 namespace ao {
11 namespace sparse {
12 
13 namespace {
14 constexpr int64_t serialization_version_index = 0;
15 constexpr int64_t bias_index = 1;
16 constexpr int64_t out_features_block_size_index = 2;
17 constexpr int64_t in_features_block_size_index = 3;
18 constexpr int64_t weight_scales_index = 4;
19 constexpr int64_t weight_zero_point_index = 5;
20 constexpr int64_t quantization_scheme_index = 6;
21 constexpr int64_t row_block_indices_index = 7;
22 constexpr int64_t col_block_indices_index = 8;
23 constexpr int64_t weight_values_index = 9;
24 constexpr int64_t num_output_channels_index = 10;
25 constexpr int64_t num_input_channels_index = 11;
26 
27 template <typename TENSOR_DTYPE, typename VEC_DTYPE>
unwrap_vector(at::Tensor tensor)28 std::vector<VEC_DTYPE> unwrap_vector(at::Tensor tensor) {
29   std::vector<VEC_DTYPE> vec(tensor.numel());
30   TENSOR_DTYPE* tensor_data_ptr = tensor.data_ptr<TENSOR_DTYPE>();
31   std::copy(tensor_data_ptr, tensor_data_ptr + tensor.numel(), vec.data());
32   return vec;
33 }
34 
35 #ifdef USE_FBGEMM
36 /**
37  * Adapted from Fbgemm BCSRMatrix::unpack, but with non-zero zero points and
38  * without tiling
39  * https://github.com/pytorch/FBGEMM/blob/9d7c48a65419d0350f9e9e72f31e05bfe37e85a4/src/FbgemmSparseDense.cc#L154
40  */
unpack_bcsr(int8_t * dst,ao::sparse::BCSR bcsr,const int64_t R,const int64_t C,const int64_t RB,const int64_t CB,const int8_t * zero_points,const bool qscheme_per_tensor)41 void unpack_bcsr(
42     int8_t* dst,
43     ao::sparse::BCSR bcsr,
44     const int64_t R,
45     const int64_t C,
46     const int64_t RB,
47     const int64_t CB,
48     const int8_t* zero_points,
49     const bool qscheme_per_tensor) {
50   const size_t ld = C;
51   // zero out destination
52   if (qscheme_per_tensor) {
53     memset(dst, zero_points[0], R * C * sizeof(int8_t));
54   } else {
55     for (int64_t i = 0; i < R; i++) {
56       memset(dst + i * C, zero_points[i], C * sizeof(int8_t));
57     }
58   }
59   const std::vector<int8_t>& weight_values = std::get<0>(bcsr);
60   const std::vector<int32_t>& row_indices = std::get<1>(bcsr);
61   const std::vector<int32_t>& col_indices = std::get<2>(bcsr);
62   int64_t rowBlocks = (R + RB - 1) / RB;
63   for (int64_t i = 0; i < rowBlocks; ++i) {
64     // For the current tile, rowBPtr starts from currentTileIdx
65     for (int64_t r = row_indices[i]; r < row_indices[i + 1]; ++r) {
66       int64_t curColIdx = col_indices[r];
67       for (int64_t ib = 0; ib < RB; ++ib) {
68         for (int64_t jb = 0; jb < CB; ++jb) {
69           // Are we within bounds of destination matrix?
70           if ((i * RB + ib) < R && (curColIdx * CB + jb) < C) {
71             dst[(i * RB + ib) * ld + curColIdx * CB + jb] =
72                 weight_values[r * RB * CB + ib * CB + jb];
73           }
74         }
75       }
76     }
77   }
78 }
79 #endif // USE_FBGEMM
80 } // namespace
81 
82 #ifdef USE_FBGEMM
83 
deserialize(const BCSRSerializationType & serialized)84 c10::intrusive_ptr<LinearPackedParamsBase> PackedLinearWeight::deserialize(
85     const BCSRSerializationType& serialized) {
86   const int64_t out_features_block_size =
87       std::get<out_features_block_size_index>(serialized);
88   const int64_t in_features_block_size =
89       std::get<in_features_block_size_index>(serialized);
90   const c10::QScheme q_scheme = std::get<quantization_scheme_index>(serialized)
91       ? c10::kPerTensorAffine
92       : c10::kPerChannelAffine;
93   const int64_t output_channels =
94       std::get<num_output_channels_index>(serialized);
95   const int64_t input_channels = std::get<num_input_channels_index>(serialized);
96   // Unpack the untiled bcsr, then pack it in tiled form
97   at::Tensor weight_origin;
98   const at::Tensor weight_zero_points =
99       std::get<weight_zero_point_index>(serialized);
100   if (q_scheme == c10::kPerTensorAffine) {
101     weight_origin = at::_empty_affine_quantized(
102         {output_channels, input_channels},
103         at::device(c10::kCPU).dtype(c10::kQInt8),
104         std::get<weight_scales_index>(serialized).data_ptr<float>()[0],
105         weight_zero_points.data_ptr<int8_t>()[0]);
106   } else if (q_scheme == c10::kPerChannelAffine) {
107     weight_origin = at::_empty_per_channel_affine_quantized(
108         {output_channels, input_channels},
109         std::get<weight_scales_index>(serialized),
110         weight_zero_points,
111         0, // The output channel axis is 0
112         device(c10::kCPU).dtype(c10::kQInt8));
113   }
114 
115   const at::Tensor loaded_weight_values =
116       std::get<weight_values_index>(serialized);
117   const uint8_t* loaded_weight_values_ptr =
118       loaded_weight_values.data_ptr<uint8_t>();
119   const int64_t loaded_weight_values_size = loaded_weight_values.numel();
120   // Subtract 128 because we serialize as +128, which s best for
121   // minimizing memory footprint for QNNPack
122   std::vector<int8_t> weight_values(loaded_weight_values_size);
123   std::transform(
124       loaded_weight_values_ptr,
125       loaded_weight_values_ptr + loaded_weight_values_size,
126       weight_values.begin(),
127       [](uint8_t v) {
128         return static_cast<int8_t>(static_cast<int16_t>(v) - 128);
129       });
130 
131   const at::Tensor row_block_indices =
132       std::get<row_block_indices_index>(serialized);
133   const at::Tensor col_block_indices =
134       std::get<col_block_indices_index>(serialized);
135   // Unpack as non backend specific untiled BCSR then pack as Fbgemm tiled BCSR
136   // because untiled Fbgemm BCSR currently doesn't exist
137   unpack_bcsr(
138       reinterpret_cast<int8_t*>(weight_origin.data_ptr<c10::qint8>()),
139       AT_DISPATCH_INTEGRAL_TYPES(
140           row_block_indices.scalar_type(),
141           "packed_linear_weight_fbgemm_setup_bcsr",
142           [&] {
143             return ao::sparse::BCSR(
144                 std::move(weight_values),
145                 unwrap_vector<scalar_t, int32_t>(
146                     std::get<row_block_indices_index>(serialized)),
147                 unwrap_vector<scalar_t, int32_t>(
148                     std::get<col_block_indices_index>(serialized)));
149           }),
150       output_channels,
151       input_channels,
152       out_features_block_size,
153       in_features_block_size,
154       weight_zero_points.data_ptr<int8_t>(),
155       q_scheme == c10::kPerTensorAffine);
156 
157   return PackedLinearWeight::prepack(
158       weight_origin,
159       std::get<bias_index>(serialized),
160       out_features_block_size,
161       in_features_block_size);
162 }
163 
164 #endif // USE_FBGEMM
165 
166 #ifdef USE_PYTORCH_QNNPACK
167 
deserialize(const BCSRSerializationType & serialized)168 c10::intrusive_ptr<LinearPackedParamsBase> PackedLinearWeightQnnp::deserialize(
169     const BCSRSerializationType& serialized) {
170   return c10::make_intrusive<PackedLinearWeightQnnp>(serialized);
171 }
172 
173 template <typename INDICES_DTYPE>
174 struct UnsignedIndicesTypeTrait {
175   static_assert(
176       sizeof(INDICES_DTYPE) == 0,
177       "Invalid dtype for UnsignedIndicesTypeTrait");
178 };
179 
180 template <>
181 struct UnsignedIndicesTypeTrait<int32_t> {
182   using t = uint32_t;
183 };
184 
185 template <>
186 struct UnsignedIndicesTypeTrait<int16_t> {
187   using t = uint16_t;
188 };
189 
190 template <>
191 struct UnsignedIndicesTypeTrait<int8_t> {
192   using t = uint8_t;
193 };
194 
195 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
PackedLinearWeightQnnp(const BCSRSerializationType & serialized)196 PackedLinearWeightQnnp::PackedLinearWeightQnnp(
197     const BCSRSerializationType& serialized)
198     : LinearPackedParamsBase(
199           std::get<out_features_block_size_index>(serialized),
200           std::get<in_features_block_size_index>(serialized)),
201       orig_bias_(std::get<bias_index>(serialized)),
202       q_scheme_(
203           std::get<quantization_scheme_index>(serialized)
204               ? c10::kPerTensorAffine
205               : c10::kPerChannelAffine),
206       output_channels_(std::get<num_output_channels_index>(serialized)),
207       input_channels_(std::get<num_input_channels_index>(serialized)) {
208   const int64_t serialization_version =
209       std::get<serialization_version_index>(serialized);
210   TORCH_CHECK(
211       serialization_version <= SPARSE_LINEAR_PACKED_PARAM_SERIALIZATION_VERSION,
212       "Attempted to deserialize sparse qlinear packed params with an ",
213       "incompatible serialization version (",
214       serialization_version,
215       " > ",
216       SPARSE_LINEAR_PACKED_PARAM_SERIALIZATION_VERSION,
217       ")");
218 
219   if (orig_bias_.has_value()) {
220     bias_ = orig_bias_.value();
221 
222     TORCH_CHECK(
223         (bias_.ndimension() == 1 && bias_.size(0) == output_channels_),
224         "ao::sparse::qlinear_deserialize (qnnpack): Given weight of size ",
225         "{",
226         output_channels_,
227         ", ",
228         input_channels_,
229         "}",
230         ", expected bias to be 1-dimensional with ",
231         output_channels_,
232         " elements",
233         ", but got bias of size ",
234         bias_.sizes(),
235         " instead");
236   } else {
237     bias_ = at::zeros(output_channels_, at::device(at::kCPU).dtype(at::kFloat));
238   }
239 
240   // Pad amount (8) comes from make_zero_points_and_scales_tensor
241   // https://github.com/pytorch/pytorch/blob/f8c1acea1e78573c04cd18893c4abff9eea64b03/aten/src/ATen/native/quantized/cpu/qnnpack_utils.h#L468
242   const int64_t output_channels_padded = output_channels_ + 8;
243 
244   w_scales_ = at::empty(
245       {output_channels_padded}, at::device(at::kCPU).dtype(at::kFloat));
246   float* w_scales_data_ptr = w_scales_.mutable_data_ptr<float>();
247   std::fill_n(
248       w_scales_data_ptr + output_channels_,
249       output_channels_padded - output_channels_,
250       1); // Pad with 1
251 
252   w_zero_points_ =
253       std::vector<uint8_t>(output_channels_padded, 0); // Pad with 0;
254 
255   const float* w_scales_orig_data_ptr =
256       std::get<weight_scales_index>(serialized).data_ptr<float>();
257   const int8_t* w_zp_orig_data_ptr =
258       std::get<weight_zero_point_index>(serialized).data_ptr<int8_t>();
259 
260   const std::function<uint8_t(int8_t)> add_128 = [](int8_t v) {
261     return static_cast<uint8_t>(static_cast<int16_t>(v) + 128);
262   };
263 
264   if (q_scheme_ == at::kPerTensorAffine) {
265     std::fill_n(w_scales_data_ptr, output_channels_, w_scales_orig_data_ptr[0]);
266     std::fill_n(
267         w_zero_points_.begin(), output_channels_, w_zp_orig_data_ptr[0] + 128);
268   } else if (q_scheme_ == at::kPerChannelAffine) {
269     std::copy(
270         w_scales_orig_data_ptr,
271         w_scales_orig_data_ptr + output_channels_,
272         w_scales_data_ptr);
273     std::transform(
274         w_zp_orig_data_ptr,
275         w_zp_orig_data_ptr + output_channels_,
276         w_zero_points_.begin(),
277         add_128);
278   } else {
279     TORCH_CHECK(false, "Unsupported quantization scheme.");
280   }
281 
282   deserialized_bcsr_row_block_indices_ =
283       std::get<row_block_indices_index>(serialized);
284   deserialized_bcsr_col_block_indices_ =
285       std::get<col_block_indices_index>(serialized);
286   deserialized_bcsr_weight_values_ = std::get<weight_values_index>(serialized);
287 
288 #define AT_DISPATCH_CASE_BCSR_INDICES_TYPES(...)      \
289   AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
290   AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__)  \
291   AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__)
292 
293 #define AT_DISPATCH_BCSR_INDICES_TYPES(TYPE, NAME, ...) \
294   AT_DISPATCH_SWITCH(                                   \
295       TYPE, NAME, AT_DISPATCH_CASE_BCSR_INDICES_TYPES(__VA_ARGS__))
296 
297   bcsr_matrix_ = AT_DISPATCH_BCSR_INDICES_TYPES(
298       deserialized_bcsr_row_block_indices_.scalar_type(),
299       "packed_linear_weight_qnnp_setup_bcsr",
300       [&] {
301         using unsigned_t = UnsignedIndicesTypeTrait<scalar_t>::t;
302         return qnnpack::generateBlockCSRMatrix<unsigned_t>(
303             reinterpret_cast<unsigned_t*>(
304                 deserialized_bcsr_col_block_indices_.data_ptr<scalar_t>()),
305             reinterpret_cast<unsigned_t*>(
306                 deserialized_bcsr_row_block_indices_.data_ptr<scalar_t>()),
307             deserialized_bcsr_weight_values_.data_ptr<uint8_t>(),
308             deserialized_bcsr_col_block_indices_.numel(),
309             deserialized_bcsr_row_block_indices_.numel(),
310             deserialized_bcsr_weight_values_.numel(),
311             out_features_block_size_,
312             in_features_block_size_);
313       });
314 
315 #undef AT_DISPATCH_CASE_BCSR_INDICES_TYPES
316 #undef AT_DISPATCH_BCSR_INDICES_TYPES
317 }
318 #endif // USE_PYTORCH_QNNPACK
319 
320 } // namespace sparse
321 } // namespace ao
322