1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/array_ops.cc. 17 18 #ifdef INTEL_MKL 19 #define EIGEN_USE_THREADS 20 21 #include <math.h> 22 23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 24 #include "dnnl.hpp" 25 #include "tensorflow/core/framework/op.h" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/type_traits.h" 28 #include "tensorflow/core/framework/types.h" 29 #include "tensorflow/core/kernels/meta_support.h" 30 #include "tensorflow/core/kernels/no_op.h" 31 #include "tensorflow/core/lib/core/errors.h" 32 #include "tensorflow/core/util/mkl_util.h" 33 34 namespace tensorflow { 35 36 typedef Eigen::ThreadPoolDevice CPUDevice; 37 38 template <typename Device, typename Toutput> 39 class MklRequantizePerChannelOp : public OpKernel { 40 public: MklRequantizePerChannelOp(OpKernelConstruction * ctx)41 explicit MklRequantizePerChannelOp(OpKernelConstruction* ctx) 42 : OpKernel(ctx) { 43 OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_type_)); 44 OP_REQUIRES(ctx, out_type_ == DT_QINT8 || out_type_ == DT_QUINT8, 45 errors::InvalidArgument( 46 "out_type must be qint8 or quint8, but got: ", out_type_)); 47 } ~MklRequantizePerChannelOp()48 virtual ~MklRequantizePerChannelOp() {} Compute(OpKernelContext * ctx)49 void Compute(OpKernelContext* ctx) override { 50 try { 51 const Tensor& input = ctx->input(kInputTensorIndex); 52 OP_REQUIRES( 53 ctx, input.dims() == 4, 54 errors::InvalidArgument("Current RequantizePerChannel operator" 55 "supports 4D tensors only.")); 56 57 const Tensor& input_min_vec = ctx->input(kInputMinVecIndex); 58 size_t depth = input_min_vec.NumElements(); 59 float* input_min_vec_data = (float*)const_cast<void*>( 60 static_cast<const void*>(input_min_vec.flat<float>().data())); 61 62 const Tensor& input_max_vec = ctx->input(kInputMaxVecIndex); 63 OP_REQUIRES( 64 ctx, input_max_vec.NumElements() == depth, 65 errors::InvalidArgument("input_max has incorrect size, expected ", 66 depth, " was ", input_max_vec.NumElements())); 67 float* input_max_vec_data = (float*)const_cast<void*>( 68 static_cast<const void*>(input_max_vec.flat<float>().data())); 69 70 const Tensor& input_requested_min = ctx->input(this->kRequestMinIndex); 71 OP_REQUIRES( 72 ctx, input_requested_min.NumElements() == 1, 73 errors::InvalidArgument("requested_output_min must be a scalar")); 74 const float input_requested_min_float = 75 input_requested_min.flat<float>()(0); 76 77 const Tensor& input_requested_max = ctx->input(this->kRequestMaxIndex); 78 OP_REQUIRES( 79 ctx, input_requested_min.NumElements() == 1, 80 errors::InvalidArgument("requested_output_max must be a scalar")); 81 const float input_requested_max_float = 82 input_requested_max.flat<float>()(0); 83 84 if (out_type_ == DT_QINT8) { 85 OP_REQUIRES(ctx, input_requested_min_float < 0.0f, 86 errors::InvalidArgument( 87 "If out_type is QINT8, requested_output_max must be " 88 "non negative, got ", 89 input_requested_min_float)); 90 } 91 92 const float factor = (out_type_ == DT_QINT8) ? 127.0f : 255.0f; 93 const float requested_min_max = 94 std::max(std::abs(input_requested_min_float), 95 std::abs(input_requested_max_float)); 96 Tensor* output = nullptr; 97 OP_REQUIRES_OK(ctx, ctx->allocate_output(kOutputTensorIndex, 98 input.shape(), &output)); 99 100 std::vector<float> scales(depth); 101 for (int i = 0; i < depth; ++i) { 102 float min_max_from_vec = std::max(std::abs(input_min_vec_data[i]), 103 std::abs(input_max_vec_data[i])); 104 scales[i] = factor * (min_max_from_vec / requested_min_max / 105 static_cast<float>(1L << 31)); 106 } 107 108 dnnl::primitive_attr reorder_attr; 109 reorder_attr.set_output_scales(2, scales); 110 111 memory::dims dims_mkl_order = 112 TFShapeToMklDnnDimsInNCHW(input.shape(), FORMAT_NHWC); 113 memory::desc input_md = memory::desc(dims_mkl_order, MklDnnType<qint32>(), 114 memory::format_tag::nhwc); 115 memory::desc output_md = 116 (out_type_ == DT_QINT8) 117 ? memory::desc(dims_mkl_order, MklDnnType<qint8>(), 118 memory::format_tag::nhwc) 119 : memory::desc(dims_mkl_order, MklDnnType<quint8>(), 120 memory::format_tag::nhwc); 121 122 void* input_buf = 123 static_cast<void*>(const_cast<qint32*>(input.flat<qint32>().data())); 124 void* output_buf; 125 if (out_type_ == DT_QINT8) { 126 output_buf = static_cast<void*>( 127 const_cast<qint8*>(output->flat<qint8>().data())); 128 } else { 129 output_buf = static_cast<void*>( 130 const_cast<quint8*>(output->flat<quint8>().data())); 131 } 132 133 std::unique_ptr<memory> input_mem_prim( 134 new memory(input_md, cpu_engine_, input_buf)); 135 std::unique_ptr<memory> output_mem_prim( 136 new memory(output_md, cpu_engine_, output_buf)); 137 138 dnnl::reorder::primitive_desc reorder_pd = 139 ReorderPd(cpu_engine_, input_mem_prim->get_desc(), cpu_engine_, 140 output_mem_prim->get_desc(), reorder_attr); 141 std::shared_ptr<stream> reorder_stream; 142 MklDnnThreadPool eigen_tp(ctx); 143 reorder_stream.reset(CreateStream(&eigen_tp, cpu_engine_)); 144 std::unordered_map<int, dnnl::memory> reorder_args = { 145 {DNNL_ARG_FROM, *input_mem_prim}, {DNNL_ARG_TO, *output_mem_prim}}; 146 std::unique_ptr<dnnl::primitive> reorder_prim( 147 new dnnl::reorder(reorder_pd)); 148 reorder_prim->execute(*reorder_stream, reorder_args); 149 150 Tensor* output_min = nullptr; 151 Tensor* output_max = nullptr; 152 OP_REQUIRES_OK(ctx, 153 ctx->allocate_output(kOutputMinIndex, {}, &output_min)); 154 OP_REQUIRES_OK(ctx, 155 ctx->allocate_output(kOutputMaxIndex, {}, &output_max)); 156 157 output_min->flat<float>()(0) = input_requested_min_float; 158 output_max->flat<float>()(0) = input_requested_max_float; 159 } catch (dnnl::error& e) { 160 string error_msg = "Status: " + std::to_string(e.status) + 161 ", message: " + std::string(e.message) + ", in file " + 162 std::string(__FILE__) + ":" + std::to_string(__LINE__); 163 OP_REQUIRES_OK( 164 ctx, errors::Aborted("Operation received an exception:", error_msg)); 165 } 166 } 167 168 private: 169 const int kInputTensorIndex = 0; 170 const int kInputMinVecIndex = 1; 171 const int kInputMaxVecIndex = 2; 172 const int kRequestMinIndex = 3; 173 const int kRequestMaxIndex = 4; 174 const int kOutputTensorIndex = 0; 175 const int kOutputMinIndex = 1; 176 const int kOutputMaxIndex = 2; 177 DataType out_type_; 178 engine cpu_engine_ = engine(engine::kind::cpu, 0); 179 }; 180 181 // Registration for out_type: qint8 182 REGISTER_KERNEL_BUILDER(Name("RequantizePerChannel") 183 .Device(DEVICE_CPU) 184 .TypeConstraint<qint32>("T") 185 .TypeConstraint<qint8>("out_type"), 186 MklRequantizePerChannelOp<CPUDevice, qint8>); 187 // Registration for out_type: quint8 188 REGISTER_KERNEL_BUILDER(Name("RequantizePerChannel") 189 .Device(DEVICE_CPU) 190 .TypeConstraint<qint32>("T") 191 .TypeConstraint<quint8>("out_type"), 192 MklRequantizePerChannelOp<CPUDevice, quint8>); 193 194 } // namespace tensorflow 195 #endif // INTEL_MKL 196