1 #pragma once
2
3 #include <ATen/Config.h>
4 #if AT_MKLDNN_ENABLED()
5 #include <ATen/Tensor.h>
6 #include <ATen/native/quantized/PackedParams.h>
7 #include <ideep.hpp>
8 #include <cpuinfo.h>
9
10 #include <c10/util/CallOnce.h>
11
12 using PrimitiveCacheKey = std::tuple<
13 double, // input_scale
14 int64_t, // input_zero_point
15 std::vector<int64_t>, // input_shape
16 double, // output_scale
17 int64_t, // output_zero_point
18 int64_t, // OMP_number_of_threads
19 double, // accum_scale
20 int64_t>; // accum_zero_point
21
22 enum CacheKeyIndex {
23 InputScale,
24 InputZeroPoint,
25 InputShape,
26 OutputScale,
27 OutputZeroPoint,
28 NumOfThreads,
29 };
30
31 // Base class of primitive cache
32 struct PrimitiveCache {
33 PrimitiveCacheKey key;
34
hitPrimitiveCache35 bool hit(const PrimitiveCacheKey& key) {
36 return this->key == key;
37 }
38 };
39
40 using LinearParams = ideep::matmul_forward_params;
41 using Conv = dnnl::convolution_forward;
42 using ConvDesc = dnnl::convolution_forward::primitive_desc;
43 using ConvParams = ideep::convolution_forward_params;
44 using Deconv = dnnl::deconvolution_forward;
45 using DeconvDesc = dnnl::deconvolution_forward::primitive_desc;
46 using DeconvParams = ideep::deconv_forward_params;
47
48 struct LinearPrimitiveCache : PrimitiveCache {
LinearPrimitiveCacheLinearPrimitiveCache49 LinearPrimitiveCache() {}
50
LinearPrimitiveCacheLinearPrimitiveCache51 LinearPrimitiveCache(
52 const PrimitiveCacheKey& key,
53 const LinearParams& param) {
54 this->key = key;
55 this->param = param;
56 }
57
58 LinearParams param;
59
60 // For dynamic qlinear, scale and zero point
61 // are set at execution time. So we only need to compare
62 // the rest part of key.
hit_dynamicLinearPrimitiveCache63 bool hit_dynamic(const PrimitiveCacheKey& new_key) {
64 auto cached_input_shape = std::get<InputShape>(this->key);
65 auto new_input_shape = std::get<InputShape>(new_key);
66 return (
67 cached_input_shape == new_input_shape &&
68 std::get<NumOfThreads>(this->key) == std::get<NumOfThreads>(new_key));
69 }
70
get_paramLinearPrimitiveCache71 LinearParams& get_param() {
72 return param;
73 }
74 };
75
76 struct ConvPrimitiveCache : PrimitiveCache {
ConvPrimitiveCacheConvPrimitiveCache77 ConvPrimitiveCache() {}
78
ConvPrimitiveCacheConvPrimitiveCache79 ConvPrimitiveCache(
80 const PrimitiveCacheKey& key,
81 const ConvParams& params) {
82 this->key = key;
83 this->params = params;
84 }
85
86 ConvParams params;
87
get_paramsConvPrimitiveCache88 ConvParams& get_params() {
89 return params;
90 }
91 };
92
93 struct DeconvPrimitiveCache : PrimitiveCache {
DeconvPrimitiveCacheDeconvPrimitiveCache94 DeconvPrimitiveCache() {}
95
DeconvPrimitiveCacheDeconvPrimitiveCache96 DeconvPrimitiveCache(
97 const PrimitiveCacheKey& key,
98 const DeconvParams& params) {
99 this->key = key;
100 this->params = params;
101 }
102
103 DeconvParams params;
104
get_paramsDeconvPrimitiveCache105 DeconvParams& get_params() {
106 return params;
107 }
108 };
109
110 enum PostOps {
111 NoPostOp,
112 Relu,
113 LeakyRelu,
114 Tanh,
115 Gelu
116 };
117
118
119 struct PackedLinearWeightsOnednn : public LinearPackedParamsBase {
PackedLinearWeightsOnednnPackedLinearWeightsOnednn120 PackedLinearWeightsOnednn(
121 std::unique_ptr<ideep::tensor> weight,
122 std::optional<ideep::tensor> bias,
123 at::Tensor orig_weight,
124 std::optional<at::Tensor> orig_bias)
125 : weight_(std::move(weight)),
126 bias_(std::move(bias)),
127 orig_weight_(std::move(orig_weight)),
128 orig_bias_(std::move(orig_bias)) {
129 cache_initialized_flag = std::make_unique<c10::once_flag>();
130 }
131 std::unique_ptr<ideep::tensor> weight_;
132 std::optional<ideep::tensor> bias_;
133 at::Tensor orig_weight_;
134 std::optional<at::Tensor> orig_bias_;
135
136 at::Tensor apply(
137 at::Tensor input,
138 double output_scale,
139 int64_t output_zero_point) override;
140 at::Tensor apply_relu(
141 at::Tensor input,
142 double output_scale,
143 int64_t output_zero_point) override;
144
145 at::Tensor apply_dynamic(at::Tensor input, bool reduce_range=false) override;
146 at::Tensor apply_dynamic_relu(at::Tensor input, bool reduce_range=false) override;
147
148 at::Tensor apply_leaky_relu(
149 at::Tensor input,
150 double output_scale,
151 int64_t output_zero_point,
152 double negative_slope);
153
154 at::Tensor apply_tanh(
155 at::Tensor input,
156 double output_scale,
157 int64_t output_zero_point);
158
159 std::tuple<at::Tensor, std::optional<at::Tensor>> unpack() override;
160
biasPackedLinearWeightsOnednn161 std::optional<at::Tensor> bias() override {
162 return orig_bias_;
163 }
164
165 static c10::intrusive_ptr<LinearPackedParamsBase> prepack(
166 at::Tensor weight,
167 std::optional<at::Tensor> bias);
168
169 private:
170 LinearPrimitiveCache prim_cache;
171 std::unique_ptr<c10::once_flag> cache_initialized_flag;
172
173 template <PostOps post_op>
174 at::Tensor apply_impl(
175 at::Tensor input,
176 double output_scale,
177 int64_t output_zero_point,
178 torch::List<at::Scalar> post_op_args = torch::List<at::Scalar>());
179
180 template <bool ReluFused>
181 at::Tensor apply_dynamic_impl(at::Tensor input, bool reduce_range=false);
182
get_cachePackedLinearWeightsOnednn183 LinearPrimitiveCache& get_cache() {
184 return prim_cache;
185 }
186 };
187
188 template <int kSpatialDim = 2>
189 struct PackedConvWeightsOnednn : public ConvPackedParamsBase<kSpatialDim> {
PackedConvWeightsOnednnPackedConvWeightsOnednn190 PackedConvWeightsOnednn(
191 std::unique_ptr<ideep::tensor> weight,
192 std::optional<ideep::tensor> bias,
193 at::Tensor orig_weight,
194 std::optional<at::Tensor> orig_bias,
195 torch::List<int64_t> stride,
196 torch::List<int64_t> padding,
197 torch::List<int64_t> output_padding,
198 torch::List<int64_t> dilation,
199 int64_t groups,
200 uint8_t transpose)
201 : weight_(std::move(weight)),
202 bias_(std::move(bias)),
203 orig_weight_(std::move(orig_weight)),
204 orig_bias_(std::move(orig_bias)),
205 stride_(std::move(stride)),
206 padding_(std::move(padding)),
207 output_padding_(std::move(output_padding)),
208 dilation_(std::move(dilation)),
209 groups_(groups),
210 transpose_(transpose) {
211 cache_initialized_flag = std::make_unique<c10::once_flag>();
212 }
213
214 std::unique_ptr<ideep::tensor> weight_;
215 std::optional<ideep::tensor> bias_;
216 at::Tensor orig_weight_;
217 std::optional<at::Tensor> orig_bias_;
218 torch::List<int64_t> stride_;
219 torch::List<int64_t> padding_;
220 torch::List<int64_t> output_padding_;
221 torch::List<int64_t> dilation_;
222 int64_t groups_;
223 uint8_t transpose_;
224
225 at::Tensor apply(
226 const at::Tensor& input,
227 double output_scale,
228 int64_t output_zero_point) override;
229
230 at::Tensor apply_relu(
231 const at::Tensor& input,
232 double output_scale,
233 int64_t output_zero_point) override;
234
235 at::Tensor apply_dynamic(
236 const at::Tensor& input,
237 bool reduce_range) override;
238
239 at::Tensor apply_add(
240 const at::Tensor& input,
241 const at::Tensor& accum,
242 double output_scale,
243 int64_t output_zero_point);
244
245 at::Tensor apply_add_relu(
246 const at::Tensor& input,
247 const at::Tensor& accum,
248 double output_scale,
249 int64_t output_zero_point);
250
251 std::tuple<at::Tensor, std::optional<at::Tensor>> unpack() override;
252
253 static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> prepack(
254 at::Tensor weight,
255 std::optional<at::Tensor> bias,
256 torch::List<int64_t> stride,
257 torch::List<int64_t> padding,
258 torch::List<int64_t> output_padding,
259 torch::List<int64_t> dilation,
260 int64_t groups,
261 bool transpose);
262
stridePackedConvWeightsOnednn263 torch::List<int64_t> stride() const override {
264 return stride_;
265 }
266
paddingPackedConvWeightsOnednn267 torch::List<int64_t> padding() const override {
268 return padding_;
269 }
270
output_paddingPackedConvWeightsOnednn271 torch::List<int64_t> output_padding() const override {
272 return output_padding_;
273 }
274
dilationPackedConvWeightsOnednn275 torch::List<int64_t> dilation() const override {
276 return dilation_;
277 }
278
groupsPackedConvWeightsOnednn279 int64_t groups() const override {
280 return groups_;
281 }
282
transposePackedConvWeightsOnednn283 bool transpose() const override {
284 return (bool)transpose_;
285 }
286
287 private:
288 ConvPrimitiveCache conv_prim_cache;
289 DeconvPrimitiveCache deconv_prim_cache;
290 std::unique_ptr<c10::once_flag> cache_initialized_flag;
291
292 template <bool ReluFused>
293 at::Tensor apply_impl(
294 const at::Tensor& input,
295 const std::optional<at::Tensor>& accum,
296 double output_scale,
297 int64_t output_zero_point);
298
get_conv_cachePackedConvWeightsOnednn299 ConvPrimitiveCache& get_conv_cache() {
300 assert(!transpose());
301 return conv_prim_cache;
302 }
303
get_deconv_cachePackedConvWeightsOnednn304 DeconvPrimitiveCache& get_deconv_cache() {
305 assert(transpose());
306 return deconv_prim_cache;
307 }
308 };
309
310 namespace onednn_utils {
311
create_attr_by_post_op(const c10::string_view & binary_post_op,double binary_alpha,double input1_scale,int64_t input1_zero_point,const ideep::tensor::desc & input1_desc,const c10::string_view & unary_post_op,const torch::List<std::optional<at::Scalar>> & unary_post_op_args,const c10::string_view & unary_post_op_algorithm)312 inline ideep::attr_t create_attr_by_post_op(
313 const c10::string_view& binary_post_op,
314 double binary_alpha,
315 double input1_scale,
316 int64_t input1_zero_point,
317 const ideep::tensor::desc& input1_desc,
318 const c10::string_view& unary_post_op,
319 const torch::List<std::optional<at::Scalar>>& unary_post_op_args,
320 const c10::string_view& unary_post_op_algorithm) {
321 using ideep::tensor;
322 if (binary_post_op == "none") {
323 if (unary_post_op == "relu") {
324 return ideep::attr_t::fuse_relu();
325 } else if (unary_post_op == "leaky_relu") {
326 TORCH_CHECK(
327 unary_post_op_args.size() == 1,
328 "onednn qlinear: expect one argument for post op leaky_relu but got ", unary_post_op_args.size(), " args");
329 auto alpha = unary_post_op_args[0].value().to<float>();
330 return ideep::attr_t::fuse_relu_v2(alpha);
331 } else if (unary_post_op == "tanh") {
332 return ideep::attr_t::fuse_tanh();
333 } else if (unary_post_op == "gelu") {
334 TORCH_CHECK(
335 unary_post_op_algorithm == "none" || unary_post_op_algorithm == "tanh",
336 "onednn qlinear: algorithm for post op gelu must be none or tanh but got ", unary_post_op_algorithm);
337 auto post_algorithm = unary_post_op_algorithm == "none" ?
338 dnnl::algorithm::eltwise_gelu_erf :
339 dnnl::algorithm::eltwise_gelu_tanh;
340 return ideep::attr_t::fuse_gelu_v2(0.f, 0.f, post_algorithm);
341 } else if (unary_post_op == "hardtanh") {
342 TORCH_CHECK(
343 unary_post_op_args.size() == 2 &&
344 unary_post_op_args[0].has_value() &&
345 unary_post_op_args[1].has_value(),
346 "hardtanh is expected to have two scalar input: min_val and max_val");
347 auto lower_bound_value =
348 unary_post_op_args[0].value().to<float>();
349 auto upper_bound_value =
350 unary_post_op_args[1].value().to<float>();
351 return ideep::attr_t::fuse_clamp(lower_bound_value, upper_bound_value);
352 } else if (unary_post_op == "hardswish") {
353 return ideep::attr_t::fuse_hardswish();
354 } else if (unary_post_op == "swish") {
355 return ideep::attr_t::fuse_swish();
356 } else {
357 TORCH_CHECK(
358 unary_post_op == "none",
359 "onednn qlinear: unsupported unary post op ", unary_post_op);
360 }
361 } else if (binary_post_op == "sum") {
362 if (unary_post_op == "none") {
363 return ideep::attr_t::fuse_sum(input1_scale, input1_zero_point);
364 } else if (unary_post_op == "relu") {
365 return ideep::attr_t::residual_with_sum_zero_point(input1_scale, input1_zero_point);
366 } else {
367 TORCH_CHECK(
368 false,
369 "onednn qlinear: unsupported unary post op ", unary_post_op, " with binary post op sum");
370 }
371 } else if (binary_post_op == "add") {
372 if (unary_post_op == "none") {
373 return ideep::attr_t::fuse_binary(ideep::algorithm::binary_add, input1_desc);
374 } else if (unary_post_op == "relu") {
375 ideep::post_ops po;
376 po.append_binary(ideep::algorithm::binary_add, input1_desc);
377 po.append_eltwise(ideep::algorithm::eltwise_relu, 0, 0);
378 return ideep::attr_t::attr_post_ops(po);
379 } else {
380 TORCH_CHECK(
381 false,
382 "onednn qlinear: unsupported unary post op ", unary_post_op, " with binary post op add");
383 }
384 } else {
385 TORCH_CHECK(
386 false,
387 "onednn qlinear: unsupported binary post op ", binary_post_op);
388 }
389 return ideep::attr_t();
390 }
391
392 // ONEDNN requires symmetric quantization of weight
393 // Use this util function to check.
is_weight_symmetric_quant(const at::Tensor & weight,bool is_transposed_conv)394 inline bool is_weight_symmetric_quant(
395 const at::Tensor& weight,
396 bool is_transposed_conv) {
397 bool is_symmetric = true;
398 const auto qtype = weight.qscheme();
399 if (qtype == c10::kPerTensorAffine) {
400 is_symmetric &= (weight.q_zero_point() == 0);
401 } else if (qtype == c10::kPerChannelAffine) {
402 if (is_transposed_conv) {
403 // This case is currently not supported in PyTorch
404 // but we do not want to raise an error in this util function.
405 is_symmetric = false;
406 } else {
407 auto output_channels = weight.size(0);
408 for (int i = 0; i < output_channels; ++i) {
409 auto zp = weight.q_per_channel_zero_points()[i].item<int32_t>();
410 is_symmetric &= (zp == 0);
411 }
412 }
413 } else {
414 // This case is currently not supported in PyTorch
415 // but we do not want to raise an error in this util function.
416 is_symmetric = false;
417 }
418 return is_symmetric;
419 }
420
421 // When qengine is x86, use this util func to check if onednn kernel
422 // is preferred than fbgemm's to get better performance.
should_use_onednn_quant(const at::Tensor & weight,bool is_transposed_conv,int groups,torch::List<int64_t> output_padding)423 inline bool should_use_onednn_quant(
424 const at::Tensor& weight,
425 bool is_transposed_conv,
426 int groups,
427 torch::List<int64_t> output_padding) {
428 // Performance of onednn is only validated on Linux right now.
429 // Also, the heuristics for dispatching are based on perf data on Linux.
430 // So, for x86 qengine, we always use fbgemm kernels if OS is not Linux.
431 // TODO Support more OSs.
432 #if !defined(__linux__)
433 return false;
434 #else
435 bool vnni_available = cpuinfo_has_x86_avx512vnni();
436 bool w_sym_quant =
437 is_weight_symmetric_quant(weight, is_transposed_conv);
438 bool opad_all_zero =
439 std::all_of(output_padding.begin(), output_padding.end(), [](int i) { return i==0; });
440 return vnni_available && (groups <= 100) && w_sym_quant && opad_all_zero;
441 #endif
442 }
443
444 } // onednn_utils
445
446 at::Tensor _qconv_prepack_onednn(
447 at::Tensor weight, // from CPU backend instead of QuantizedCPU
448 at::Tensor weight_scales, // Weight zero points must be 0 for onednn
449 double input_scale,
450 int64_t input_zero_point,
451 torch::List<int64_t> stride,
452 torch::List<int64_t> padding,
453 torch::List<int64_t> dilation,
454 int64_t groups,
455 std::optional<torch::List<int64_t>> input_shape=std::nullopt);
456
457 #endif // #if AT_MKLDNN_ENABLED()
458