1 #include <ATen/ATen.h>
2
3 #ifdef USE_FBGEMM
4 #include <ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h>
5 #endif
6 #ifdef USE_PYTORCH_QNNPACK
7 #include <ATen/native/ao_sparse/quantized/cpu/qnnpack_utils.h>
8
9 #include <utility>
10 #endif
11
12 namespace ao {
13 namespace sparse {
14
15 namespace {
16 /**
17 - Wrap a vector in a Tensor, copying data into its own data pointer.
18 - The type of vec is T& (not vector<T>&) so this works with any vector-like
19 datastructure which has .data() and .size()
20 */
21 template <typename UNDERLYING_DTYPE, typename T>
wrap_vector(T & vec,c10::ScalarType dtype)22 at::Tensor wrap_vector(T& vec, c10::ScalarType dtype) {
23 at::Tensor t = at::empty(
24 {static_cast<long>(vec.size())}, at::device(c10::kCPU).dtype(dtype));
25 std::copy(
26 vec.data(), vec.data() + vec.size(), t.mutable_data_ptr<UNDERLYING_DTYPE>());
27 return t;
28 }
29
30 #ifdef USE_FBGEMM
31 /**
32 * Adapted from Fbgemm BCSRMatrix::pack, but with zero points, without tiling,
33 * and without determining row_offsets
34 * https://github.com/pytorch/FBGEMM/blob/9d7c48a65419d0350f9e9e72f31e05bfe37e85a4/src/FbgemmSparseDense.cc#L84
35 */
pack_bcsr(const int8_t * src,const int64_t R,const int64_t C,const int64_t RB,const int64_t CB,const int8_t * zero_points,const bool qscheme_per_tensor)36 ao::sparse::BCSR pack_bcsr(
37 const int8_t* src,
38 const int64_t R,
39 const int64_t C,
40 const int64_t RB,
41 const int64_t CB,
42 const int8_t* zero_points,
43 const bool qscheme_per_tensor) {
44 const size_t ld = C;
45 std::vector<int32_t> rowBPtr;
46 std::vector<int32_t> colBIdx;
47 std::vector<int8_t> values;
48 rowBPtr.push_back(0);
49 int64_t nnzb = 0;
50 int64_t rowBlocks = (R + RB - 1) / RB;
51 for (int64_t i = 0; i < rowBlocks; ++i) {
52 int64_t curCols = C;
53 int64_t curColBlocks = (curCols + CB - 1) / CB;
54 for (int64_t j = 0; j < curColBlocks; ++j) {
55 // is the whole block zero?
56 bool isCurrentBlockNonZero = false;
57 for (int64_t ib = 0; ib < RB; ++ib) {
58 // break if already found a non-zero element or
59 // out of bounds
60 if (isCurrentBlockNonZero || (i * RB + ib) >= R) {
61 break;
62 }
63 const int64_t curr_row = i * RB + ib;
64 const int8_t curr_row_zero_point =
65 qscheme_per_tensor ? zero_points[0] : zero_points[curr_row];
66 for (int64_t jb = 0; jb < CB; ++jb) {
67 // within bound?
68 if ((j * CB + jb) >= C) {
69 continue;
70 } else {
71 if (src[curr_row * ld + j * CB + jb] != curr_row_zero_point) {
72 isCurrentBlockNonZero = true;
73 break;
74 }
75 }
76 }
77 }
78 if (isCurrentBlockNonZero) {
79 for (int64_t ib = 0; ib < RB; ++ib) {
80 for (int64_t jb = 0; jb < CB; ++jb) {
81 if ((i * RB + ib) >= R || (j * CB + jb) >= C) {
82 // zero fill
83 values.push_back(0);
84 } else {
85 int8_t val = src[(i * RB + ib) * ld + j * CB + jb];
86 values.push_back(val);
87 }
88 }
89 }
90 colBIdx.push_back(static_cast<int32_t>(j));
91 nnzb++;
92 }
93 }
94 rowBPtr.push_back(static_cast<int32_t>(nnzb));
95 }
96 return ao::sparse::BCSR(
97 std::move(values), std::move(rowBPtr), std::move(colBIdx));
98 }
99 #endif // USE_FBGEMM
100 } // namespace
101
102 #ifdef USE_FBGEMM
103
serialize()104 BCSRSerializationType PackedLinearWeight::serialize() {
105 // Get weights, row indices, and col indices in untiled form;
106 // unpack the tiled bcsr then pack it in untiled form
107 std::vector<int8_t> dense_weight_values = std::vector<int8_t>(w->R * w->C);
108 w->unpack(dense_weight_values.data());
109
110 const bool qscheme_per_tensor = (q_scheme == c10::kPerTensorAffine);
111 at::Tensor zero_points = wrap_vector<int8_t>(w_zp, c10::kChar);
112
113 ao::sparse::BCSR untiled_bcsr = pack_bcsr(
114 dense_weight_values.data(),
115 w->R,
116 w->C,
117 w->RB,
118 w->CB,
119 zero_points.data_ptr<int8_t>(),
120 qscheme_per_tensor);
121
122 std::vector<int8_t>& packed_weight_values = std::get<0>(untiled_bcsr);
123 // Add 128 to each weight value. This serialization format is best for
124 // minimizing memory footprint for QNNPack
125
126 at::Tensor weight_values = at::empty(
127 {static_cast<long>(packed_weight_values.size())},
128 at::device(c10::kCPU).dtype(c10::kByte));
129 std::transform(
130 packed_weight_values.begin(),
131 packed_weight_values.end(),
132 weight_values.mutable_data_ptr<uint8_t>(),
133 [](int8_t v) {
134 return static_cast<uint8_t>(static_cast<int16_t>(v) + 128);
135 });
136
137 return BCSRSerializationType(
138 SPARSE_LINEAR_PACKED_PARAM_SERIALIZATION_VERSION,
139 bias_,
140 out_features_block_size_,
141 in_features_block_size_,
142 wrap_vector<float>(w_scale, c10::kFloat),
143 // Narrowing from int32_t to int8_t; this is okay because qint8 zero
144 // points are restricted to fit in bounds of int_8
145 std::move(zero_points),
146 qscheme_per_tensor,
147 wrap_vector<int>(
148 std::get<1>(untiled_bcsr), c10::kInt), // Row block indices
149 wrap_vector<int>(
150 std::get<2>(untiled_bcsr), c10::kInt), // Col block indices
151 std::move(weight_values),
152 w->R,
153 w->C);
154 }
155
156 #endif // USE_FBGEMM
157
158 #ifdef USE_PYTORCH_QNNPACK
159
serialize()160 BCSRSerializationType PackedLinearWeightQnnp::serialize() {
161 at::Tensor w_scales_compact;
162 at::Tensor w_zero_points_compact;
163 const float* w_scales_data_ptr = w_scales_.const_data_ptr<float>();
164 std::function<int8_t(uint8_t)> subtract_128 = [](uint8_t v) {
165 return static_cast<int8_t>(static_cast<int16_t>(v) - 128);
166 };
167
168 if (q_scheme_ == at::kPerTensorAffine) {
169 w_scales_compact = at::empty({1}, at::device(c10::kCPU).dtype(c10::kFloat));
170 w_zero_points_compact =
171 at::empty({1}, at::device(c10::kCPU).dtype(c10::kChar));
172
173 w_scales_compact.mutable_data_ptr<float>()[0] = w_scales_data_ptr[0];
174 w_zero_points_compact.mutable_data_ptr<int8_t>()[0] =
175 static_cast<int8_t>(static_cast<int16_t>(w_zero_points_[0]) - 128);
176 } else if (q_scheme_ == at::kPerChannelAffine) {
177 w_scales_compact =
178 at::empty({output_channels_}, at::device(c10::kCPU).dtype(c10::kFloat));
179 w_zero_points_compact =
180 at::empty({output_channels_}, at::device(c10::kCPU).dtype(c10::kChar));
181
182 std::copy(
183 w_scales_data_ptr,
184 w_scales_data_ptr +
185 output_channels_, // Don't go to the end because of padding
186 w_scales_compact.mutable_data_ptr<float>());
187
188 // Subtract 128 from each zero point, to reverse addition done during
189 // prepacking
190 std::transform(
191 w_zero_points_.begin(),
192 w_zero_points_.begin() +
193 output_channels_, // Don't go to the end because of padding
194 w_zero_points_compact.mutable_data_ptr<int8_t>(),
195 std::move(subtract_128));
196 } else {
197 TORCH_CHECK(false, "Unsupported quantization scheme.");
198 }
199
200 at::Tensor wrapped_row_values;
201 at::Tensor wrapped_col_indices;
202
203 const uint32_t max_index = bcsr_matrix_->max_index();
204
205 if (max_index <= std::numeric_limits<uint8_t>::max()) {
206 // Cast from uint8_t range to int8_t
207 wrapped_row_values = QNNPACK_BCSRMATRIX_DISPATCH_INDICES_DTYPE(
208 bcsr_matrix_,
209 { return wrap_vector<int8_t>(typed_bcsr->row_values, c10::kChar); });
210 wrapped_col_indices = QNNPACK_BCSRMATRIX_DISPATCH_INDICES_DTYPE(
211 bcsr_matrix_,
212 { return wrap_vector<int8_t>(typed_bcsr->col_indices, c10::kChar); });
213 } else if (max_index <= std::numeric_limits<uint16_t>::max()) {
214 // Cast from uint16_t range to int16_t
215 wrapped_row_values = QNNPACK_BCSRMATRIX_DISPATCH_INDICES_DTYPE(
216 bcsr_matrix_,
217 { return wrap_vector<int16_t>(typed_bcsr->row_values, c10::kShort); });
218 wrapped_col_indices = QNNPACK_BCSRMATRIX_DISPATCH_INDICES_DTYPE(
219 bcsr_matrix_,
220 { return wrap_vector<int16_t>(typed_bcsr->col_indices, c10::kShort); });
221 } else {
222 // Cast from uint32_t range to int32_t
223 wrapped_row_values = QNNPACK_BCSRMATRIX_DISPATCH_INDICES_DTYPE(
224 bcsr_matrix_,
225 { return wrap_vector<int>(typed_bcsr->row_values, c10::kInt); });
226 wrapped_col_indices = QNNPACK_BCSRMATRIX_DISPATCH_INDICES_DTYPE(
227 bcsr_matrix_,
228 { return wrap_vector<int>(typed_bcsr->col_indices, c10::kInt); });
229 }
230
231 return BCSRSerializationType(
232 SPARSE_LINEAR_PACKED_PARAM_SERIALIZATION_VERSION,
233 orig_bias_,
234 out_features_block_size_,
235 in_features_block_size_,
236 std::move(w_scales_compact),
237 std::move(w_zero_points_compact),
238 (q_scheme_ == c10::kPerTensorAffine),
239 wrapped_row_values,
240 wrapped_col_indices,
241 wrap_vector<uint8_t>(bcsr_matrix_->values, c10::kByte),
242 output_channels_,
243 input_channels_);
244 }
245
246 #endif // USE_PYTORCH_QNNPACK
247
248 } // namespace sparse
249 } // namespace ao
250