xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/quantized/cpu/OnednnUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Config.h>
4 #if AT_MKLDNN_ENABLED()
5 #include <ATen/Tensor.h>
6 #include <ATen/native/quantized/PackedParams.h>
7 #include <ideep.hpp>
8 #include <cpuinfo.h>
9 
10 #include <c10/util/CallOnce.h>
11 
12 using PrimitiveCacheKey = std::tuple<
13     double, // input_scale
14     int64_t, // input_zero_point
15     std::vector<int64_t>, // input_shape
16     double, // output_scale
17     int64_t, // output_zero_point
18     int64_t, // OMP_number_of_threads
19     double, // accum_scale
20     int64_t>; // accum_zero_point
21 
22 enum CacheKeyIndex {
23   InputScale,
24   InputZeroPoint,
25   InputShape,
26   OutputScale,
27   OutputZeroPoint,
28   NumOfThreads,
29 };
30 
31 // Base class of primitive cache
32 struct PrimitiveCache {
33   PrimitiveCacheKey key;
34 
hitPrimitiveCache35   bool hit(const PrimitiveCacheKey& key) {
36     return this->key == key;
37   }
38 };
39 
40 using LinearParams = ideep::matmul_forward_params;
41 using Conv = dnnl::convolution_forward;
42 using ConvDesc = dnnl::convolution_forward::primitive_desc;
43 using ConvParams = ideep::convolution_forward_params;
44 using Deconv = dnnl::deconvolution_forward;
45 using DeconvDesc = dnnl::deconvolution_forward::primitive_desc;
46 using DeconvParams = ideep::deconv_forward_params;
47 
48 struct LinearPrimitiveCache : PrimitiveCache {
LinearPrimitiveCacheLinearPrimitiveCache49   LinearPrimitiveCache() {}
50 
LinearPrimitiveCacheLinearPrimitiveCache51   LinearPrimitiveCache(
52       const PrimitiveCacheKey& key,
53       const LinearParams& param) {
54     this->key = key;
55     this->param = param;
56   }
57 
58   LinearParams param;
59 
60   // For dynamic qlinear, scale and zero point
61   // are set at execution time. So we only need to compare
62   // the rest part of key.
hit_dynamicLinearPrimitiveCache63   bool hit_dynamic(const PrimitiveCacheKey& new_key) {
64     auto cached_input_shape = std::get<InputShape>(this->key);
65     auto new_input_shape = std::get<InputShape>(new_key);
66     return (
67         cached_input_shape == new_input_shape &&
68         std::get<NumOfThreads>(this->key) == std::get<NumOfThreads>(new_key));
69   }
70 
get_paramLinearPrimitiveCache71   LinearParams& get_param() {
72     return param;
73   }
74 };
75 
76 struct ConvPrimitiveCache : PrimitiveCache {
ConvPrimitiveCacheConvPrimitiveCache77   ConvPrimitiveCache() {}
78 
ConvPrimitiveCacheConvPrimitiveCache79   ConvPrimitiveCache(
80       const PrimitiveCacheKey& key,
81       const ConvParams& params) {
82     this->key = key;
83     this->params = params;
84   }
85 
86   ConvParams params;
87 
get_paramsConvPrimitiveCache88   ConvParams& get_params() {
89     return params;
90   }
91 };
92 
93 struct DeconvPrimitiveCache : PrimitiveCache {
DeconvPrimitiveCacheDeconvPrimitiveCache94   DeconvPrimitiveCache() {}
95 
DeconvPrimitiveCacheDeconvPrimitiveCache96   DeconvPrimitiveCache(
97       const PrimitiveCacheKey& key,
98       const DeconvParams& params) {
99     this->key = key;
100     this->params = params;
101   }
102 
103   DeconvParams params;
104 
get_paramsDeconvPrimitiveCache105   DeconvParams& get_params() {
106     return params;
107   }
108 };
109 
110 enum PostOps {
111   NoPostOp,
112   Relu,
113   LeakyRelu,
114   Tanh,
115   Gelu
116 };
117 
118 
119 struct PackedLinearWeightsOnednn : public LinearPackedParamsBase {
PackedLinearWeightsOnednnPackedLinearWeightsOnednn120   PackedLinearWeightsOnednn(
121       std::unique_ptr<ideep::tensor> weight,
122       std::optional<ideep::tensor> bias,
123       at::Tensor orig_weight,
124       std::optional<at::Tensor> orig_bias)
125       : weight_(std::move(weight)),
126         bias_(std::move(bias)),
127         orig_weight_(std::move(orig_weight)),
128         orig_bias_(std::move(orig_bias)) {
129     cache_initialized_flag = std::make_unique<c10::once_flag>();
130   }
131   std::unique_ptr<ideep::tensor> weight_;
132   std::optional<ideep::tensor> bias_;
133   at::Tensor orig_weight_;
134   std::optional<at::Tensor> orig_bias_;
135 
136   at::Tensor apply(
137       at::Tensor input,
138       double output_scale,
139       int64_t output_zero_point) override;
140   at::Tensor apply_relu(
141       at::Tensor input,
142       double output_scale,
143       int64_t output_zero_point) override;
144 
145   at::Tensor apply_dynamic(at::Tensor input, bool reduce_range=false) override;
146   at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range=false) override;
147 
148   at::Tensor apply_leaky_relu(
149       at::Tensor input,
150       double output_scale,
151       int64_t output_zero_point,
152       double negative_slope);
153 
154   at::Tensor apply_tanh(
155       at::Tensor input,
156       double output_scale,
157       int64_t output_zero_point);
158 
159   std::tuple<at::Tensor, std::optional<at::Tensor>> unpack() override;
160 
biasPackedLinearWeightsOnednn161   std::optional<at::Tensor> bias() override {
162     return orig_bias_;
163   }
164 
165   static c10::intrusive_ptr<LinearPackedParamsBase> prepack(
166       at::Tensor weight,
167       std::optional<at::Tensor> bias);
168 
169  private:
170   LinearPrimitiveCache prim_cache;
171   std::unique_ptr<c10::once_flag> cache_initialized_flag;
172 
173   template <PostOps post_op>
174   at::Tensor apply_impl(
175       at::Tensor input,
176       double output_scale,
177       int64_t output_zero_point,
178       torch::List<at::Scalar> post_op_args = torch::List<at::Scalar>());
179 
180   template <bool ReluFused>
181   at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range=false);
182 
get_cachePackedLinearWeightsOnednn183   LinearPrimitiveCache& get_cache() {
184     return prim_cache;
185   }
186 };
187 
188 template <int kSpatialDim = 2>
189 struct PackedConvWeightsOnednn : public ConvPackedParamsBase<kSpatialDim> {
PackedConvWeightsOnednnPackedConvWeightsOnednn190   PackedConvWeightsOnednn(
191       std::unique_ptr<ideep::tensor> weight,
192       std::optional<ideep::tensor> bias,
193       at::Tensor orig_weight,
194       std::optional<at::Tensor> orig_bias,
195       torch::List<int64_t> stride,
196       torch::List<int64_t> padding,
197       torch::List<int64_t> output_padding,
198       torch::List<int64_t> dilation,
199       int64_t groups,
200       uint8_t transpose)
201       : weight_(std::move(weight)),
202         bias_(std::move(bias)),
203         orig_weight_(std::move(orig_weight)),
204         orig_bias_(std::move(orig_bias)),
205         stride_(std::move(stride)),
206         padding_(std::move(padding)),
207         output_padding_(std::move(output_padding)),
208         dilation_(std::move(dilation)),
209         groups_(groups),
210         transpose_(transpose) {
211     cache_initialized_flag = std::make_unique<c10::once_flag>();
212   }
213 
214   std::unique_ptr<ideep::tensor> weight_;
215   std::optional<ideep::tensor> bias_;
216   at::Tensor orig_weight_;
217   std::optional<at::Tensor> orig_bias_;
218   torch::List<int64_t> stride_;
219   torch::List<int64_t> padding_;
220   torch::List<int64_t> output_padding_;
221   torch::List<int64_t> dilation_;
222   int64_t groups_;
223   uint8_t transpose_;
224 
225   at::Tensor apply(
226       const at::Tensor& input,
227       double output_scale,
228       int64_t output_zero_point) override;
229 
230   at::Tensor apply_relu(
231       const at::Tensor& input,
232       double output_scale,
233       int64_t output_zero_point) override;
234 
235   at::Tensor apply_dynamic(
236       const at::Tensor& input,
237       bool reduce_range) override;
238 
239   at::Tensor apply_add(
240       const at::Tensor& input,
241       const at::Tensor& accum,
242       double output_scale,
243       int64_t output_zero_point);
244 
245   at::Tensor apply_add_relu(
246       const at::Tensor& input,
247       const at::Tensor& accum,
248       double output_scale,
249       int64_t output_zero_point);
250 
251   std::tuple<at::Tensor, std::optional<at::Tensor>> unpack() override;
252 
253   static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> prepack(
254       at::Tensor weight,
255       std::optional<at::Tensor> bias,
256       torch::List<int64_t> stride,
257       torch::List<int64_t> padding,
258       torch::List<int64_t> output_padding,
259       torch::List<int64_t> dilation,
260       int64_t groups,
261       bool transpose);
262 
stridePackedConvWeightsOnednn263   torch::List<int64_t> stride() const override {
264     return stride_;
265   }
266 
paddingPackedConvWeightsOnednn267   torch::List<int64_t> padding() const override {
268     return padding_;
269   }
270 
output_paddingPackedConvWeightsOnednn271   torch::List<int64_t> output_padding() const override {
272     return output_padding_;
273   }
274 
dilationPackedConvWeightsOnednn275   torch::List<int64_t> dilation() const override {
276     return dilation_;
277   }
278 
groupsPackedConvWeightsOnednn279   int64_t groups() const override {
280     return groups_;
281   }
282 
transposePackedConvWeightsOnednn283   bool transpose() const override {
284     return (bool)transpose_;
285   }
286 
287  private:
288   ConvPrimitiveCache conv_prim_cache;
289   DeconvPrimitiveCache deconv_prim_cache;
290   std::unique_ptr<c10::once_flag> cache_initialized_flag;
291 
292   template <bool ReluFused>
293   at::Tensor apply_impl(
294       const at::Tensor& input,
295       const std::optional<at::Tensor>& accum,
296       double output_scale,
297       int64_t output_zero_point);
298 
get_conv_cachePackedConvWeightsOnednn299   ConvPrimitiveCache& get_conv_cache() {
300     assert(!transpose());
301     return conv_prim_cache;
302   }
303 
get_deconv_cachePackedConvWeightsOnednn304   DeconvPrimitiveCache& get_deconv_cache() {
305     assert(transpose());
306     return deconv_prim_cache;
307   }
308 };
309 
310 namespace onednn_utils {
311 
create_attr_by_post_op(const c10::string_view & binary_post_op,double binary_alpha,double input1_scale,int64_t input1_zero_point,const ideep::tensor::desc & input1_desc,const c10::string_view & unary_post_op,const torch::List<std::optional<at::Scalar>> & unary_post_op_args,const c10::string_view & unary_post_op_algorithm)312 inline ideep::attr_t create_attr_by_post_op(
313     const c10::string_view& binary_post_op,
314     double binary_alpha,
315     double input1_scale,
316     int64_t input1_zero_point,
317     const ideep::tensor::desc& input1_desc,
318     const c10::string_view& unary_post_op,
319     const torch::List<std::optional<at::Scalar>>& unary_post_op_args,
320     const c10::string_view& unary_post_op_algorithm) {
321   using ideep::tensor;
322   if (binary_post_op == "none") {
323     if (unary_post_op == "relu") {
324       return ideep::attr_t::fuse_relu();
325     } else if (unary_post_op == "leaky_relu") {
326       TORCH_CHECK(
327           unary_post_op_args.size() == 1,
328           "onednn qlinear: expect one argument for post op leaky_relu but got ", unary_post_op_args.size(), " args");
329       auto alpha = unary_post_op_args[0].value().to<float>();
330       return ideep::attr_t::fuse_relu_v2(alpha);
331     } else if (unary_post_op == "tanh") {
332       return ideep::attr_t::fuse_tanh();
333     } else if (unary_post_op == "gelu") {
334       TORCH_CHECK(
335           unary_post_op_algorithm == "none" || unary_post_op_algorithm == "tanh",
336           "onednn qlinear: algorithm for post op gelu must be none or tanh but got ", unary_post_op_algorithm);
337       auto post_algorithm = unary_post_op_algorithm == "none" ?
338         dnnl::algorithm::eltwise_gelu_erf :
339         dnnl::algorithm::eltwise_gelu_tanh;
340       return ideep::attr_t::fuse_gelu_v2(0.f, 0.f, post_algorithm);
341     } else if (unary_post_op == "hardtanh") {
342       TORCH_CHECK(
343           unary_post_op_args.size() == 2 &&
344               unary_post_op_args[0].has_value() &&
345               unary_post_op_args[1].has_value(),
346           "hardtanh is expected to have two scalar input: min_val and max_val");
347       auto lower_bound_value =
348           unary_post_op_args[0].value().to<float>();
349       auto upper_bound_value =
350           unary_post_op_args[1].value().to<float>();
351       return ideep::attr_t::fuse_clamp(lower_bound_value, upper_bound_value);
352     } else if (unary_post_op == "hardswish") {
353       return ideep::attr_t::fuse_hardswish();
354     } else if (unary_post_op == "swish") {
355       return ideep::attr_t::fuse_swish();
356     } else {
357       TORCH_CHECK(
358           unary_post_op == "none",
359           "onednn qlinear: unsupported unary post op ", unary_post_op);
360     }
361   } else if (binary_post_op == "sum") {
362     if (unary_post_op == "none") {
363       return ideep::attr_t::fuse_sum(input1_scale, input1_zero_point);
364     } else if (unary_post_op == "relu") {
365       return ideep::attr_t::residual_with_sum_zero_point(input1_scale, input1_zero_point);
366     } else {
367       TORCH_CHECK(
368           false,
369           "onednn qlinear: unsupported unary post op ", unary_post_op, " with binary post op sum");
370     }
371   } else if (binary_post_op == "add") {
372     if (unary_post_op == "none") {
373       return ideep::attr_t::fuse_binary(ideep::algorithm::binary_add, input1_desc);
374     } else if (unary_post_op == "relu") {
375       ideep::post_ops po;
376       po.append_binary(ideep::algorithm::binary_add, input1_desc);
377       po.append_eltwise(ideep::algorithm::eltwise_relu, 0, 0);
378       return ideep::attr_t::attr_post_ops(po);
379     } else {
380       TORCH_CHECK(
381           false,
382           "onednn qlinear: unsupported unary post op ", unary_post_op, " with binary post op add");
383     }
384   } else {
385     TORCH_CHECK(
386         false,
387         "onednn qlinear: unsupported binary post op ", binary_post_op);
388   }
389   return ideep::attr_t();
390 }
391 
392 // ONEDNN requires symmetric quantization of weight
393 // Use this util function to check.
is_weight_symmetric_quant(const at::Tensor & weight,bool is_transposed_conv)394 inline bool is_weight_symmetric_quant(
395       const at::Tensor& weight,
396       bool is_transposed_conv) {
397   bool is_symmetric = true;
398   const auto qtype = weight.qscheme();
399   if (qtype == c10::kPerTensorAffine) {
400     is_symmetric &= (weight.q_zero_point() == 0);
401   } else if (qtype == c10::kPerChannelAffine) {
402     if (is_transposed_conv) {
403       // This case is currently not supported in PyTorch
404       // but we do not want to raise an error in this util function.
405       is_symmetric = false;
406     } else {
407       auto output_channels = weight.size(0);
408       for (int i = 0; i < output_channels; ++i) {
409         auto zp = weight.q_per_channel_zero_points()[i].item<int32_t>();
410         is_symmetric &= (zp == 0);
411       }
412     }
413   } else {
414     // This case is currently not supported in PyTorch
415       // but we do not want to raise an error in this util function.
416     is_symmetric = false;
417   }
418   return is_symmetric;
419 }
420 
421 // When qengine is x86, use this util func to check if onednn kernel
422 // is preferred than fbgemm's to get better performance.
should_use_onednn_quant(const at::Tensor & weight,bool is_transposed_conv,int groups,torch::List<int64_t> output_padding)423 inline bool should_use_onednn_quant(
424     const at::Tensor& weight,
425     bool is_transposed_conv,
426     int groups,
427     torch::List<int64_t> output_padding) {
428   // Performance of onednn is only validated on Linux right now.
429   // Also, the heuristics for dispatching are based on perf data on Linux.
430   // So, for x86 qengine, we always use fbgemm kernels if OS is not Linux.
431   // TODO Support more OSs.
432 #if !defined(__linux__)
433   return false;
434 #else
435   bool vnni_available = cpuinfo_has_x86_avx512vnni();
436   bool w_sym_quant =
437       is_weight_symmetric_quant(weight, is_transposed_conv);
438   bool opad_all_zero =
439       std::all_of(output_padding.begin(), output_padding.end(), [](int i) { return i==0; });
440   return vnni_available && (groups <= 100) && w_sym_quant && opad_all_zero;
441 #endif
442 }
443 
444 } // onednn_utils
445 
446 at::Tensor _qconv_prepack_onednn(
447     at::Tensor weight, // from CPU backend instead of QuantizedCPU
448     at::Tensor weight_scales, // Weight zero points must be 0 for onednn
449     double input_scale,
450     int64_t input_zero_point,
451     torch::List<int64_t> stride,
452     torch::List<int64_t> padding,
453     torch::List<int64_t> dilation,
454     int64_t groups,
455     std::optional<torch::List<int64_t>> input_shape=std::nullopt);
456 
457 #endif // #if AT_MKLDNN_ENABLED()
458