xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/mkl/mkl_qmatmul_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implements a quantized eight-bit version of the matmul operation with bias,
17 // relu and requantization fusion support utilizing oneDNN u8s8s32 inner
18 // product API. Right now, this version can support
19 //   - Input: quantized as uint8 via either MIN_FIRST or SCALE mode.
20 //            SCALE mode is selected when input is guaranteed to be non-
21 //            negative, e.g., MatMul is fed by Relu. Otherwise, MIN_FIRST is
22 //            selected.
23 //   - Weight: quantized to int8 via SCALE mode.
24 //   - Bias: float32/int32. For int32, it is quantized according to input and
25 //           filter min-max values.
26 // Other than that, this op does not support other input combination yet.
27 // When input is quantized to uint8 via MIN_FIRST, bias needs compensation.
28 // The detailed algorithm is illustrated as below:
29 //
30 // Af32 is the original fp32 activation 2D tensor.
31 // Min(Af32) is the minimum scalar value of Af32.
32 // Max(Af32) is the maximum scalar value of Af32.
33 // Qa is the quantization scale for activation.
34 // Au8 is the quantized unsigned int8 activation tensor.
35 // With SCALE quantization (used for non-negative Af32), Qa and Au8 can be
36 // calculated as below:
37 //    Qa = 255.0 / Max(Af32)
38 //    Au8 = round(Qa * Af32).
39 // With MIN_FIRST quantization, Q'a and A'u8 can be calculated as below:
40 //    Q'a = 255.0 / (Max(Af32) - Min(Af32))
41 //    A'u8 = round(Q'a * (Af32 - Min(Af32) * ones(Af32))),
42 // where, ones(.) is a tensor of all 1s with the same shape of its argument and
43 // round(.) rounds a number to its nearest integer.
44 //
45 // Wf32 is the original fp32 2D weight tensor.
46 // MaxAbs(Wf32) is the maximum absolute scalar value of Wf32.
47 // Qw is the quantization scale of weight.
48 // Ws8 is the quantized signed int8 weight tensor.
49 // Qw and Ws8 can be calculated as below:
50 //    Qw = 127.0 / MaxAbs(Wf32)
51 //    Ws8 = round(Qw * Wf32).
52 //
53 // Bf32 is the original fp32 1D bias tensor matching the innermost dim of
54 // Wf32.
55 // With SCALE quantization of activation, the scaled bias, Bs32, is calculated
56 // as below:
57 //      Bs32 = Qa * Qw * Bf32.
58 // With MIN_FIRST quantization of activation, the scaled bias tensor with
59 // compensation, B's32, is calculated as below:
60 //      B's32 = Q'a * Qw * Bf32 + Q'a * Qw * Min(Af32) * 1 * Wf32
61 //            = Q'a * Qw * Bf32 + Q'a * Min(Af32) * 1 * Ws8.
62 // where, 1 denotes a row vector matching the outermost dim of Wf32.
63 //
64 // The QuantizedMatMulWithBias op calculates 32bit integer output as below:
65 //  - with SCALE activation quantization:
66 //    Xs32 = Au8 * Ws8 + 1' * Bs32
67 //         = Qa * Qw * Af32 * Wf32  + Qa * Qw * 1' * Bf32
68 //         = Qa * Qw * (Af32 * Wf32 + 1' * Bf32) = Qa * Qw * Xf32,
69 //    where, 1' denotes a column vector matching the outermost dim of Af32 and
70 //    Xf32 represents the output of original fp32 MatMul with BiasAdd fusion.
71 //
72 //  - with MIN_FIRST activation quantization:
73 //    Xs32 = A'u8 * Ws8 + 1' * B's32
74 //         = Q'a * (Af32 - Min(Af32) * ones(Af32)) * Qw * Wf32 +
75 //           Q'a * Qw * 1' * Bf32 + Q'a * Qw * Min(Af32) * 1' * 1 * Wf32
76 //         = Q'a * Qw * (Af32 * Wf32 + 1' * Bf32)
77 //         = Q'a * Qw * Xf32.
78 //    Note that 1' * 1 = ones(Af32).
79 //
80 // The QuantizedMatMulWithBiasAndRelu op does the same calculation as above
81 // except adding relu function for the 32bit integer output.
82 //
83 // The QuantizedMatMulWithBiasAndReluAndRequantize op does one more step of
84 // requantize calculation based on above. Since the fusion ends with a Relu the
85 // activation Xf32 at Relu, in the original fp32 graph, is guaranteed to be
86 // non-negative. The requantize scale Qr is calculated from offline calibration.
87 //    Qr = 255 / Max(Xf32)
88 //    Xu8 = Qr * Xf32.
89 //
90 // More information of this implementation can be found in
91 // https://software.intel.com/en-us/articles/lower-numerical-precision-deep-learning-inference-and-training
92 #ifdef INTEL_MKL
93 
94 #include "tensorflow/core/framework/register_types.h"
95 #include "tensorflow/core/kernels/fill_functor.h"
96 #include "tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h"
97 #include "tensorflow/core/kernels/mkl/mkl_quantized_conv_ops.h"
98 #include "tensorflow/core/kernels/no_op.h"
99 #include "tensorflow/core/lib/core/errors.h"
100 #include "tensorflow/core/util/mkl_threadpool.h"
101 #include "tensorflow/core/util/mkl_util.h"
102 #include "tensorflow/core/util/work_sharder.h"
103 
104 namespace {
105 enum {
106   QUANTIZE_MODE_MIN_FIRST,
107   QUANTIZE_MODE_SCALED,
108 };
109 }  // namespace
110 
111 namespace tensorflow {
112 
113 template <typename Device, typename Tinput, typename Tweight, typename Tbias,
114           typename Toutput, bool native_format = false>
115 class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> {
116  public:
~MklDnnQuantizedMatMulOp()117   virtual ~MklDnnQuantizedMatMulOp() {
118     if (this->input_bias_ != nullptr) {
119       delete this->input_bias_;
120       input_bias_ = nullptr;
121     }
122     if (this->scaled_bias_ != nullptr) {
123       delete this->scaled_bias_;
124       scaled_bias_ = nullptr;
125     }
126     if (this->comp_bias_ != nullptr) {
127       delete this->comp_bias_;
128       comp_bias_ = nullptr;
129     }
130   }
131 
GetCompBiasBuffer(int size)132   float* GetCompBiasBuffer(int size) {
133     if (comp_bias_ == nullptr) {
134       comp_bias_ = new float[size];
135     }
136     return comp_bias_;
137   }
138 
MklDnnQuantizedMatMulOp(OpKernelConstruction * context)139   explicit MklDnnQuantizedMatMulOp(OpKernelConstruction* context)
140       : MklDnnMatMulOpBase<Tweight, Toutput>(context) {
141     string mode_string;
142     OP_REQUIRES_OK(context, context->GetAttr("input_quant_mode", &mode_string));
143     if (mode_string == "MIN_FIRST") {
144       mode_ = QUANTIZE_MODE_MIN_FIRST;
145     } else if (mode_string == "SCALED") {
146       mode_ = QUANTIZE_MODE_SCALED;
147     } else {
148       context->CtxFailure(errors::InvalidArgument(
149           "Quantization mode must be either MIN_FIRST or SCALED, but received ",
150           mode_string));
151     }
152     this->is_weight_const_ = false;
153     if (context->HasAttr("is_weight_const")) {
154       OP_REQUIRES_OK(context, context->GetAttr("is_weight_const",
155                                                &(this->is_weight_const_)));
156     }
157   }
158 
Compute(OpKernelContext * context)159   void Compute(OpKernelContext* context) override {
160     try {
161       // Input tensors
162       const Tensor& src_tensor = MklGetInput(context, this->kInputIndexSrc);
163       const Tensor& weight_tensor =
164           MklGetInput(context, this->kInputIndexWeight);
165       const Tensor& bias_tensor = MklGetInput(context, this->kInputIndexBias);
166 
167       MklDnnShape src_mkl_shape, weight_mkl_shape;
168       GetMklShape(context, this->kInputIndexSrc, &src_mkl_shape, native_format);
169       GetMklShape(context, this->kInputIndexWeight, &weight_mkl_shape,
170                   native_format);
171       OP_REQUIRES(context, !weight_mkl_shape.IsMklTensor(),
172                   errors::InvalidArgument("Weight should not be in "
173                                           "MKL Layout"));
174 
175       MklDnnData<Tinput> src(&(this->cpu_engine_));
176       MklDnnData<Tweight> weight(&(this->cpu_engine_));
177 
178       memory::dims src_dims, weight_dims;
179       memory::dims dst_dims_tf_order, dst_dims_mkl_order;
180 
181       // Get shapes of input tensors in oneDNN order
182       auto src_tf_shape = src_mkl_shape.IsMklTensor()
183                               ? src_mkl_shape.GetTfShape()
184                               : src_tensor.shape();
185       auto weight_tf_shape = weight_mkl_shape.IsMklTensor()
186                                  ? weight_mkl_shape.GetTfShape()
187                                  : weight_tensor.shape();
188 
189       src_dims = TFShapeToMklDnnDims(src_tf_shape);
190       weight_dims = TFShapeToMklDnnDims(weight_tf_shape);
191       dst_dims_mkl_order = {static_cast<int>(src_tf_shape.dim_size(0)),
192                             static_cast<int>(weight_tf_shape.dim_size(1))};
193 
194       // Weight dims need to be reversed to create inner-product forward
195       // descriptor
196       weight_dims = {static_cast<int>(weight_tf_shape.dim_size(1)),
197                      static_cast<int>(weight_tf_shape.dim_size(0))};
198 
199       // Create memory for user data.
200       // Describe how the inputs and outputs of inner-product look like. Also
201       // specify buffers containing actual input and output data.
202       Tensor* dst_tensor = nullptr;
203       auto input_output_fmt = memory::format_tag::nc;
204       auto input_output_fmt_mkldnn = MklTensorFormat::FORMAT_NC;
205 
206       // If input is in MKL layout, then simply take input layout; otherwise,
207       // construct input TF layout. For TF layout, although input shape
208       // (src_dims) required is in oneDNN order, the layout is Tensorflow's
209       // layout depending on data format.
210       auto src_md =
211           src_mkl_shape.IsMklTensor()
212               ? src_mkl_shape.GetMklLayout()
213               : memory::desc(src_dims, MklDnnType<Tinput>(), input_output_fmt);
214       src.SetUsrMem(src_md, &src_tensor);
215 
216       // Although weight shape (weight_dims) required is in oneDNN order,
217       // the layout is TensorFlow's layout.
218       auto weight_md = weight_mkl_shape.IsMklTensor()
219                            ? weight_mkl_shape.GetMklLayout()
220                            : memory::desc(weight_dims, MklDnnType<Tweight>(),
221                                           memory::format_tag::io);
222       weight.SetUsrMem(weight_md, &weight_tensor);
223 
224       MklDnnMatMulFwdPrimitive<float, Tinput, Tweight, Tbias, Toutput>*
225           matmul_fwd = nullptr;
226       memory::dims bias_dims = {static_cast<int>(bias_tensor.dim_size(0))};
227 
228       MklDnnMatMulFwdParams matmul_fwd_dims(src_dims, weight_dims, bias_dims,
229                                             dst_dims_mkl_order);
230 
231       // Extend the basic parameters for data types and fusions.
232       this->ExtendMklDnnMatMulFwdParams(context, matmul_fwd_dims);
233 
234       // Get a MatMul fwd from primitive pool.
235       matmul_fwd =
236           MklDnnMatMulFwdPrimitiveFactory<float, Tinput, Tweight, Tbias,
237                                           Toutput>::Get(matmul_fwd_dims, 0);
238 
239       // Allocate output Tensor.
240       std::shared_ptr<dnnl::inner_product_forward::primitive_desc>
241           matmul_fwd_pd = matmul_fwd->GetPrimitiveDesc();
242       this->AllocateOutputTensor(context, *matmul_fwd_pd, dst_dims_mkl_order,
243                                  input_output_fmt_mkldnn, &dst_tensor,
244                                  native_format);
245 
246       Toutput* dst_data =
247           reinterpret_cast<Toutput*>(dst_tensor->flat<Toutput>().data());
248 
249       // Check if src and weight data need to be reordered.
250       Tinput* src_data = nullptr;
251       if (!native_format && src_md != matmul_fwd_pd->src_desc()) {
252         src.SetUsrMem(src_md, &src_tensor);
253         src.CheckReorderToOpMem(matmul_fwd_pd.get()->src_desc(),
254                                 this->cpu_engine_, context);
255         src_data = static_cast<Tinput*>(src.GetOpMem().get_data_handle());
256       } else {
257         src_data = static_cast<Tinput*>(
258             const_cast<Tinput*>(src_tensor.flat<Tinput>().data()));
259       }
260 
261       Tweight* weight_data = nullptr;
262       if (weight_md != matmul_fwd_pd->weights_desc()) {
263         bool is_weight_cached = false;
264         // For batch size 1, oneDNN expects that weight format is OI whereas
265         // TF default format is IO. So in that case convert weight from IO
266         // to OI for the first iteration and cache it to reuse in the
267         // subsequent iterations, if the weight is constant.
268         if (this->is_weight_const_) {
269           // Check if the weight is already cached or not
270           if (this->IsWeightCacheEmpty(context)) {
271             // Cache weight if it is not cached.
272             this->CacheWeight(context, matmul_fwd_pd, weight_data,
273                               weight_tensor, weight, weight_md);
274           }
275           weight_data =
276               this->GetCachedWeight(context, matmul_fwd_pd->weights_desc());
277           is_weight_cached = (weight_data != nullptr);
278         }
279 
280         if (!is_weight_cached) {
281           weight.SetUsrMem(weight_md, &weight_tensor);
282           weight.CheckReorderToOpMem(matmul_fwd_pd.get()->weights_desc(),
283                                      this->cpu_engine_, context);
284           weight_data =
285               static_cast<Tweight*>(weight.GetOpMem().get_data_handle());
286         }
287 
288       } else {
289         weight_data = static_cast<Tweight*>(
290             const_cast<Tweight*>(weight_tensor.flat<Tweight>().data()));
291       }
292 
293       std::shared_ptr<stream> cpu_stream;
294       MklDnnThreadPool eigen_tp(context);
295       cpu_stream.reset(CreateStream(&eigen_tp, matmul_fwd->GetEngine()));
296 
297       UserScratchPad<unsigned char> scratch_pad;
298       scratch_pad.AllocateSPTensor(matmul_fwd, context);
299 
300       // Execute inner-product
301       Tbias* bias_data = this->GetBiasHandle(
302           context, matmul_fwd_pd, bias_tensor, weight_tensor, cpu_stream);
303       matmul_fwd->Execute(src_data, weight_data, bias_data, dst_data,
304                           scratch_pad.Get(), cpu_stream);
305     } catch (dnnl::error& e) {
306       string error_msg = tensorflow::strings::StrCat(
307           "Status: ", e.status, ", message: ", string(e.message), ", in file ",
308           __FILE__, ":", __LINE__);
309       OP_REQUIRES_OK(
310           context,
311           errors::Aborted("Operation received an exception:", error_msg));
312     }
313     float min_output_value;
314     float max_output_value;
315     if (std::is_same<Toutput, quint8>::value ||
316         std::is_same<Toutput, qint8>::value) {
317       // This is the case the inner-product and requantization are fused.
318       // "min_freezed_output" and "max_freezed_output" are the requested range
319       // for the output.
320       min_output_value = context->input(7).flat<float>()(0);
321       max_output_value = context->input(8).flat<float>()(0);
322     } else {
323       ComputeOutputRangeForInt32(context, &min_output_value, &max_output_value);
324     }
325 
326     if (std::is_same<Toutput, quint8>::value ||
327         std::is_same<Toutput, qint8>::value ||
328         std::is_same<Toutput, qint32>::value) {
329       Tensor* output_min = nullptr;
330       Tensor* output_max = nullptr;
331       MklDnnShape output_min_mkl_shape, output_max_mkl_shape;
332       output_min_mkl_shape.SetMklTensor(false);
333       output_max_mkl_shape.SetMklTensor(false);
334       AllocateOutputSetMklShape(context, 1, &output_min, {},
335                                 output_min_mkl_shape, native_format);
336       AllocateOutputSetMklShape(context, 2, &output_max, {},
337                                 output_max_mkl_shape, native_format);
338       output_min->flat<float>()(0) = min_output_value;
339       output_max->flat<float>()(0) = max_output_value;
340     }
341   }
342 
343  protected:
ComputeOutputRangeForInt32(OpKernelContext * context,float * min_output_value,float * max_output_value)344   void ComputeOutputRangeForInt32(OpKernelContext* context,
345                                   float* min_output_value,
346                                   float* max_output_value) {
347     const float min_input = context->input(3).flat<float>()(0);
348     const float max_input = context->input(4).flat<float>()(0);
349     const float min_weight = context->input(5).flat<float>()(0);
350     const float max_weight = context->input(6).flat<float>()(0);
351     MklQuantizationRangeForMultiplication<quint8, qint8, qint32>(
352         min_input, max_input, min_weight, max_weight, min_output_value,
353         max_output_value);
354   }
355 
ExtendMklDnnMatMulFwdParams(OpKernelContext * context,MklDnnMatMulFwdParams & params)356   virtual void ExtendMklDnnMatMulFwdParams(OpKernelContext* context,
357                                            MklDnnMatMulFwdParams& params) {
358     // Append data type names of input, weight, bias, and output.
359     params.dtypes.append(typeid(Tinput).name());
360     params.dtypes.append(typeid(Tweight).name());
361     params.dtypes.append(typeid(Tbias).name());
362     params.dtypes.append(typeid(Toutput).name());
363 
364     // When the output type is quint8, the output data is requantized into
365     // quint8. A post_op "output_scale" is added to do the conversion.
366     if (std::is_same<Toutput, quint8>::value ||
367         std::is_same<Toutput, qint8>::value ||
368         std::is_same<Toutput, float>::value) {
369       float min_output_value;
370       float max_output_value;
371       ComputeOutputRangeForInt32(context, &min_output_value, &max_output_value);
372       float scale_int32 =
373           std::max(std::abs(min_output_value), std::abs(max_output_value));
374       const float min_freezed_output = context->input(7).flat<float>()(0);
375       const float max_freezed_output = context->input(8).flat<float>()(0);
376       float scale_eightbit =
377           std::max(std::abs(min_freezed_output), std::abs(max_freezed_output));
378       float scale = 1.0;
379       if (std::is_same<Toutput, quint8>::value) {
380         scale = scale_int32 / scale_eightbit / static_cast<float>(1u << 23);
381       } else if (std::is_same<Toutput, qint8>::value) {
382         scale = scale_int32 / scale_eightbit / static_cast<float>(1u << 24);
383       } else if (std::is_same<Toutput, float>::value) {
384         scale = scale_int32 / static_cast<float>(1u << 31);
385       } else {
386         // TODO(intel-tf): Keep the default qint8 as before.
387         // Change to error later.
388         scale = scale_int32 / scale_eightbit / static_cast<float>(1u << 24);
389       }
390       std::vector<float> output_scale;
391       output_scale.push_back(scale);
392       params.post_op_params.push_back({"output_scale", output_scale});
393     }
394   }
395 
396   // This function handles bias conversion and compensation for MIN_FIRST and
397   // SCALE mode. If input is quantized via MIN_FIRST,
398   //  B's32 = Q'a * Qw * Bf32 + Q'a * Qw * Min(Af32) * 1 * Wf32
399   // If input is quantized via SCALE,
400   //   Bs32 = Qa * Qw * Bf32.
GetBiasHandle(OpKernelContext * context,std::shared_ptr<dnnl::inner_product_forward::primitive_desc> & mkldnn_matmul_fwd_pd,const Tensor & bias_tensor,const Tensor & weight_tensor,std::shared_ptr<stream> reorder_stream)401   Tbias* GetBiasHandle(
402       OpKernelContext* context,
403       std::shared_ptr<dnnl::inner_product_forward::primitive_desc>&
404           mkldnn_matmul_fwd_pd,
405       const Tensor& bias_tensor, const Tensor& weight_tensor,
406       std::shared_ptr<stream> reorder_stream) {
407     // If the bias is qint32, it means the bias is already converted offline.
408     // and it can be added to matmul output directly.
409     if (std::is_same<Tbias, qint32>::value) {
410       return static_cast<Tbias*>(
411           const_cast<Tbias*>(bias_tensor.flat<Tbias>().data()));
412     } else {
413       // If the bias is fp32, then need to calculate the bias
414       const float min_input = context->input(3).flat<float>()(0);
415       const float max_input = context->input(4).flat<float>()(0);
416       const float min_weight = context->input(5).flat<float>()(0);
417       const float max_weight = context->input(6).flat<float>()(0);
418 
419       std::vector<dnnl::primitive> net;
420       float out_scale;
421       // If the bias is float and input quantize is MIN_FIRST, bias has to be
422       // compensated with B's32 = Q'a * Qw * Bf32 + Q'a * Qw * Min(Af32) * 1 *
423       // Wf32.
424       if (mode_ == QUANTIZE_MODE_MIN_FIRST) {
425         int k = weight_tensor.dim_size(0);
426         int n = weight_tensor.dim_size(1);
427         float* comp_bias = GetCompBiasBuffer(n);
428 
429         qint8* wt_buf = static_cast<qint8*>(
430             const_cast<qint8*>(weight_tensor.flat<qint8>().data()));
431 
432         const float* bias_buf = static_cast<float*>(
433             const_cast<float*>(bias_tensor.flat<float>().data()));
434 
435         float qa_amin = 255 * min_input / (max_input - min_input);
436 
437         out_scale = (255.0 * 127.0) /
438                     ((max_input - min_input) *
439                      std::max(std::abs(max_weight), std::abs(min_weight)));
440 
441 #ifndef ENABLE_ONEDNN_OPENMP
442         auto parallel_func = [&](int64 start, int64 end) {
443           for (int64 j = start; j < end; j++) {
444             int x = 0;
445             for (int64 i = 0; i < k; ++i) {
446               x += wt_buf[i * n + j];
447             }
448             comp_bias[j] =
449                 ((bias_buf[j] * out_scale) + static_cast<float>(x * qa_amin));
450           }
451         };
452 
453         const float kArithCost = 2.5f;
454         const float kMovCost = 1.0f;
455         float shard_cost = 4 * kArithCost + kMovCost;
456         const DeviceBase::CpuWorkerThreads& worker_threads =
457             *(context->device()->tensorflow_cpu_worker_threads());
458         Shard(worker_threads.num_threads, worker_threads.workers, n, shard_cost,
459               parallel_func);
460 #else
461 #pragma omp parallel for schedule(static)
462         for (int j = 0; j < n; ++j) {
463           int x = 0;
464           for (int i = 0; i < k; ++i) {
465             x += wt_buf[i * n + j];
466           }
467           comp_bias[j] =
468               ((bias_buf[j] * out_scale) + static_cast<float>(x * qa_amin));
469         }
470 #endif  // !ENABLE_ONEDNN_OPENMP
471         return reinterpret_cast<Tbias*>(comp_bias_);
472 
473       } else if (mode_ == QUANTIZE_MODE_SCALED) {
474         // If the bias is float and input quantize is SCALE, bias has to be
475         // compensated with Bs32 = Qa * Qw * Bf32.
476         out_scale = 255.0 * 127.0 / max_input *
477                     std::max(std::abs(max_weight), std::abs(min_weight));
478 
479         std::vector<float> scales;
480         scales.push_back(out_scale);
481         dnnl::primitive_attr bias_attr;
482         bias_attr.set_output_scales(0, scales);
483 
484         void* bias_buf = static_cast<void*>(
485             const_cast<Tbias*>(bias_tensor.flat<Tbias>().data()));
486         input_bias_ = new memory(mkldnn_matmul_fwd_pd->bias_desc(),
487                                  this->cpu_engine_, bias_buf);
488         scaled_bias_ =
489             new memory(mkldnn_matmul_fwd_pd->bias_desc(), this->cpu_engine_);
490 
491         auto reorder_desc = dnnl::reorder::primitive_desc(
492             *input_bias_, *scaled_bias_, bias_attr);
493         net.push_back(dnnl::reorder(reorder_desc));
494         std::unordered_map<int, memory> reorder_net_args = {
495             {DNNL_ARG_FROM, *input_bias_}, {DNNL_ARG_TO, *scaled_bias_}};
496         net.at(0).execute(*reorder_stream, reorder_net_args);
497 
498         return reinterpret_cast<Tbias*>(scaled_bias_->get_data_handle());
499       } else {
500         context->CtxFailure(
501             errors::InvalidArgument("Quantization mode must be"
502                                     "either MIN_FIRST or SCALED."));
503         return nullptr;
504       }
505     }
506   }
507 
508  private:
509   memory* input_bias_ = nullptr;
510   memory* scaled_bias_ = nullptr;
511 
512   // Buffer to save the compensated bias
513   float* comp_bias_ = nullptr;
514 
515   int mode_;
516 };
517 
518 template <typename Device, typename Tinput, typename Tweight, typename Tbias,
519           typename Toutput, bool native_format = false>
520 class MklDnnQuantizedMatMulReluOp
521     : public MklDnnQuantizedMatMulOp<Device, Tinput, Tweight, Tbias, Toutput,
522                                      native_format> {
523  public:
~MklDnnQuantizedMatMulReluOp()524   virtual ~MklDnnQuantizedMatMulReluOp() {}
525 
MklDnnQuantizedMatMulReluOp(OpKernelConstruction * context)526   explicit MklDnnQuantizedMatMulReluOp(OpKernelConstruction* context)
527       : MklDnnQuantizedMatMulOp<Device, Tinput, Tweight, Tbias, Toutput,
528                                 native_format>(context) {}
529 
530  protected:
ExtendMklDnnMatMulFwdParams(OpKernelContext * context,MklDnnMatMulFwdParams & params)531   void ExtendMklDnnMatMulFwdParams(OpKernelContext* context,
532                                    MklDnnMatMulFwdParams& params) override {
533     MklDnnQuantizedMatMulOp<Device, quint8, qint8, Tbias, Toutput,
534                             native_format>::ExtendMklDnnMatMulFwdParams(context,
535                                                                         params);
536     params.post_op_params.push_back({"relu", {1.0, 0.0, 0.0}});
537   }
538 };
539 
540 #define REGISTER_MKL_KERNEL(op, kernel, bias_type, output_type, is_native)   \
541   REGISTER_KERNEL_BUILDER(                                                   \
542       Name(op)                                                               \
543           .Device(DEVICE_CPU)                                                \
544           .TypeConstraint<quint8>("T1")                                      \
545           .TypeConstraint<qint8>("T2") BIAS_TYPE_CONSTRAINT(bias_type)       \
546           .TypeConstraint<output_type>("Toutput") LABEL,                     \
547       kernel TEMPLATE_ARGS(CPUDevice, quint8, qint8, bias_type, output_type, \
548                            is_native));
549 
550 #define REGISTER_MKL_KERNEL_ALL_BIAS_TYPES(op, kernel, output_type, is_native) \
551   REGISTER_MKL_KERNEL(op, kernel, float, output_type, is_native)               \
552   REGISTER_MKL_KERNEL(op, kernel, qint32, output_type, is_native);
553 
554 #define LABEL
555 #define TEMPLATE_ARGS(CPUDevice, quint8, qint8, bias_type, output_type, \
556                       is_native)
557 #define BIAS_TYPE_CONSTRAINT(bias_type)
558 REGISTER_MKL_KERNEL("QuantizedMatMulWithBiasAndRelu", NoOp, float, qint32,
559                     false);
560 #undef BIAS_TYPE_CONSTRAINT
561 
562 #define BIAS_TYPE_CONSTRAINT(bias_type) .TypeConstraint<bias_type>("Tbias")
563 REGISTER_MKL_KERNEL_ALL_BIAS_TYPES("QuantizedMatMulWithBias", NoOp, qint32,
564                                    false);
565 REGISTER_MKL_KERNEL_ALL_BIAS_TYPES(
566     "QuantizedMatMulWithBiasAndReluAndRequantize", NoOp, quint8, false);
567 REGISTER_MKL_KERNEL_ALL_BIAS_TYPES("QuantizedMatMulWithBiasAndRequantize", NoOp,
568                                    quint8, false);
569 REGISTER_MKL_KERNEL_ALL_BIAS_TYPES("QuantizedMatMulWithBiasAndDequantize", NoOp,
570                                    float, false);
571 #undef BIAS_TYPE_CONSTRAINT
572 #undef TEMPLATE_ARGS
573 #undef LABEL
574 
575 #define LABEL .Label(mkl_op_registry::kMklQuantizedOpLabel)
576 #define TEMPLATE_ARGS(CPUDevice, quint8, qint8, bias_type, output_type, \
577                       is_native)                                        \
578 <CPUDevice, quint8, qint8, bias_type, output_type, is_native>
579 #define BIAS_TYPE_CONSTRAINT(bias_type)
580 REGISTER_MKL_KERNEL("_MklQuantizedMatMulWithBiasAndRelu",
581                     MklDnnQuantizedMatMulReluOp, float, qint32, true);
582 #undef BIAS_TYPE_CONSTRAINT
583 
584 #define BIAS_TYPE_CONSTRAINT(bias_type) .TypeConstraint<bias_type>("Tbias")
585 REGISTER_MKL_KERNEL_ALL_BIAS_TYPES("_MklQuantizedMatMulWithBias",
586                                    MklDnnQuantizedMatMulOp, qint32, true);
587 REGISTER_MKL_KERNEL_ALL_BIAS_TYPES(
588     "_MklQuantizedMatMulWithBiasAndReluAndRequantize",
589     MklDnnQuantizedMatMulReluOp, quint8, true);
590 REGISTER_MKL_KERNEL_ALL_BIAS_TYPES("_MklQuantizedMatMulWithBiasAndRequantize",
591                                    MklDnnQuantizedMatMulOp, quint8, true);
592 REGISTER_MKL_KERNEL_ALL_BIAS_TYPES("_MklQuantizedMatMulWithBiasAndDequantize",
593                                    MklDnnQuantizedMatMulOp, float, true);
594 #undef BIAS_TYPE_CONSTRAINT
595 #undef TEMPLATE_ARGS
596 #undef LABEL
597 
598 }  // namespace tensorflow
599 
600 #endif  // INTEL_MKL
601