1 #include <ATen/ATen.h>
2
3 #ifdef USE_FBGEMM
4 #include <ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h>
5 #endif
6 #ifdef USE_PYTORCH_QNNPACK
7 #include <ATen/native/ao_sparse/quantized/cpu/qnnpack_utils.h>
8 #endif
9
10 namespace ao {
11 namespace sparse {
12
13 namespace {
14 constexpr int64_t serialization_version_index = 0;
15 constexpr int64_t bias_index = 1;
16 constexpr int64_t out_features_block_size_index = 2;
17 constexpr int64_t in_features_block_size_index = 3;
18 constexpr int64_t weight_scales_index = 4;
19 constexpr int64_t weight_zero_point_index = 5;
20 constexpr int64_t quantization_scheme_index = 6;
21 constexpr int64_t row_block_indices_index = 7;
22 constexpr int64_t col_block_indices_index = 8;
23 constexpr int64_t weight_values_index = 9;
24 constexpr int64_t num_output_channels_index = 10;
25 constexpr int64_t num_input_channels_index = 11;
26
27 template <typename TENSOR_DTYPE, typename VEC_DTYPE>
unwrap_vector(at::Tensor tensor)28 std::vector<VEC_DTYPE> unwrap_vector(at::Tensor tensor) {
29 std::vector<VEC_DTYPE> vec(tensor.numel());
30 TENSOR_DTYPE* tensor_data_ptr = tensor.data_ptr<TENSOR_DTYPE>();
31 std::copy(tensor_data_ptr, tensor_data_ptr + tensor.numel(), vec.data());
32 return vec;
33 }
34
35 #ifdef USE_FBGEMM
36 /**
37 * Adapted from Fbgemm BCSRMatrix::unpack, but with non-zero zero points and
38 * without tiling
39 * https://github.com/pytorch/FBGEMM/blob/9d7c48a65419d0350f9e9e72f31e05bfe37e85a4/src/FbgemmSparseDense.cc#L154
40 */
unpack_bcsr(int8_t * dst,ao::sparse::BCSR bcsr,const int64_t R,const int64_t C,const int64_t RB,const int64_t CB,const int8_t * zero_points,const bool qscheme_per_tensor)41 void unpack_bcsr(
42 int8_t* dst,
43 ao::sparse::BCSR bcsr,
44 const int64_t R,
45 const int64_t C,
46 const int64_t RB,
47 const int64_t CB,
48 const int8_t* zero_points,
49 const bool qscheme_per_tensor) {
50 const size_t ld = C;
51 // zero out destination
52 if (qscheme_per_tensor) {
53 memset(dst, zero_points[0], R * C * sizeof(int8_t));
54 } else {
55 for (int64_t i = 0; i < R; i++) {
56 memset(dst + i * C, zero_points[i], C * sizeof(int8_t));
57 }
58 }
59 const std::vector<int8_t>& weight_values = std::get<0>(bcsr);
60 const std::vector<int32_t>& row_indices = std::get<1>(bcsr);
61 const std::vector<int32_t>& col_indices = std::get<2>(bcsr);
62 int64_t rowBlocks = (R + RB - 1) / RB;
63 for (int64_t i = 0; i < rowBlocks; ++i) {
64 // For the current tile, rowBPtr starts from currentTileIdx
65 for (int64_t r = row_indices[i]; r < row_indices[i + 1]; ++r) {
66 int64_t curColIdx = col_indices[r];
67 for (int64_t ib = 0; ib < RB; ++ib) {
68 for (int64_t jb = 0; jb < CB; ++jb) {
69 // Are we within bounds of destination matrix?
70 if ((i * RB + ib) < R && (curColIdx * CB + jb) < C) {
71 dst[(i * RB + ib) * ld + curColIdx * CB + jb] =
72 weight_values[r * RB * CB + ib * CB + jb];
73 }
74 }
75 }
76 }
77 }
78 }
79 #endif // USE_FBGEMM
80 } // namespace
81
82 #ifdef USE_FBGEMM
83
deserialize(const BCSRSerializationType & serialized)84 c10::intrusive_ptr<LinearPackedParamsBase> PackedLinearWeight::deserialize(
85 const BCSRSerializationType& serialized) {
86 const int64_t out_features_block_size =
87 std::get<out_features_block_size_index>(serialized);
88 const int64_t in_features_block_size =
89 std::get<in_features_block_size_index>(serialized);
90 const c10::QScheme q_scheme = std::get<quantization_scheme_index>(serialized)
91 ? c10::kPerTensorAffine
92 : c10::kPerChannelAffine;
93 const int64_t output_channels =
94 std::get<num_output_channels_index>(serialized);
95 const int64_t input_channels = std::get<num_input_channels_index>(serialized);
96 // Unpack the untiled bcsr, then pack it in tiled form
97 at::Tensor weight_origin;
98 const at::Tensor weight_zero_points =
99 std::get<weight_zero_point_index>(serialized);
100 if (q_scheme == c10::kPerTensorAffine) {
101 weight_origin = at::_empty_affine_quantized(
102 {output_channels, input_channels},
103 at::device(c10::kCPU).dtype(c10::kQInt8),
104 std::get<weight_scales_index>(serialized).data_ptr<float>()[0],
105 weight_zero_points.data_ptr<int8_t>()[0]);
106 } else if (q_scheme == c10::kPerChannelAffine) {
107 weight_origin = at::_empty_per_channel_affine_quantized(
108 {output_channels, input_channels},
109 std::get<weight_scales_index>(serialized),
110 weight_zero_points,
111 0, // The output channel axis is 0
112 device(c10::kCPU).dtype(c10::kQInt8));
113 }
114
115 const at::Tensor loaded_weight_values =
116 std::get<weight_values_index>(serialized);
117 const uint8_t* loaded_weight_values_ptr =
118 loaded_weight_values.data_ptr<uint8_t>();
119 const int64_t loaded_weight_values_size = loaded_weight_values.numel();
120 // Subtract 128 because we serialize as +128, which s best for
121 // minimizing memory footprint for QNNPack
122 std::vector<int8_t> weight_values(loaded_weight_values_size);
123 std::transform(
124 loaded_weight_values_ptr,
125 loaded_weight_values_ptr + loaded_weight_values_size,
126 weight_values.begin(),
127 [](uint8_t v) {
128 return static_cast<int8_t>(static_cast<int16_t>(v) - 128);
129 });
130
131 const at::Tensor row_block_indices =
132 std::get<row_block_indices_index>(serialized);
133 const at::Tensor col_block_indices =
134 std::get<col_block_indices_index>(serialized);
135 // Unpack as non backend specific untiled BCSR then pack as Fbgemm tiled BCSR
136 // because untiled Fbgemm BCSR currently doesn't exist
137 unpack_bcsr(
138 reinterpret_cast<int8_t*>(weight_origin.data_ptr<c10::qint8>()),
139 AT_DISPATCH_INTEGRAL_TYPES(
140 row_block_indices.scalar_type(),
141 "packed_linear_weight_fbgemm_setup_bcsr",
142 [&] {
143 return ao::sparse::BCSR(
144 std::move(weight_values),
145 unwrap_vector<scalar_t, int32_t>(
146 std::get<row_block_indices_index>(serialized)),
147 unwrap_vector<scalar_t, int32_t>(
148 std::get<col_block_indices_index>(serialized)));
149 }),
150 output_channels,
151 input_channels,
152 out_features_block_size,
153 in_features_block_size,
154 weight_zero_points.data_ptr<int8_t>(),
155 q_scheme == c10::kPerTensorAffine);
156
157 return PackedLinearWeight::prepack(
158 weight_origin,
159 std::get<bias_index>(serialized),
160 out_features_block_size,
161 in_features_block_size);
162 }
163
164 #endif // USE_FBGEMM
165
166 #ifdef USE_PYTORCH_QNNPACK
167
deserialize(const BCSRSerializationType & serialized)168 c10::intrusive_ptr<LinearPackedParamsBase> PackedLinearWeightQnnp::deserialize(
169 const BCSRSerializationType& serialized) {
170 return c10::make_intrusive<PackedLinearWeightQnnp>(serialized);
171 }
172
173 template <typename INDICES_DTYPE>
174 struct UnsignedIndicesTypeTrait {
175 static_assert(
176 sizeof(INDICES_DTYPE) == 0,
177 "Invalid dtype for UnsignedIndicesTypeTrait");
178 };
179
180 template <>
181 struct UnsignedIndicesTypeTrait<int32_t> {
182 using t = uint32_t;
183 };
184
185 template <>
186 struct UnsignedIndicesTypeTrait<int16_t> {
187 using t = uint16_t;
188 };
189
190 template <>
191 struct UnsignedIndicesTypeTrait<int8_t> {
192 using t = uint8_t;
193 };
194
195 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
PackedLinearWeightQnnp(const BCSRSerializationType & serialized)196 PackedLinearWeightQnnp::PackedLinearWeightQnnp(
197 const BCSRSerializationType& serialized)
198 : LinearPackedParamsBase(
199 std::get<out_features_block_size_index>(serialized),
200 std::get<in_features_block_size_index>(serialized)),
201 orig_bias_(std::get<bias_index>(serialized)),
202 q_scheme_(
203 std::get<quantization_scheme_index>(serialized)
204 ? c10::kPerTensorAffine
205 : c10::kPerChannelAffine),
206 output_channels_(std::get<num_output_channels_index>(serialized)),
207 input_channels_(std::get<num_input_channels_index>(serialized)) {
208 const int64_t serialization_version =
209 std::get<serialization_version_index>(serialized);
210 TORCH_CHECK(
211 serialization_version <= SPARSE_LINEAR_PACKED_PARAM_SERIALIZATION_VERSION,
212 "Attempted to deserialize sparse qlinear packed params with an ",
213 "incompatible serialization version (",
214 serialization_version,
215 " > ",
216 SPARSE_LINEAR_PACKED_PARAM_SERIALIZATION_VERSION,
217 ")");
218
219 if (orig_bias_.has_value()) {
220 bias_ = orig_bias_.value();
221
222 TORCH_CHECK(
223 (bias_.ndimension() == 1 && bias_.size(0) == output_channels_),
224 "ao::sparse::qlinear_deserialize (qnnpack): Given weight of size ",
225 "{",
226 output_channels_,
227 ", ",
228 input_channels_,
229 "}",
230 ", expected bias to be 1-dimensional with ",
231 output_channels_,
232 " elements",
233 ", but got bias of size ",
234 bias_.sizes(),
235 " instead");
236 } else {
237 bias_ = at::zeros(output_channels_, at::device(at::kCPU).dtype(at::kFloat));
238 }
239
240 // Pad amount (8) comes from make_zero_points_and_scales_tensor
241 // https://github.com/pytorch/pytorch/blob/f8c1acea1e78573c04cd18893c4abff9eea64b03/aten/src/ATen/native/quantized/cpu/qnnpack_utils.h#L468
242 const int64_t output_channels_padded = output_channels_ + 8;
243
244 w_scales_ = at::empty(
245 {output_channels_padded}, at::device(at::kCPU).dtype(at::kFloat));
246 float* w_scales_data_ptr = w_scales_.mutable_data_ptr<float>();
247 std::fill_n(
248 w_scales_data_ptr + output_channels_,
249 output_channels_padded - output_channels_,
250 1); // Pad with 1
251
252 w_zero_points_ =
253 std::vector<uint8_t>(output_channels_padded, 0); // Pad with 0;
254
255 const float* w_scales_orig_data_ptr =
256 std::get<weight_scales_index>(serialized).data_ptr<float>();
257 const int8_t* w_zp_orig_data_ptr =
258 std::get<weight_zero_point_index>(serialized).data_ptr<int8_t>();
259
260 const std::function<uint8_t(int8_t)> add_128 = [](int8_t v) {
261 return static_cast<uint8_t>(static_cast<int16_t>(v) + 128);
262 };
263
264 if (q_scheme_ == at::kPerTensorAffine) {
265 std::fill_n(w_scales_data_ptr, output_channels_, w_scales_orig_data_ptr[0]);
266 std::fill_n(
267 w_zero_points_.begin(), output_channels_, w_zp_orig_data_ptr[0] + 128);
268 } else if (q_scheme_ == at::kPerChannelAffine) {
269 std::copy(
270 w_scales_orig_data_ptr,
271 w_scales_orig_data_ptr + output_channels_,
272 w_scales_data_ptr);
273 std::transform(
274 w_zp_orig_data_ptr,
275 w_zp_orig_data_ptr + output_channels_,
276 w_zero_points_.begin(),
277 add_128);
278 } else {
279 TORCH_CHECK(false, "Unsupported quantization scheme.");
280 }
281
282 deserialized_bcsr_row_block_indices_ =
283 std::get<row_block_indices_index>(serialized);
284 deserialized_bcsr_col_block_indices_ =
285 std::get<col_block_indices_index>(serialized);
286 deserialized_bcsr_weight_values_ = std::get<weight_values_index>(serialized);
287
288 #define AT_DISPATCH_CASE_BCSR_INDICES_TYPES(...) \
289 AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
290 AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
291 AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__)
292
293 #define AT_DISPATCH_BCSR_INDICES_TYPES(TYPE, NAME, ...) \
294 AT_DISPATCH_SWITCH( \
295 TYPE, NAME, AT_DISPATCH_CASE_BCSR_INDICES_TYPES(__VA_ARGS__))
296
297 bcsr_matrix_ = AT_DISPATCH_BCSR_INDICES_TYPES(
298 deserialized_bcsr_row_block_indices_.scalar_type(),
299 "packed_linear_weight_qnnp_setup_bcsr",
300 [&] {
301 using unsigned_t = UnsignedIndicesTypeTrait<scalar_t>::t;
302 return qnnpack::generateBlockCSRMatrix<unsigned_t>(
303 reinterpret_cast<unsigned_t*>(
304 deserialized_bcsr_col_block_indices_.data_ptr<scalar_t>()),
305 reinterpret_cast<unsigned_t*>(
306 deserialized_bcsr_row_block_indices_.data_ptr<scalar_t>()),
307 deserialized_bcsr_weight_values_.data_ptr<uint8_t>(),
308 deserialized_bcsr_col_block_indices_.numel(),
309 deserialized_bcsr_row_block_indices_.numel(),
310 deserialized_bcsr_weight_values_.numel(),
311 out_features_block_size_,
312 in_features_block_size_);
313 });
314
315 #undef AT_DISPATCH_CASE_BCSR_INDICES_TYPES
316 #undef AT_DISPATCH_BCSR_INDICES_TYPES
317 }
318 #endif // USE_PYTORCH_QNNPACK
319
320 } // namespace sparse
321 } // namespace ao
322