xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/mkl/mkl_requantize_per_channel_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/array_ops.cc.
17 
18 #ifdef INTEL_MKL
19 #define EIGEN_USE_THREADS
20 
21 #include <math.h>
22 
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "dnnl.hpp"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/type_traits.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/kernels/meta_support.h"
30 #include "tensorflow/core/kernels/no_op.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/util/mkl_util.h"
33 
34 namespace tensorflow {
35 
36 typedef Eigen::ThreadPoolDevice CPUDevice;
37 
38 template <typename Device, typename Toutput>
39 class MklRequantizePerChannelOp : public OpKernel {
40  public:
MklRequantizePerChannelOp(OpKernelConstruction * ctx)41   explicit MklRequantizePerChannelOp(OpKernelConstruction* ctx)
42       : OpKernel(ctx) {
43     OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_type_));
44     OP_REQUIRES(ctx, out_type_ == DT_QINT8 || out_type_ == DT_QUINT8,
45                 errors::InvalidArgument(
46                     "out_type must be qint8 or quint8, but got: ", out_type_));
47   }
~MklRequantizePerChannelOp()48   virtual ~MklRequantizePerChannelOp() {}
Compute(OpKernelContext * ctx)49   void Compute(OpKernelContext* ctx) override {
50     try {
51       const Tensor& input = ctx->input(kInputTensorIndex);
52       OP_REQUIRES(
53           ctx, input.dims() == 4,
54           errors::InvalidArgument("Current RequantizePerChannel operator"
55                                   "supports 4D tensors only."));
56 
57       const Tensor& input_min_vec = ctx->input(kInputMinVecIndex);
58       size_t depth = input_min_vec.NumElements();
59       float* input_min_vec_data = (float*)const_cast<void*>(
60           static_cast<const void*>(input_min_vec.flat<float>().data()));
61 
62       const Tensor& input_max_vec = ctx->input(kInputMaxVecIndex);
63       OP_REQUIRES(
64           ctx, input_max_vec.NumElements() == depth,
65           errors::InvalidArgument("input_max has incorrect size, expected ",
66                                   depth, " was ", input_max_vec.NumElements()));
67       float* input_max_vec_data = (float*)const_cast<void*>(
68           static_cast<const void*>(input_max_vec.flat<float>().data()));
69 
70       const Tensor& input_requested_min = ctx->input(this->kRequestMinIndex);
71       OP_REQUIRES(
72           ctx, input_requested_min.NumElements() == 1,
73           errors::InvalidArgument("requested_output_min must be a scalar"));
74       const float input_requested_min_float =
75           input_requested_min.flat<float>()(0);
76 
77       const Tensor& input_requested_max = ctx->input(this->kRequestMaxIndex);
78       OP_REQUIRES(
79           ctx, input_requested_min.NumElements() == 1,
80           errors::InvalidArgument("requested_output_max must be a scalar"));
81       const float input_requested_max_float =
82           input_requested_max.flat<float>()(0);
83 
84       if (out_type_ == DT_QINT8) {
85         OP_REQUIRES(ctx, input_requested_min_float < 0.0f,
86                     errors::InvalidArgument(
87                         "If out_type is QINT8, requested_output_max must be "
88                         "non negative, got ",
89                         input_requested_min_float));
90       }
91 
92       const float factor = (out_type_ == DT_QINT8) ? 127.0f : 255.0f;
93       const float requested_min_max =
94           std::max(std::abs(input_requested_min_float),
95                    std::abs(input_requested_max_float));
96       Tensor* output = nullptr;
97       OP_REQUIRES_OK(ctx, ctx->allocate_output(kOutputTensorIndex,
98                                                input.shape(), &output));
99 
100       std::vector<float> scales(depth);
101       for (int i = 0; i < depth; ++i) {
102         float min_max_from_vec = std::max(std::abs(input_min_vec_data[i]),
103                                           std::abs(input_max_vec_data[i]));
104         scales[i] = factor * (min_max_from_vec / requested_min_max /
105                               static_cast<float>(1L << 31));
106       }
107 
108       dnnl::primitive_attr reorder_attr;
109       reorder_attr.set_output_scales(2, scales);
110 
111       memory::dims dims_mkl_order =
112           TFShapeToMklDnnDimsInNCHW(input.shape(), FORMAT_NHWC);
113       memory::desc input_md = memory::desc(dims_mkl_order, MklDnnType<qint32>(),
114                                            memory::format_tag::nhwc);
115       memory::desc output_md =
116           (out_type_ == DT_QINT8)
117               ? memory::desc(dims_mkl_order, MklDnnType<qint8>(),
118                              memory::format_tag::nhwc)
119               : memory::desc(dims_mkl_order, MklDnnType<quint8>(),
120                              memory::format_tag::nhwc);
121 
122       void* input_buf =
123           static_cast<void*>(const_cast<qint32*>(input.flat<qint32>().data()));
124       void* output_buf;
125       if (out_type_ == DT_QINT8) {
126         output_buf = static_cast<void*>(
127             const_cast<qint8*>(output->flat<qint8>().data()));
128       } else {
129         output_buf = static_cast<void*>(
130             const_cast<quint8*>(output->flat<quint8>().data()));
131       }
132 
133       std::unique_ptr<memory> input_mem_prim(
134           new memory(input_md, cpu_engine_, input_buf));
135       std::unique_ptr<memory> output_mem_prim(
136           new memory(output_md, cpu_engine_, output_buf));
137 
138       dnnl::reorder::primitive_desc reorder_pd =
139           ReorderPd(cpu_engine_, input_mem_prim->get_desc(), cpu_engine_,
140                     output_mem_prim->get_desc(), reorder_attr);
141       std::shared_ptr<stream> reorder_stream;
142       MklDnnThreadPool eigen_tp(ctx);
143       reorder_stream.reset(CreateStream(&eigen_tp, cpu_engine_));
144       std::unordered_map<int, dnnl::memory> reorder_args = {
145           {DNNL_ARG_FROM, *input_mem_prim}, {DNNL_ARG_TO, *output_mem_prim}};
146       std::unique_ptr<dnnl::primitive> reorder_prim(
147           new dnnl::reorder(reorder_pd));
148       reorder_prim->execute(*reorder_stream, reorder_args);
149 
150       Tensor* output_min = nullptr;
151       Tensor* output_max = nullptr;
152       OP_REQUIRES_OK(ctx,
153                      ctx->allocate_output(kOutputMinIndex, {}, &output_min));
154       OP_REQUIRES_OK(ctx,
155                      ctx->allocate_output(kOutputMaxIndex, {}, &output_max));
156 
157       output_min->flat<float>()(0) = input_requested_min_float;
158       output_max->flat<float>()(0) = input_requested_max_float;
159     } catch (dnnl::error& e) {
160       string error_msg = "Status: " + std::to_string(e.status) +
161                          ", message: " + std::string(e.message) + ", in file " +
162                          std::string(__FILE__) + ":" + std::to_string(__LINE__);
163       OP_REQUIRES_OK(
164           ctx, errors::Aborted("Operation received an exception:", error_msg));
165     }
166   }
167 
168  private:
169   const int kInputTensorIndex = 0;
170   const int kInputMinVecIndex = 1;
171   const int kInputMaxVecIndex = 2;
172   const int kRequestMinIndex = 3;
173   const int kRequestMaxIndex = 4;
174   const int kOutputTensorIndex = 0;
175   const int kOutputMinIndex = 1;
176   const int kOutputMaxIndex = 2;
177   DataType out_type_;
178   engine cpu_engine_ = engine(engine::kind::cpu, 0);
179 };
180 
181 // Registration for out_type: qint8
182 REGISTER_KERNEL_BUILDER(Name("RequantizePerChannel")
183                             .Device(DEVICE_CPU)
184                             .TypeConstraint<qint32>("T")
185                             .TypeConstraint<qint8>("out_type"),
186                         MklRequantizePerChannelOp<CPUDevice, qint8>);
187 // Registration for out_type: quint8
188 REGISTER_KERNEL_BUILDER(Name("RequantizePerChannel")
189                             .Device(DEVICE_CPU)
190                             .TypeConstraint<qint32>("T")
191                             .TypeConstraint<quint8>("out_type"),
192                         MklRequantizePerChannelOp<CPUDevice, quint8>);
193 
194 }  // namespace tensorflow
195 #endif  // INTEL_MKL
196