1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/array_ops.cc.
17 #ifdef INTEL_MKL
18 #define EIGEN_USE_THREADS
19 
20 #include <math.h>
21 
22 #include <limits>
23 
24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25 #include "tensorflow/core/framework/op.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/type_traits.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/kernels/meta_support.h"
30 #include "tensorflow/core/kernels/no_op.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/util/mkl_threadpool.h"
33 #include "tensorflow/core/util/mkl_util.h"
34 
35 namespace tensorflow {
36 
37 typedef Eigen::ThreadPoolDevice CPUDevice;
38 
39 class MklRequantizationRangePerChannelOp : public OpKernel {
40  public:
MklRequantizationRangePerChannelOp(OpKernelConstruction * ctx)41   explicit MklRequantizationRangePerChannelOp(OpKernelConstruction* ctx)
42       : OpKernel(ctx) {
43     OP_REQUIRES_OK(ctx, ctx->GetAttr("clip_value_max", &clip_value_max_));
44   }
45 
Compute(OpKernelContext * ctx)46   void Compute(OpKernelContext* ctx) override {
47     const Tensor& input = ctx->input(kInputTensorIndex);
48     const Tensor& input_min = ctx->input(kInputMinIndex);
49     const Tensor& input_max = ctx->input(kInputMaxIndex);
50 
51     const size_t depth = input_max.NumElements();
52     OP_REQUIRES(
53         ctx, input_min.dim_size(0) == depth,
54         errors::InvalidArgument("input_min has incorrect size, expected ",
55                                 depth, " was ", input_min.dim_size(0)));
56     OP_REQUIRES(
57         ctx, input_max.dim_size(0) == depth,
58         errors::InvalidArgument("input_max has incorrect size, expected ",
59                                 depth, " was ", input_max.dim_size(0)));
60     OP_REQUIRES(
61         ctx, input_min.NumElements() == depth,
62         errors::InvalidArgument("input_min must have the same number of "
63                                 "elements as input_max, got ",
64                                 input_min.NumElements(), " and ", depth));
65     OP_REQUIRES(ctx, input.NumElements() > 0,
66                 errors::InvalidArgument("input must not be empty"));
67     OP_REQUIRES(ctx, input.dims() == 4,
68                 errors::InvalidArgument("input must be in NHWC format"));
69     OP_REQUIRES(
70         ctx, input.dim_size(3) == depth,
71         errors::InvalidArgument(
72             "input must have same number of channels as length of input_min: ",
73             input.dim_size(3), " vs ", depth));
74 
75     const float* input_min_data = input_min.flat<float>().data();
76     const float* input_max_data = input_max.flat<float>().data();
77     std::vector<float> ranges(depth);
78     bool is_non_negative = true;
79     Eigen::array<int, 2> shuffling({1, 0});
80     auto input_matrix = input.flat_inner_dims<qint32>();
81 
82     // TODO(intel-tf): Verify performance of not transposing and finding min max
83     // directly from input_matrix vs the one presented below of transposing and
84     // using the transposed matrix as the transposing operation in itself might
85     // be more costly.
86     // Note that this operation is a calibration step for quantization and will
87     // cease to exist in the final inference graph(will exist as a const node).
88     auto transposed_input = input_matrix.shuffle(shuffling);
89 
90     // Find the ranges of each channel in parallel.
91     float out_min_max = std::numeric_limits<float>::min();
92 
93 #ifdef ENABLE_ONEDNN_OPENMP
94 #ifdef _MSC_VER
95 #pragma omp parallel for
96 #else
97 #pragma omp parallel for reduction(max : out_min_max)
98 #endif
99 #endif  // ENABLE_ONEDNN_OPENMP
100     // TODO(intel-tf): Add eigen parallel_for
101     for (int64_t i = 0; i < depth; ++i) {
102       Eigen::Tensor<qint32, 0, Eigen::RowMajor> min =
103           transposed_input.chip<0>(i).minimum();
104       Eigen::Tensor<qint32, 0, Eigen::RowMajor> max =
105           transposed_input.chip<0>(i).maximum();
106       const int32_t min_per_channel = min();
107       const int32_t max_per_channel = max();
108       const int32_t abs_max =
109           std::max(std::abs(min_per_channel), std::abs(max_per_channel));
110       float scale =
111           std::max(std::abs(input_min_data[i]), std::abs(input_max_data[i]));
112       ranges[i] =
113           scale * static_cast<float>(abs_max) / static_cast<float>(1L << 31);
114       if (min_per_channel < 0) is_non_negative = false;
115 
116       // Thread-local out_min_max.
117       out_min_max = std::max(out_min_max, ranges[i]);
118     }
119 
120     // All local out_min_max gets max-reduced into one global out_min_max at
121     // the end of the loop by specifying reduction(max:out_min_max) along with
122     // omp parallel for.
123 
124     // Fixing max to clip_value_max_ (example 6.0 to support relu6)
125     if (out_min_max > clip_value_max_) out_min_max = clip_value_max_;
126 
127     Tensor* output_min = nullptr;
128     Tensor* output_max = nullptr;
129     OP_REQUIRES_OK(ctx, ctx->allocate_output(kOutputMinIndex, {}, &output_min));
130     OP_REQUIRES_OK(ctx, ctx->allocate_output(kOutputMaxIndex, {}, &output_max));
131     output_min->flat<float>()(0) = is_non_negative ? 0.0f : -out_min_max;
132     output_max->flat<float>()(0) = out_min_max;
133   }
134 
135  private:
136   float clip_value_max_ = std::numeric_limits<float>::infinity();
137   const int kInputTensorIndex = 0;
138   const int kInputMinIndex = 1;
139   const int kInputMaxIndex = 2;
140   const int kOutputMinIndex = 0;
141   const int kOutputMaxIndex = 1;
142 };
143 
144 REGISTER_KERNEL_BUILDER(Name("RequantizationRangePerChannel")
145                             .Device(DEVICE_CPU)
146                             .TypeConstraint<qint32>("T"),
147                         MklRequantizationRangePerChannelOp);
148 }  // namespace tensorflow
149 #endif  // INTEL_MKL
150