xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/conv_serialization.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <ATen/core/List.h>
5 #include <ATen/native/quantized/cpu/fbgemm_utils.h>
6 #include <ATen/native/quantized/cpu/QnnpackUtils.h>
7 #include <ATen/native/quantized/cpu/OnednnUtils.h>
8 #include <c10/util/irange.h>
9 #if !defined(__s390x__) && !defined(__powerpc__)
10 #include <cpuinfo.h>
11 #endif
12 
13 #ifndef AT_PER_OPERATOR_HEADERS
14 #include <ATen/Functions.h>
15 #else
16 #include <ATen/ops/from_blob.h>
17 #endif
18 
19 
20 #include <tuple>
21 
22 /* Convolution prepacked parameters serialization.
23  *
24  * Version 1
25  *
26  * - Fields:
27  *  1. weight
28  *  2. bias
29  *  3. stride x kSpatialDim
30  *  4. padding x kSpatialDim
31  *  5. dilation x kSpatialDim
32  *  6. groups
33  *
34  * Version 2
35  *
36  * - Fields:
37  *  0. version (string)
38  *  1. list of non-optional tensors
39  *    0: packed parameters (int16_t)
40  *      - kSpatialDim
41  *      - stride x kSpatialDim
42  *      - padding x kSpatialDim
43  *      - dilation x kSpatialDim
44  *      - output_padding x kSpatialDim
45  *      - groups
46  *      - transpose (0 or 1)
47  *    1: weight
48  *  2. list of optional tensors
49  *    0: bias
50  *
51  * Version 3
52  *
53  * - Fields:
54  *  0. version (int64_t)
55  *  1. list of int64_t configuration values
56  *    - kSpatialDim
57  *    - stride x kSpatialDim
58  *    - padding x kSpatialDim
59  *    - dilation x kSpatialDim
60  *    - output_padding x kSpatialDim
61  *    - groups
62  *    - flags (bitmask)
63  *      - (1 << 0) transpose (1 = yes)
64  *  2. list of optional tensors
65  *    0: None (helps with type inference)
66  *    1: weight (this must be present)
67  *    2: bias
68  */
69 
70 using ConvParamsSerializationTypeV2 = std::tuple<
71   // version, for versions 2 and up
72   std::string,
73   // non-optional tensors
74   std::vector<at::Tensor>,
75   // optional tensors
76   std::vector<std::optional<at::Tensor>>>;
77 
78 using ConvParamsSerializationTypeV3 = std::tuple<
79   // version, int for versions 3 and up
80   int64_t,
81   // configuration values
82   std::vector<int64_t>,
83   // optional tensors
84   std::vector<std::optional<at::Tensor>>>;
85 
86 // Parses any historical conv packed params format into
87 // the current format.
88 template <uint32_t kSpatialDim>
parse_conv_serialized_state(c10::IValue v)89 ConvParamsSerializationTypeV3 parse_conv_serialized_state(c10::IValue v) {
90 
91   // determine the version based on IValue contents
92   int version = -1;
93   if (v.isTuple()) {
94     const auto& elements = v.toTupleRef().elements();
95     if (!elements.empty()) {
96       auto firstElement = elements[0];
97       if (firstElement.isTensor()) {
98         version = 1;
99       } else if (firstElement.isString()) {
100         const std::string& version_str = firstElement.toStringRef();
101         // note: not parsing the string to automatically handle bad
102         // inputs
103         if (version_str == "2") {
104           version = 2;
105         }
106       } else if (firstElement.isInt()) {
107         auto raw_version = firstElement.toInt();
108         if (raw_version == 3) {
109           version = 3;
110         }
111       }
112     }
113   }
114   TORCH_INTERNAL_ASSERT(version != -1, "Unable to parse serialization version");
115 
116   if (version == 1) {
117     // version 1 - convert to version 3 manually
118 
119     const auto& elements = v.toTupleRef().elements();
120 
121     at::Tensor weight = elements[0].toTensor();
122     std::optional<at::Tensor> bias = elements[1].toOptional<at::Tensor>();
123     torch::List<at::Tensor> stride_x_kSpatialDim = elements[2].toTensorList();
124     torch::List<at::Tensor> padding_x_kSpatialDim = elements[3].toTensorList();
125     torch::List<at::Tensor> dilation_x_kSpatialDim = elements[4].toTensorList();
126     at::Tensor groups = elements[5].toTensor();
127 
128     std::vector<int64_t> config_vals;
129     config_vals.reserve(
130         stride_x_kSpatialDim.size() + padding_x_kSpatialDim.size() +
131         dilation_x_kSpatialDim.size() + kSpatialDim + 3);
132     config_vals.push_back(kSpatialDim);
133     for (const auto i : c10::irange(stride_x_kSpatialDim.size())) {
134       auto stride = stride_x_kSpatialDim.get(i);
135       config_vals.push_back(stride[0].item<int16_t>());
136     }
137     for (const auto i : c10::irange(padding_x_kSpatialDim.size())) {
138       auto padding = padding_x_kSpatialDim.get(i);
139       config_vals.push_back(padding[0].item<int16_t>());
140     }
141     for (const auto i : c10::irange(dilation_x_kSpatialDim.size())) {
142       auto dilation = dilation_x_kSpatialDim.get(i);
143       config_vals.push_back(dilation[0].item<int16_t>());
144     }
145     // output_padding does not exist in v1, so we fill in a default value
146     for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) {
147       config_vals.push_back(0);
148     }
149     config_vals.push_back(groups[0].item<int16_t>());
150     // transpose does not exist in v1, so we fill in a default value
151     config_vals.push_back(0);
152 
153     std::vector<std::optional<at::Tensor>> tensors;
154     tensors.emplace_back();
155     tensors.emplace_back(weight);
156     tensors.emplace_back(bias);
157 
158     int64_t version = 3;
159     return std::tie(version, config_vals, tensors);
160   } else if (version == 2) {
161     // version 2
162     const auto& elements = v.toTupleRef().elements();
163     std::vector<at::Tensor> non_optional = elements[1].toTensorList().vec();
164     std::vector<std::optional<at::Tensor>> optional;
165 
166     if (elements[2].isTensorList()) {
167       for (const auto& elem : elements[2].toTensorList()) {
168         optional.emplace_back(static_cast<at::Tensor>(elem));
169       }
170     } else {
171       for (const auto& elem : elements[2].toList()) {
172         optional.emplace_back(static_cast<c10::IValue>(elem).toOptional<at::Tensor>());
173       }
174     }
175     // create default optional value for bias
176     if (optional.empty()) {
177       optional.emplace_back();
178     }
179 
180     auto config_a = non_optional[0].accessor<int16_t, 1>();
181     std::vector<int64_t> config_vals;
182     config_vals.reserve(config_a.size(0));
183     for (const auto i : c10::irange(config_a.size(0))) {
184       config_vals.emplace_back(config_a[i]);
185     }
186 
187     auto weight = non_optional[1];
188     auto bias = optional[0];
189 
190     std::vector<std::optional<at::Tensor>> tensors;
191     tensors.emplace_back();
192     tensors.emplace_back(weight);
193     tensors.emplace_back(bias);
194 
195     int64_t version = 3;
196     return std::tie(version, config_vals, tensors);
197   } else if (version == 3) {
198     return v.to<ConvParamsSerializationTypeV3>();
199   } else {
200     TORCH_INTERNAL_ASSERT(false, "Unexpected serialized qconv version: ",
201         version);
202   }
203 }
204 
205 #define QCONV_SERIALIZATION_VERSION 2
206 
207 #if QCONV_SERIALIZATION_VERSION == 2
208 using ConvParamsSerializationType = ConvParamsSerializationTypeV2;
209 
210 template <uint32_t kSpatialDim>
serialize_conv(const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> & params)211 ConvParamsSerializationTypeV2 serialize_conv(
212     const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>>& params) {
213 
214   std::string version = "2";
215   std::vector<at::Tensor> non_optional;
216   std::vector<std::optional<at::Tensor>> optional;
217 
218   // create a packed int8_t tensor for conv params
219   std::vector<int16_t> params_vec;
220   params_vec.push_back(kSpatialDim);
221   auto stride = params->stride().vec();
222   params_vec.insert(params_vec.end(), stride.begin(), stride.end());
223   auto padding = params->padding().vec();
224   params_vec.insert(params_vec.end(), padding.begin(), padding.end());
225   auto dilation = params->dilation().vec();
226   params_vec.insert(params_vec.end(), dilation.begin(), dilation.end());
227   auto output_padding = params->output_padding().vec();
228   params_vec.insert(params_vec.end(), output_padding.begin(),
229                     output_padding.end());
230   params_vec.push_back(params->groups());
231   params_vec.push_back(params->transpose());
232   int64_t vec_size = params_vec.size();
233   at::Tensor params_tensor = at::from_blob(
234       params_vec.data(), {vec_size},
235       at::TensorOptions().dtype(at::kShort))
236     // clone to retain ownership of the data
237     .clone();
238 
239   auto [weight, bias] = params->unpack();
240 
241   non_optional.emplace_back(std::move(params_tensor));
242   non_optional.emplace_back(std::move(weight));
243   optional.emplace_back(std::move(bias));
244 
245   return std::tie(version, non_optional, optional);
246 }
247 
248 #elif QCONV_SERIALIZATION_VERSION == 3
249 using ConvParamsSerializationType = ConvParamsSerializationTypeV3;
250 
251 template <uint32_t kSpatialDim>
serialize_conv(const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> & params)252 ConvParamsSerializationTypeV3 serialize_conv(
253     const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>>& params) {
254   std::vector<int64_t> config_vals;
255   config_vals.push_back(kSpatialDim);
256   auto stride = params->stride().vec();
257   config_vals.insert(config_vals.end(), stride.begin(), stride.end());
258   auto padding = params->padding().vec();
259   config_vals.insert(config_vals.end(), padding.begin(), padding.end());
260   auto dilation = params->dilation().vec();
261   config_vals.insert(config_vals.end(), dilation.begin(), dilation.end());
262   auto output_padding = params->output_padding().vec();
263   config_vals.insert(config_vals.end(), output_padding.begin(),
264                     output_padding.end());
265   config_vals.push_back(params->groups());
266   config_vals.push_back(params->transpose());
267 
268   auto [weight, bias] = params->unpack();
269 
270   std::vector<std::optional<at::Tensor>> tensors;
271   tensors.emplace_back();
272   tensors.emplace_back(weight);
273   tensors.emplace_back(bias);
274 
275   int64_t version = 3;
276   return std::tie(version, config_vals, tensors);
277 }
278 
279 #else
280 #error "Invalid qconv serialization version."
281 #endif
282 
283 template <uint32_t kSpatialDim>
deserialize_conv(ConvParamsSerializationTypeV3 state)284 c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> deserialize_conv(
285     ConvParamsSerializationTypeV3 state) {
286   auto [version, config_vals, tensors] = state;
287   TORCH_INTERNAL_ASSERT(version == 3, "Unexpected serialized qconv version: ", version);
288 
289   TORCH_CHECK(tensors.size() == 3, "Wrong number of tensors", tensors.size());
290   std::optional<at::Tensor> weight = tensors[1];
291   std::optional<at::Tensor> bias = tensors[2];
292   TORCH_INTERNAL_ASSERT(weight, "Weight should always be present in serialized qconv.");
293 
294   torch::List<int64_t> stride, padding, output_padding, dilation;
295   // skip kSpatialDim
296   int idx = 1;
297   for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) {
298     stride.emplace_back(config_vals.at(idx));
299     idx++;
300   }
301   for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) {
302     padding.emplace_back(config_vals.at(idx));
303     idx++;
304   }
305   for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) {
306     dilation.emplace_back(config_vals.at(idx));
307     idx++;
308   }
309   for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) {
310     TORCH_INTERNAL_ASSERT(idx < static_cast<int64_t>(config_vals.size()),
311         "Unexpected index = ", idx, " for config_vals of size ",
312         config_vals.size());
313     output_padding.emplace_back(config_vals.at(idx));
314     idx++;
315   }
316   int64_t groups = config_vals.at(idx);
317   idx++;
318   int64_t flags = config_vals.at(idx);
319   idx++;
320   TORCH_INTERNAL_ASSERT(idx == static_cast<int64_t>(config_vals.size()),
321       "Unexpected length of config_vals, expected ",
322       idx,
323       " got ",
324       config_vals.size());
325 
326   bool transpose = flags & (1 << 0);
327 
328   int64_t other_flags = flags & ~(1 << 0);
329   TORCH_INTERNAL_ASSERT(other_flags == 0, "Unexpected flags set in ", flags, ".");
330 
331   auto& ctx = at::globalContext();
332 
333 #ifdef USE_FBGEMM
334   if (ctx.qEngine() == at::QEngine::X86) {
335 #if AT_MKLDNN_ENABLED()
336     bool use_onednn = onednn_utils::should_use_onednn_quant(
337         weight.value(), transpose, groups, output_padding);
338     if (use_onednn) {
339       return PackedConvWeightsOnednn<kSpatialDim>::prepack(
340         weight.value(),
341         bias,
342         stride,
343         padding,
344         output_padding,
345         dilation,
346         groups,
347         transpose
348       );
349     }
350 #endif
351     return PackedConvWeight<kSpatialDim>::prepack(
352       weight.value(),
353       bias,
354       stride,
355       padding,
356       output_padding,
357       dilation,
358       groups,
359       transpose
360     );
361   } // x86
362 #endif
363 
364 #ifdef USE_FBGEMM
365   if (ctx.qEngine() == at::QEngine::FBGEMM) {
366     return PackedConvWeight<kSpatialDim>::prepack(
367       weight.value(),
368       bias,
369       stride,
370       padding,
371       output_padding,
372       dilation,
373       groups,
374       transpose
375     );
376   }
377 #endif // USE_FBGEMM
378 #ifdef USE_PYTORCH_QNNPACK
379   if (ctx.qEngine() == at::QEngine::QNNPACK) {
380     TORCH_CHECK(
381         kSpatialDim == 2,
382         "prepack/__setstate__: QNNPACK only supports Conv2d "
383         "now.");
384     return PackedConvWeightsQnnp<kSpatialDim>::prepack(
385       weight.value(),
386       bias,
387       stride,
388       padding,
389       output_padding,
390       dilation,
391       groups,
392       transpose
393     );
394   }
395 #endif // USE_PYTORCH_QNNPACK
396 #if AT_MKLDNN_ENABLED()
397   if (ctx.qEngine() == at::QEngine::ONEDNN) {
398     return PackedConvWeightsOnednn<kSpatialDim>::prepack(
399       weight.value(),
400       bias,
401       stride,
402       padding,
403       output_padding,
404       dilation,
405       groups,
406       transpose
407     );
408   }
409 #endif // AT_MKLDNN_ENABLED()
410 TORCH_CHECK(
411   false,
412   "Didn't find engine for when deserializing ConvPackedParams: ",
413   toString(ctx.qEngine()));
414 }
415