1 #pragma once
2
3 #include <ATen/core/Tensor.h>
4 #include <ATen/core/List.h>
5 #include <ATen/native/quantized/cpu/fbgemm_utils.h>
6 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
7 #include <ATen/native/quantized/cpu/OnednnUtils.h>
8 #include <c10/util/irange.h>
9 #if !defined(__s390x__) && !defined(__powerpc__)
10 #include <cpuinfo.h>
11 #endif
12
13 #ifndef AT_PER_OPERATOR_HEADERS
14 #include <ATen/Functions.h>
15 #else
16 #include <ATen/ops/from_blob.h>
17 #endif
18
19
20 #include <tuple>
21
22 /* Convolution prepacked parameters serialization.
23 *
24 * Version 1
25 *
26 * - Fields:
27 * 1. weight
28 * 2. bias
29 * 3. stride x kSpatialDim
30 * 4. padding x kSpatialDim
31 * 5. dilation x kSpatialDim
32 * 6. groups
33 *
34 * Version 2
35 *
36 * - Fields:
37 * 0. version (string)
38 * 1. list of non-optional tensors
39 * 0: packed parameters (int16_t)
40 * - kSpatialDim
41 * - stride x kSpatialDim
42 * - padding x kSpatialDim
43 * - dilation x kSpatialDim
44 * - output_padding x kSpatialDim
45 * - groups
46 * - transpose (0 or 1)
47 * 1: weight
48 * 2. list of optional tensors
49 * 0: bias
50 *
51 * Version 3
52 *
53 * - Fields:
54 * 0. version (int64_t)
55 * 1. list of int64_t configuration values
56 * - kSpatialDim
57 * - stride x kSpatialDim
58 * - padding x kSpatialDim
59 * - dilation x kSpatialDim
60 * - output_padding x kSpatialDim
61 * - groups
62 * - flags (bitmask)
63 * - (1 << 0) transpose (1 = yes)
64 * 2. list of optional tensors
65 * 0: None (helps with type inference)
66 * 1: weight (this must be present)
67 * 2: bias
68 */
69
70 using ConvParamsSerializationTypeV2 = std::tuple<
71 // version, for versions 2 and up
72 std::string,
73 // non-optional tensors
74 std::vector<at::Tensor>,
75 // optional tensors
76 std::vector<std::optional<at::Tensor>>>;
77
78 using ConvParamsSerializationTypeV3 = std::tuple<
79 // version, int for versions 3 and up
80 int64_t,
81 // configuration values
82 std::vector<int64_t>,
83 // optional tensors
84 std::vector<std::optional<at::Tensor>>>;
85
86 // Parses any historical conv packed params format into
87 // the current format.
88 template <uint32_t kSpatialDim>
parse_conv_serialized_state(c10::IValue v)89 ConvParamsSerializationTypeV3 parse_conv_serialized_state(c10::IValue v) {
90
91 // determine the version based on IValue contents
92 int version = -1;
93 if (v.isTuple()) {
94 const auto& elements = v.toTupleRef().elements();
95 if (!elements.empty()) {
96 auto firstElement = elements[0];
97 if (firstElement.isTensor()) {
98 version = 1;
99 } else if (firstElement.isString()) {
100 const std::string& version_str = firstElement.toStringRef();
101 // note: not parsing the string to automatically handle bad
102 // inputs
103 if (version_str == "2") {
104 version = 2;
105 }
106 } else if (firstElement.isInt()) {
107 auto raw_version = firstElement.toInt();
108 if (raw_version == 3) {
109 version = 3;
110 }
111 }
112 }
113 }
114 TORCH_INTERNAL_ASSERT(version != -1, "Unable to parse serialization version");
115
116 if (version == 1) {
117 // version 1 - convert to version 3 manually
118
119 const auto& elements = v.toTupleRef().elements();
120
121 at::Tensor weight = elements[0].toTensor();
122 std::optional<at::Tensor> bias = elements[1].toOptional<at::Tensor>();
123 torch::List<at::Tensor> stride_x_kSpatialDim = elements[2].toTensorList();
124 torch::List<at::Tensor> padding_x_kSpatialDim = elements[3].toTensorList();
125 torch::List<at::Tensor> dilation_x_kSpatialDim = elements[4].toTensorList();
126 at::Tensor groups = elements[5].toTensor();
127
128 std::vector<int64_t> config_vals;
129 config_vals.reserve(
130 stride_x_kSpatialDim.size() + padding_x_kSpatialDim.size() +
131 dilation_x_kSpatialDim.size() + kSpatialDim + 3);
132 config_vals.push_back(kSpatialDim);
133 for (const auto i : c10::irange(stride_x_kSpatialDim.size())) {
134 auto stride = stride_x_kSpatialDim.get(i);
135 config_vals.push_back(stride[0].item<int16_t>());
136 }
137 for (const auto i : c10::irange(padding_x_kSpatialDim.size())) {
138 auto padding = padding_x_kSpatialDim.get(i);
139 config_vals.push_back(padding[0].item<int16_t>());
140 }
141 for (const auto i : c10::irange(dilation_x_kSpatialDim.size())) {
142 auto dilation = dilation_x_kSpatialDim.get(i);
143 config_vals.push_back(dilation[0].item<int16_t>());
144 }
145 // output_padding does not exist in v1, so we fill in a default value
146 for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) {
147 config_vals.push_back(0);
148 }
149 config_vals.push_back(groups[0].item<int16_t>());
150 // transpose does not exist in v1, so we fill in a default value
151 config_vals.push_back(0);
152
153 std::vector<std::optional<at::Tensor>> tensors;
154 tensors.emplace_back();
155 tensors.emplace_back(weight);
156 tensors.emplace_back(bias);
157
158 int64_t version = 3;
159 return std::tie(version, config_vals, tensors);
160 } else if (version == 2) {
161 // version 2
162 const auto& elements = v.toTupleRef().elements();
163 std::vector<at::Tensor> non_optional = elements[1].toTensorList().vec();
164 std::vector<std::optional<at::Tensor>> optional;
165
166 if (elements[2].isTensorList()) {
167 for (const auto& elem : elements[2].toTensorList()) {
168 optional.emplace_back(static_cast<at::Tensor>(elem));
169 }
170 } else {
171 for (const auto& elem : elements[2].toList()) {
172 optional.emplace_back(static_cast<c10::IValue>(elem).toOptional<at::Tensor>());
173 }
174 }
175 // create default optional value for bias
176 if (optional.empty()) {
177 optional.emplace_back();
178 }
179
180 auto config_a = non_optional[0].accessor<int16_t, 1>();
181 std::vector<int64_t> config_vals;
182 config_vals.reserve(config_a.size(0));
183 for (const auto i : c10::irange(config_a.size(0))) {
184 config_vals.emplace_back(config_a[i]);
185 }
186
187 auto weight = non_optional[1];
188 auto bias = optional[0];
189
190 std::vector<std::optional<at::Tensor>> tensors;
191 tensors.emplace_back();
192 tensors.emplace_back(weight);
193 tensors.emplace_back(bias);
194
195 int64_t version = 3;
196 return std::tie(version, config_vals, tensors);
197 } else if (version == 3) {
198 return v.to<ConvParamsSerializationTypeV3>();
199 } else {
200 TORCH_INTERNAL_ASSERT(false, "Unexpected serialized qconv version: ",
201 version);
202 }
203 }
204
205 #define QCONV_SERIALIZATION_VERSION 2
206
207 #if QCONV_SERIALIZATION_VERSION == 2
208 using ConvParamsSerializationType = ConvParamsSerializationTypeV2;
209
210 template <uint32_t kSpatialDim>
serialize_conv(const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> & params)211 ConvParamsSerializationTypeV2 serialize_conv(
212 const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>>& params) {
213
214 std::string version = "2";
215 std::vector<at::Tensor> non_optional;
216 std::vector<std::optional<at::Tensor>> optional;
217
218 // create a packed int8_t tensor for conv params
219 std::vector<int16_t> params_vec;
220 params_vec.push_back(kSpatialDim);
221 auto stride = params->stride().vec();
222 params_vec.insert(params_vec.end(), stride.begin(), stride.end());
223 auto padding = params->padding().vec();
224 params_vec.insert(params_vec.end(), padding.begin(), padding.end());
225 auto dilation = params->dilation().vec();
226 params_vec.insert(params_vec.end(), dilation.begin(), dilation.end());
227 auto output_padding = params->output_padding().vec();
228 params_vec.insert(params_vec.end(), output_padding.begin(),
229 output_padding.end());
230 params_vec.push_back(params->groups());
231 params_vec.push_back(params->transpose());
232 int64_t vec_size = params_vec.size();
233 at::Tensor params_tensor = at::from_blob(
234 params_vec.data(), {vec_size},
235 at::TensorOptions().dtype(at::kShort))
236 // clone to retain ownership of the data
237 .clone();
238
239 auto [weight, bias] = params->unpack();
240
241 non_optional.emplace_back(std::move(params_tensor));
242 non_optional.emplace_back(std::move(weight));
243 optional.emplace_back(std::move(bias));
244
245 return std::tie(version, non_optional, optional);
246 }
247
248 #elif QCONV_SERIALIZATION_VERSION == 3
249 using ConvParamsSerializationType = ConvParamsSerializationTypeV3;
250
251 template <uint32_t kSpatialDim>
serialize_conv(const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> & params)252 ConvParamsSerializationTypeV3 serialize_conv(
253 const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>>& params) {
254 std::vector<int64_t> config_vals;
255 config_vals.push_back(kSpatialDim);
256 auto stride = params->stride().vec();
257 config_vals.insert(config_vals.end(), stride.begin(), stride.end());
258 auto padding = params->padding().vec();
259 config_vals.insert(config_vals.end(), padding.begin(), padding.end());
260 auto dilation = params->dilation().vec();
261 config_vals.insert(config_vals.end(), dilation.begin(), dilation.end());
262 auto output_padding = params->output_padding().vec();
263 config_vals.insert(config_vals.end(), output_padding.begin(),
264 output_padding.end());
265 config_vals.push_back(params->groups());
266 config_vals.push_back(params->transpose());
267
268 auto [weight, bias] = params->unpack();
269
270 std::vector<std::optional<at::Tensor>> tensors;
271 tensors.emplace_back();
272 tensors.emplace_back(weight);
273 tensors.emplace_back(bias);
274
275 int64_t version = 3;
276 return std::tie(version, config_vals, tensors);
277 }
278
279 #else
280 #error "Invalid qconv serialization version."
281 #endif
282
283 template <uint32_t kSpatialDim>
deserialize_conv(ConvParamsSerializationTypeV3 state)284 c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> deserialize_conv(
285 ConvParamsSerializationTypeV3 state) {
286 auto [version, config_vals, tensors] = state;
287 TORCH_INTERNAL_ASSERT(version == 3, "Unexpected serialized qconv version: ", version);
288
289 TORCH_CHECK(tensors.size() == 3, "Wrong number of tensors", tensors.size());
290 std::optional<at::Tensor> weight = tensors[1];
291 std::optional<at::Tensor> bias = tensors[2];
292 TORCH_INTERNAL_ASSERT(weight, "Weight should always be present in serialized qconv.");
293
294 torch::List<int64_t> stride, padding, output_padding, dilation;
295 // skip kSpatialDim
296 int idx = 1;
297 for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) {
298 stride.emplace_back(config_vals.at(idx));
299 idx++;
300 }
301 for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) {
302 padding.emplace_back(config_vals.at(idx));
303 idx++;
304 }
305 for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) {
306 dilation.emplace_back(config_vals.at(idx));
307 idx++;
308 }
309 for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) {
310 TORCH_INTERNAL_ASSERT(idx < static_cast<int64_t>(config_vals.size()),
311 "Unexpected index = ", idx, " for config_vals of size ",
312 config_vals.size());
313 output_padding.emplace_back(config_vals.at(idx));
314 idx++;
315 }
316 int64_t groups = config_vals.at(idx);
317 idx++;
318 int64_t flags = config_vals.at(idx);
319 idx++;
320 TORCH_INTERNAL_ASSERT(idx == static_cast<int64_t>(config_vals.size()),
321 "Unexpected length of config_vals, expected ",
322 idx,
323 " got ",
324 config_vals.size());
325
326 bool transpose = flags & (1 << 0);
327
328 int64_t other_flags = flags & ~(1 << 0);
329 TORCH_INTERNAL_ASSERT(other_flags == 0, "Unexpected flags set in ", flags, ".");
330
331 auto& ctx = at::globalContext();
332
333 #ifdef USE_FBGEMM
334 if (ctx.qEngine() == at::QEngine::X86) {
335 #if AT_MKLDNN_ENABLED()
336 bool use_onednn = onednn_utils::should_use_onednn_quant(
337 weight.value(), transpose, groups, output_padding);
338 if (use_onednn) {
339 return PackedConvWeightsOnednn<kSpatialDim>::prepack(
340 weight.value(),
341 bias,
342 stride,
343 padding,
344 output_padding,
345 dilation,
346 groups,
347 transpose
348 );
349 }
350 #endif
351 return PackedConvWeight<kSpatialDim>::prepack(
352 weight.value(),
353 bias,
354 stride,
355 padding,
356 output_padding,
357 dilation,
358 groups,
359 transpose
360 );
361 } // x86
362 #endif
363
364 #ifdef USE_FBGEMM
365 if (ctx.qEngine() == at::QEngine::FBGEMM) {
366 return PackedConvWeight<kSpatialDim>::prepack(
367 weight.value(),
368 bias,
369 stride,
370 padding,
371 output_padding,
372 dilation,
373 groups,
374 transpose
375 );
376 }
377 #endif // USE_FBGEMM
378 #ifdef USE_PYTORCH_QNNPACK
379 if (ctx.qEngine() == at::QEngine::QNNPACK) {
380 TORCH_CHECK(
381 kSpatialDim == 2,
382 "prepack/__setstate__: QNNPACK only supports Conv2d "
383 "now.");
384 return PackedConvWeightsQnnp<kSpatialDim>::prepack(
385 weight.value(),
386 bias,
387 stride,
388 padding,
389 output_padding,
390 dilation,
391 groups,
392 transpose
393 );
394 }
395 #endif // USE_PYTORCH_QNNPACK
396 #if AT_MKLDNN_ENABLED()
397 if (ctx.qEngine() == at::QEngine::ONEDNN) {
398 return PackedConvWeightsOnednn<kSpatialDim>::prepack(
399 weight.value(),
400 bias,
401 stride,
402 padding,
403 output_padding,
404 dilation,
405 groups,
406 transpose
407 );
408 }
409 #endif // AT_MKLDNN_ENABLED()
410 TORCH_CHECK(
411 false,
412 "Didn't find engine for when deserializing ConvPackedParams: ",
413 toString(ctx.qEngine()));
414 }
415