1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Implements a quantized eight-bit version of the matmul operation with bias, 17 // relu and requantization fusion support utilizing oneDNN u8s8s32 inner 18 // product API. Right now, this version can support 19 // - Input: quantized as uint8 via either MIN_FIRST or SCALE mode. 20 // SCALE mode is selected when input is guaranteed to be non- 21 // negative, e.g., MatMul is fed by Relu. Otherwise, MIN_FIRST is 22 // selected. 23 // - Weight: quantized to int8 via SCALE mode. 24 // - Bias: float32/int32. For int32, it is quantized according to input and 25 // filter min-max values. 26 // Other than that, this op does not support other input combination yet. 27 // When input is quantized to uint8 via MIN_FIRST, bias needs compensation. 28 // The detailed algorithm is illustrated as below: 29 // 30 // Af32 is the original fp32 activation 2D tensor. 31 // Min(Af32) is the minimum scalar value of Af32. 32 // Max(Af32) is the maximum scalar value of Af32. 33 // Qa is the quantization scale for activation. 34 // Au8 is the quantized unsigned int8 activation tensor. 35 // With SCALE quantization (used for non-negative Af32), Qa and Au8 can be 36 // calculated as below: 37 // Qa = 255.0 / Max(Af32) 38 // Au8 = round(Qa * Af32). 39 // With MIN_FIRST quantization, Q'a and A'u8 can be calculated as below: 40 // Q'a = 255.0 / (Max(Af32) - Min(Af32)) 41 // A'u8 = round(Q'a * (Af32 - Min(Af32) * ones(Af32))), 42 // where, ones(.) is a tensor of all 1s with the same shape of its argument and 43 // round(.) rounds a number to its nearest integer. 44 // 45 // Wf32 is the original fp32 2D weight tensor. 46 // MaxAbs(Wf32) is the maximum absolute scalar value of Wf32. 47 // Qw is the quantization scale of weight. 48 // Ws8 is the quantized signed int8 weight tensor. 49 // Qw and Ws8 can be calculated as below: 50 // Qw = 127.0 / MaxAbs(Wf32) 51 // Ws8 = round(Qw * Wf32). 52 // 53 // Bf32 is the original fp32 1D bias tensor matching the innermost dim of 54 // Wf32. 55 // With SCALE quantization of activation, the scaled bias, Bs32, is calculated 56 // as below: 57 // Bs32 = Qa * Qw * Bf32. 58 // With MIN_FIRST quantization of activation, the scaled bias tensor with 59 // compensation, B's32, is calculated as below: 60 // B's32 = Q'a * Qw * Bf32 + Q'a * Qw * Min(Af32) * 1 * Wf32 61 // = Q'a * Qw * Bf32 + Q'a * Min(Af32) * 1 * Ws8. 62 // where, 1 denotes a row vector matching the outermost dim of Wf32. 63 // 64 // The QuantizedMatMulWithBias op calculates 32bit integer output as below: 65 // - with SCALE activation quantization: 66 // Xs32 = Au8 * Ws8 + 1' * Bs32 67 // = Qa * Qw * Af32 * Wf32 + Qa * Qw * 1' * Bf32 68 // = Qa * Qw * (Af32 * Wf32 + 1' * Bf32) = Qa * Qw * Xf32, 69 // where, 1' denotes a column vector matching the outermost dim of Af32 and 70 // Xf32 represents the output of original fp32 MatMul with BiasAdd fusion. 71 // 72 // - with MIN_FIRST activation quantization: 73 // Xs32 = A'u8 * Ws8 + 1' * B's32 74 // = Q'a * (Af32 - Min(Af32) * ones(Af32)) * Qw * Wf32 + 75 // Q'a * Qw * 1' * Bf32 + Q'a * Qw * Min(Af32) * 1' * 1 * Wf32 76 // = Q'a * Qw * (Af32 * Wf32 + 1' * Bf32) 77 // = Q'a * Qw * Xf32. 78 // Note that 1' * 1 = ones(Af32). 79 // 80 // The QuantizedMatMulWithBiasAndRelu op does the same calculation as above 81 // except adding relu function for the 32bit integer output. 82 // 83 // The QuantizedMatMulWithBiasAndReluAndRequantize op does one more step of 84 // requantize calculation based on above. Since the fusion ends with a Relu the 85 // activation Xf32 at Relu, in the original fp32 graph, is guaranteed to be 86 // non-negative. The requantize scale Qr is calculated from offline calibration. 87 // Qr = 255 / Max(Xf32) 88 // Xu8 = Qr * Xf32. 89 // 90 // More information of this implementation can be found in 91 // https://software.intel.com/en-us/articles/lower-numerical-precision-deep-learning-inference-and-training 92 #ifdef INTEL_MKL 93 94 #include "tensorflow/core/framework/register_types.h" 95 #include "tensorflow/core/kernels/fill_functor.h" 96 #include "tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h" 97 #include "tensorflow/core/kernels/mkl/mkl_quantized_conv_ops.h" 98 #include "tensorflow/core/kernels/no_op.h" 99 #include "tensorflow/core/lib/core/errors.h" 100 #include "tensorflow/core/util/mkl_threadpool.h" 101 #include "tensorflow/core/util/mkl_util.h" 102 #include "tensorflow/core/util/work_sharder.h" 103 104 namespace { 105 enum { 106 QUANTIZE_MODE_MIN_FIRST, 107 QUANTIZE_MODE_SCALED, 108 }; 109 } // namespace 110 111 namespace tensorflow { 112 113 template <typename Device, typename Tinput, typename Tweight, typename Tbias, 114 typename Toutput, bool native_format = false> 115 class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> { 116 public: ~MklDnnQuantizedMatMulOp()117 virtual ~MklDnnQuantizedMatMulOp() { 118 if (this->input_bias_ != nullptr) { 119 delete this->input_bias_; 120 input_bias_ = nullptr; 121 } 122 if (this->scaled_bias_ != nullptr) { 123 delete this->scaled_bias_; 124 scaled_bias_ = nullptr; 125 } 126 if (this->comp_bias_ != nullptr) { 127 delete this->comp_bias_; 128 comp_bias_ = nullptr; 129 } 130 } 131 GetCompBiasBuffer(int size)132 float* GetCompBiasBuffer(int size) { 133 if (comp_bias_ == nullptr) { 134 comp_bias_ = new float[size]; 135 } 136 return comp_bias_; 137 } 138 MklDnnQuantizedMatMulOp(OpKernelConstruction * context)139 explicit MklDnnQuantizedMatMulOp(OpKernelConstruction* context) 140 : MklDnnMatMulOpBase<Tweight, Toutput>(context) { 141 string mode_string; 142 OP_REQUIRES_OK(context, context->GetAttr("input_quant_mode", &mode_string)); 143 if (mode_string == "MIN_FIRST") { 144 mode_ = QUANTIZE_MODE_MIN_FIRST; 145 } else if (mode_string == "SCALED") { 146 mode_ = QUANTIZE_MODE_SCALED; 147 } else { 148 context->CtxFailure(errors::InvalidArgument( 149 "Quantization mode must be either MIN_FIRST or SCALED, but received ", 150 mode_string)); 151 } 152 this->is_weight_const_ = false; 153 if (context->HasAttr("is_weight_const")) { 154 OP_REQUIRES_OK(context, context->GetAttr("is_weight_const", 155 &(this->is_weight_const_))); 156 } 157 } 158 Compute(OpKernelContext * context)159 void Compute(OpKernelContext* context) override { 160 try { 161 // Input tensors 162 const Tensor& src_tensor = MklGetInput(context, this->kInputIndexSrc); 163 const Tensor& weight_tensor = 164 MklGetInput(context, this->kInputIndexWeight); 165 const Tensor& bias_tensor = MklGetInput(context, this->kInputIndexBias); 166 167 MklDnnShape src_mkl_shape, weight_mkl_shape; 168 GetMklShape(context, this->kInputIndexSrc, &src_mkl_shape, native_format); 169 GetMklShape(context, this->kInputIndexWeight, &weight_mkl_shape, 170 native_format); 171 OP_REQUIRES(context, !weight_mkl_shape.IsMklTensor(), 172 errors::InvalidArgument("Weight should not be in " 173 "MKL Layout")); 174 175 MklDnnData<Tinput> src(&(this->cpu_engine_)); 176 MklDnnData<Tweight> weight(&(this->cpu_engine_)); 177 178 memory::dims src_dims, weight_dims; 179 memory::dims dst_dims_tf_order, dst_dims_mkl_order; 180 181 // Get shapes of input tensors in oneDNN order 182 auto src_tf_shape = src_mkl_shape.IsMklTensor() 183 ? src_mkl_shape.GetTfShape() 184 : src_tensor.shape(); 185 auto weight_tf_shape = weight_mkl_shape.IsMklTensor() 186 ? weight_mkl_shape.GetTfShape() 187 : weight_tensor.shape(); 188 189 src_dims = TFShapeToMklDnnDims(src_tf_shape); 190 weight_dims = TFShapeToMklDnnDims(weight_tf_shape); 191 dst_dims_mkl_order = {static_cast<int>(src_tf_shape.dim_size(0)), 192 static_cast<int>(weight_tf_shape.dim_size(1))}; 193 194 // Weight dims need to be reversed to create inner-product forward 195 // descriptor 196 weight_dims = {static_cast<int>(weight_tf_shape.dim_size(1)), 197 static_cast<int>(weight_tf_shape.dim_size(0))}; 198 199 // Create memory for user data. 200 // Describe how the inputs and outputs of inner-product look like. Also 201 // specify buffers containing actual input and output data. 202 Tensor* dst_tensor = nullptr; 203 auto input_output_fmt = memory::format_tag::nc; 204 auto input_output_fmt_mkldnn = MklTensorFormat::FORMAT_NC; 205 206 // If input is in MKL layout, then simply take input layout; otherwise, 207 // construct input TF layout. For TF layout, although input shape 208 // (src_dims) required is in oneDNN order, the layout is Tensorflow's 209 // layout depending on data format. 210 auto src_md = 211 src_mkl_shape.IsMklTensor() 212 ? src_mkl_shape.GetMklLayout() 213 : memory::desc(src_dims, MklDnnType<Tinput>(), input_output_fmt); 214 src.SetUsrMem(src_md, &src_tensor); 215 216 // Although weight shape (weight_dims) required is in oneDNN order, 217 // the layout is TensorFlow's layout. 218 auto weight_md = weight_mkl_shape.IsMklTensor() 219 ? weight_mkl_shape.GetMklLayout() 220 : memory::desc(weight_dims, MklDnnType<Tweight>(), 221 memory::format_tag::io); 222 weight.SetUsrMem(weight_md, &weight_tensor); 223 224 MklDnnMatMulFwdPrimitive<float, Tinput, Tweight, Tbias, Toutput>* 225 matmul_fwd = nullptr; 226 memory::dims bias_dims = {static_cast<int>(bias_tensor.dim_size(0))}; 227 228 MklDnnMatMulFwdParams matmul_fwd_dims(src_dims, weight_dims, bias_dims, 229 dst_dims_mkl_order); 230 231 // Extend the basic parameters for data types and fusions. 232 this->ExtendMklDnnMatMulFwdParams(context, matmul_fwd_dims); 233 234 // Get a MatMul fwd from primitive pool. 235 matmul_fwd = 236 MklDnnMatMulFwdPrimitiveFactory<float, Tinput, Tweight, Tbias, 237 Toutput>::Get(matmul_fwd_dims, 0); 238 239 // Allocate output Tensor. 240 std::shared_ptr<dnnl::inner_product_forward::primitive_desc> 241 matmul_fwd_pd = matmul_fwd->GetPrimitiveDesc(); 242 this->AllocateOutputTensor(context, *matmul_fwd_pd, dst_dims_mkl_order, 243 input_output_fmt_mkldnn, &dst_tensor, 244 native_format); 245 246 Toutput* dst_data = 247 reinterpret_cast<Toutput*>(dst_tensor->flat<Toutput>().data()); 248 249 // Check if src and weight data need to be reordered. 250 Tinput* src_data = nullptr; 251 if (!native_format && src_md != matmul_fwd_pd->src_desc()) { 252 src.SetUsrMem(src_md, &src_tensor); 253 src.CheckReorderToOpMem(matmul_fwd_pd.get()->src_desc(), 254 this->cpu_engine_, context); 255 src_data = static_cast<Tinput*>(src.GetOpMem().get_data_handle()); 256 } else { 257 src_data = static_cast<Tinput*>( 258 const_cast<Tinput*>(src_tensor.flat<Tinput>().data())); 259 } 260 261 Tweight* weight_data = nullptr; 262 if (weight_md != matmul_fwd_pd->weights_desc()) { 263 bool is_weight_cached = false; 264 // For batch size 1, oneDNN expects that weight format is OI whereas 265 // TF default format is IO. So in that case convert weight from IO 266 // to OI for the first iteration and cache it to reuse in the 267 // subsequent iterations, if the weight is constant. 268 if (this->is_weight_const_) { 269 // Check if the weight is already cached or not 270 if (this->IsWeightCacheEmpty(context)) { 271 // Cache weight if it is not cached. 272 this->CacheWeight(context, matmul_fwd_pd, weight_data, 273 weight_tensor, weight, weight_md); 274 } 275 weight_data = 276 this->GetCachedWeight(context, matmul_fwd_pd->weights_desc()); 277 is_weight_cached = (weight_data != nullptr); 278 } 279 280 if (!is_weight_cached) { 281 weight.SetUsrMem(weight_md, &weight_tensor); 282 weight.CheckReorderToOpMem(matmul_fwd_pd.get()->weights_desc(), 283 this->cpu_engine_, context); 284 weight_data = 285 static_cast<Tweight*>(weight.GetOpMem().get_data_handle()); 286 } 287 288 } else { 289 weight_data = static_cast<Tweight*>( 290 const_cast<Tweight*>(weight_tensor.flat<Tweight>().data())); 291 } 292 293 std::shared_ptr<stream> cpu_stream; 294 MklDnnThreadPool eigen_tp(context); 295 cpu_stream.reset(CreateStream(&eigen_tp, matmul_fwd->GetEngine())); 296 297 UserScratchPad<unsigned char> scratch_pad; 298 scratch_pad.AllocateSPTensor(matmul_fwd, context); 299 300 // Execute inner-product 301 Tbias* bias_data = this->GetBiasHandle( 302 context, matmul_fwd_pd, bias_tensor, weight_tensor, cpu_stream); 303 matmul_fwd->Execute(src_data, weight_data, bias_data, dst_data, 304 scratch_pad.Get(), cpu_stream); 305 } catch (dnnl::error& e) { 306 string error_msg = tensorflow::strings::StrCat( 307 "Status: ", e.status, ", message: ", string(e.message), ", in file ", 308 __FILE__, ":", __LINE__); 309 OP_REQUIRES_OK( 310 context, 311 errors::Aborted("Operation received an exception:", error_msg)); 312 } 313 float min_output_value; 314 float max_output_value; 315 if (std::is_same<Toutput, quint8>::value || 316 std::is_same<Toutput, qint8>::value) { 317 // This is the case the inner-product and requantization are fused. 318 // "min_freezed_output" and "max_freezed_output" are the requested range 319 // for the output. 320 min_output_value = context->input(7).flat<float>()(0); 321 max_output_value = context->input(8).flat<float>()(0); 322 } else { 323 ComputeOutputRangeForInt32(context, &min_output_value, &max_output_value); 324 } 325 326 if (std::is_same<Toutput, quint8>::value || 327 std::is_same<Toutput, qint8>::value || 328 std::is_same<Toutput, qint32>::value) { 329 Tensor* output_min = nullptr; 330 Tensor* output_max = nullptr; 331 MklDnnShape output_min_mkl_shape, output_max_mkl_shape; 332 output_min_mkl_shape.SetMklTensor(false); 333 output_max_mkl_shape.SetMklTensor(false); 334 AllocateOutputSetMklShape(context, 1, &output_min, {}, 335 output_min_mkl_shape, native_format); 336 AllocateOutputSetMklShape(context, 2, &output_max, {}, 337 output_max_mkl_shape, native_format); 338 output_min->flat<float>()(0) = min_output_value; 339 output_max->flat<float>()(0) = max_output_value; 340 } 341 } 342 343 protected: ComputeOutputRangeForInt32(OpKernelContext * context,float * min_output_value,float * max_output_value)344 void ComputeOutputRangeForInt32(OpKernelContext* context, 345 float* min_output_value, 346 float* max_output_value) { 347 const float min_input = context->input(3).flat<float>()(0); 348 const float max_input = context->input(4).flat<float>()(0); 349 const float min_weight = context->input(5).flat<float>()(0); 350 const float max_weight = context->input(6).flat<float>()(0); 351 MklQuantizationRangeForMultiplication<quint8, qint8, qint32>( 352 min_input, max_input, min_weight, max_weight, min_output_value, 353 max_output_value); 354 } 355 ExtendMklDnnMatMulFwdParams(OpKernelContext * context,MklDnnMatMulFwdParams & params)356 virtual void ExtendMklDnnMatMulFwdParams(OpKernelContext* context, 357 MklDnnMatMulFwdParams& params) { 358 // Append data type names of input, weight, bias, and output. 359 params.dtypes.append(typeid(Tinput).name()); 360 params.dtypes.append(typeid(Tweight).name()); 361 params.dtypes.append(typeid(Tbias).name()); 362 params.dtypes.append(typeid(Toutput).name()); 363 364 // When the output type is quint8, the output data is requantized into 365 // quint8. A post_op "output_scale" is added to do the conversion. 366 if (std::is_same<Toutput, quint8>::value || 367 std::is_same<Toutput, qint8>::value || 368 std::is_same<Toutput, float>::value) { 369 float min_output_value; 370 float max_output_value; 371 ComputeOutputRangeForInt32(context, &min_output_value, &max_output_value); 372 float scale_int32 = 373 std::max(std::abs(min_output_value), std::abs(max_output_value)); 374 const float min_freezed_output = context->input(7).flat<float>()(0); 375 const float max_freezed_output = context->input(8).flat<float>()(0); 376 float scale_eightbit = 377 std::max(std::abs(min_freezed_output), std::abs(max_freezed_output)); 378 float scale = 1.0; 379 if (std::is_same<Toutput, quint8>::value) { 380 scale = scale_int32 / scale_eightbit / static_cast<float>(1u << 23); 381 } else if (std::is_same<Toutput, qint8>::value) { 382 scale = scale_int32 / scale_eightbit / static_cast<float>(1u << 24); 383 } else if (std::is_same<Toutput, float>::value) { 384 scale = scale_int32 / static_cast<float>(1u << 31); 385 } else { 386 // TODO(intel-tf): Keep the default qint8 as before. 387 // Change to error later. 388 scale = scale_int32 / scale_eightbit / static_cast<float>(1u << 24); 389 } 390 std::vector<float> output_scale; 391 output_scale.push_back(scale); 392 params.post_op_params.push_back({"output_scale", output_scale}); 393 } 394 } 395 396 // This function handles bias conversion and compensation for MIN_FIRST and 397 // SCALE mode. If input is quantized via MIN_FIRST, 398 // B's32 = Q'a * Qw * Bf32 + Q'a * Qw * Min(Af32) * 1 * Wf32 399 // If input is quantized via SCALE, 400 // Bs32 = Qa * Qw * Bf32. GetBiasHandle(OpKernelContext * context,std::shared_ptr<dnnl::inner_product_forward::primitive_desc> & mkldnn_matmul_fwd_pd,const Tensor & bias_tensor,const Tensor & weight_tensor,std::shared_ptr<stream> reorder_stream)401 Tbias* GetBiasHandle( 402 OpKernelContext* context, 403 std::shared_ptr<dnnl::inner_product_forward::primitive_desc>& 404 mkldnn_matmul_fwd_pd, 405 const Tensor& bias_tensor, const Tensor& weight_tensor, 406 std::shared_ptr<stream> reorder_stream) { 407 // If the bias is qint32, it means the bias is already converted offline. 408 // and it can be added to matmul output directly. 409 if (std::is_same<Tbias, qint32>::value) { 410 return static_cast<Tbias*>( 411 const_cast<Tbias*>(bias_tensor.flat<Tbias>().data())); 412 } else { 413 // If the bias is fp32, then need to calculate the bias 414 const float min_input = context->input(3).flat<float>()(0); 415 const float max_input = context->input(4).flat<float>()(0); 416 const float min_weight = context->input(5).flat<float>()(0); 417 const float max_weight = context->input(6).flat<float>()(0); 418 419 std::vector<dnnl::primitive> net; 420 float out_scale; 421 // If the bias is float and input quantize is MIN_FIRST, bias has to be 422 // compensated with B's32 = Q'a * Qw * Bf32 + Q'a * Qw * Min(Af32) * 1 * 423 // Wf32. 424 if (mode_ == QUANTIZE_MODE_MIN_FIRST) { 425 int k = weight_tensor.dim_size(0); 426 int n = weight_tensor.dim_size(1); 427 float* comp_bias = GetCompBiasBuffer(n); 428 429 qint8* wt_buf = static_cast<qint8*>( 430 const_cast<qint8*>(weight_tensor.flat<qint8>().data())); 431 432 const float* bias_buf = static_cast<float*>( 433 const_cast<float*>(bias_tensor.flat<float>().data())); 434 435 float qa_amin = 255 * min_input / (max_input - min_input); 436 437 out_scale = (255.0 * 127.0) / 438 ((max_input - min_input) * 439 std::max(std::abs(max_weight), std::abs(min_weight))); 440 441 #ifndef ENABLE_ONEDNN_OPENMP 442 auto parallel_func = [&](int64 start, int64 end) { 443 for (int64 j = start; j < end; j++) { 444 int x = 0; 445 for (int64 i = 0; i < k; ++i) { 446 x += wt_buf[i * n + j]; 447 } 448 comp_bias[j] = 449 ((bias_buf[j] * out_scale) + static_cast<float>(x * qa_amin)); 450 } 451 }; 452 453 const float kArithCost = 2.5f; 454 const float kMovCost = 1.0f; 455 float shard_cost = 4 * kArithCost + kMovCost; 456 const DeviceBase::CpuWorkerThreads& worker_threads = 457 *(context->device()->tensorflow_cpu_worker_threads()); 458 Shard(worker_threads.num_threads, worker_threads.workers, n, shard_cost, 459 parallel_func); 460 #else 461 #pragma omp parallel for schedule(static) 462 for (int j = 0; j < n; ++j) { 463 int x = 0; 464 for (int i = 0; i < k; ++i) { 465 x += wt_buf[i * n + j]; 466 } 467 comp_bias[j] = 468 ((bias_buf[j] * out_scale) + static_cast<float>(x * qa_amin)); 469 } 470 #endif // !ENABLE_ONEDNN_OPENMP 471 return reinterpret_cast<Tbias*>(comp_bias_); 472 473 } else if (mode_ == QUANTIZE_MODE_SCALED) { 474 // If the bias is float and input quantize is SCALE, bias has to be 475 // compensated with Bs32 = Qa * Qw * Bf32. 476 out_scale = 255.0 * 127.0 / max_input * 477 std::max(std::abs(max_weight), std::abs(min_weight)); 478 479 std::vector<float> scales; 480 scales.push_back(out_scale); 481 dnnl::primitive_attr bias_attr; 482 bias_attr.set_output_scales(0, scales); 483 484 void* bias_buf = static_cast<void*>( 485 const_cast<Tbias*>(bias_tensor.flat<Tbias>().data())); 486 input_bias_ = new memory(mkldnn_matmul_fwd_pd->bias_desc(), 487 this->cpu_engine_, bias_buf); 488 scaled_bias_ = 489 new memory(mkldnn_matmul_fwd_pd->bias_desc(), this->cpu_engine_); 490 491 auto reorder_desc = dnnl::reorder::primitive_desc( 492 *input_bias_, *scaled_bias_, bias_attr); 493 net.push_back(dnnl::reorder(reorder_desc)); 494 std::unordered_map<int, memory> reorder_net_args = { 495 {DNNL_ARG_FROM, *input_bias_}, {DNNL_ARG_TO, *scaled_bias_}}; 496 net.at(0).execute(*reorder_stream, reorder_net_args); 497 498 return reinterpret_cast<Tbias*>(scaled_bias_->get_data_handle()); 499 } else { 500 context->CtxFailure( 501 errors::InvalidArgument("Quantization mode must be" 502 "either MIN_FIRST or SCALED.")); 503 return nullptr; 504 } 505 } 506 } 507 508 private: 509 memory* input_bias_ = nullptr; 510 memory* scaled_bias_ = nullptr; 511 512 // Buffer to save the compensated bias 513 float* comp_bias_ = nullptr; 514 515 int mode_; 516 }; 517 518 template <typename Device, typename Tinput, typename Tweight, typename Tbias, 519 typename Toutput, bool native_format = false> 520 class MklDnnQuantizedMatMulReluOp 521 : public MklDnnQuantizedMatMulOp<Device, Tinput, Tweight, Tbias, Toutput, 522 native_format> { 523 public: ~MklDnnQuantizedMatMulReluOp()524 virtual ~MklDnnQuantizedMatMulReluOp() {} 525 MklDnnQuantizedMatMulReluOp(OpKernelConstruction * context)526 explicit MklDnnQuantizedMatMulReluOp(OpKernelConstruction* context) 527 : MklDnnQuantizedMatMulOp<Device, Tinput, Tweight, Tbias, Toutput, 528 native_format>(context) {} 529 530 protected: ExtendMklDnnMatMulFwdParams(OpKernelContext * context,MklDnnMatMulFwdParams & params)531 void ExtendMklDnnMatMulFwdParams(OpKernelContext* context, 532 MklDnnMatMulFwdParams& params) override { 533 MklDnnQuantizedMatMulOp<Device, quint8, qint8, Tbias, Toutput, 534 native_format>::ExtendMklDnnMatMulFwdParams(context, 535 params); 536 params.post_op_params.push_back({"relu", {1.0, 0.0, 0.0}}); 537 } 538 }; 539 540 #define REGISTER_MKL_KERNEL(op, kernel, bias_type, output_type, is_native) \ 541 REGISTER_KERNEL_BUILDER( \ 542 Name(op) \ 543 .Device(DEVICE_CPU) \ 544 .TypeConstraint<quint8>("T1") \ 545 .TypeConstraint<qint8>("T2") BIAS_TYPE_CONSTRAINT(bias_type) \ 546 .TypeConstraint<output_type>("Toutput") LABEL, \ 547 kernel TEMPLATE_ARGS(CPUDevice, quint8, qint8, bias_type, output_type, \ 548 is_native)); 549 550 #define REGISTER_MKL_KERNEL_ALL_BIAS_TYPES(op, kernel, output_type, is_native) \ 551 REGISTER_MKL_KERNEL(op, kernel, float, output_type, is_native) \ 552 REGISTER_MKL_KERNEL(op, kernel, qint32, output_type, is_native); 553 554 #define LABEL 555 #define TEMPLATE_ARGS(CPUDevice, quint8, qint8, bias_type, output_type, \ 556 is_native) 557 #define BIAS_TYPE_CONSTRAINT(bias_type) 558 REGISTER_MKL_KERNEL("QuantizedMatMulWithBiasAndRelu", NoOp, float, qint32, 559 false); 560 #undef BIAS_TYPE_CONSTRAINT 561 562 #define BIAS_TYPE_CONSTRAINT(bias_type) .TypeConstraint<bias_type>("Tbias") 563 REGISTER_MKL_KERNEL_ALL_BIAS_TYPES("QuantizedMatMulWithBias", NoOp, qint32, 564 false); 565 REGISTER_MKL_KERNEL_ALL_BIAS_TYPES( 566 "QuantizedMatMulWithBiasAndReluAndRequantize", NoOp, quint8, false); 567 REGISTER_MKL_KERNEL_ALL_BIAS_TYPES("QuantizedMatMulWithBiasAndRequantize", NoOp, 568 quint8, false); 569 REGISTER_MKL_KERNEL_ALL_BIAS_TYPES("QuantizedMatMulWithBiasAndDequantize", NoOp, 570 float, false); 571 #undef BIAS_TYPE_CONSTRAINT 572 #undef TEMPLATE_ARGS 573 #undef LABEL 574 575 #define LABEL .Label(mkl_op_registry::kMklQuantizedOpLabel) 576 #define TEMPLATE_ARGS(CPUDevice, quint8, qint8, bias_type, output_type, \ 577 is_native) \ 578 <CPUDevice, quint8, qint8, bias_type, output_type, is_native> 579 #define BIAS_TYPE_CONSTRAINT(bias_type) 580 REGISTER_MKL_KERNEL("_MklQuantizedMatMulWithBiasAndRelu", 581 MklDnnQuantizedMatMulReluOp, float, qint32, true); 582 #undef BIAS_TYPE_CONSTRAINT 583 584 #define BIAS_TYPE_CONSTRAINT(bias_type) .TypeConstraint<bias_type>("Tbias") 585 REGISTER_MKL_KERNEL_ALL_BIAS_TYPES("_MklQuantizedMatMulWithBias", 586 MklDnnQuantizedMatMulOp, qint32, true); 587 REGISTER_MKL_KERNEL_ALL_BIAS_TYPES( 588 "_MklQuantizedMatMulWithBiasAndReluAndRequantize", 589 MklDnnQuantizedMatMulReluOp, quint8, true); 590 REGISTER_MKL_KERNEL_ALL_BIAS_TYPES("_MklQuantizedMatMulWithBiasAndRequantize", 591 MklDnnQuantizedMatMulOp, quint8, true); 592 REGISTER_MKL_KERNEL_ALL_BIAS_TYPES("_MklQuantizedMatMulWithBiasAndDequantize", 593 MklDnnQuantizedMatMulOp, float, true); 594 #undef BIAS_TYPE_CONSTRAINT 595 #undef TEMPLATE_ARGS 596 #undef LABEL 597 598 } // namespace tensorflow 599 600 #endif // INTEL_MKL 601