xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/uniform_quant_ops/uniform_requantize_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/op_kernel.h"
16 #include "tensorflow/core/kernels/uniform_quant_ops/math_utils.h"
17 #include "tensorflow/core/kernels/uniform_quant_ops/tensor_utils.h"
18 #include "tensorflow/core/platform/status.h"
19 
20 namespace tensorflow {
21 namespace {
22 
23 using tensorflow::errors::InvalidArgument;
24 
25 // Requantize from per-tensor to per-tensor.
26 template <typename Tin, typename Tout>
PerTensorToPerTensorRequantize(const Tensor & input,float input_scale,int32_t input_zero_point,float output_scale,int32_t output_zero_point,int32_t quantization_min_val,int32_t quantization_max_val,Tensor & output)27 Status PerTensorToPerTensorRequantize(
28     const Tensor& input, float input_scale, int32_t input_zero_point,
29     float output_scale, int32_t output_zero_point, int32_t quantization_min_val,
30     int32_t quantization_max_val, Tensor& output) {
31   const double effective_multiplier =
32       static_cast<double>(input_scale) / output_scale;
33   int32_t effective_quantized_multiplier;
34   int32_t effective_shift;
35   TF_RETURN_IF_ERROR(QuantizeMultiplier(
36       effective_multiplier, effective_quantized_multiplier, effective_shift));
37 
38   output.flat<Tout>() = input.flat<Tin>().unaryExpr(
39       [effective_quantized_multiplier, effective_shift, input_zero_point,
40        output_zero_point, quantization_min_val,
41        quantization_max_val](Tin input_val) {
42         return AffineRequantizeWithQuantizedMultiplierAndShift<Tin, Tout>(
43             input_val, effective_quantized_multiplier, effective_shift,
44             input_zero_point, output_zero_point, quantization_min_val,
45             quantization_max_val);
46       });
47   return OkStatus();
48 }
49 
50 // Requantize where the input or output contains any per-axis quantized cases.
51 // - From per-tensor to per-axis.
52 // - From per-axis to per-tensor.
53 // - From per-axis to per-axis.
54 template <typename Tin, typename Tout>
PerAxisRequantize(OpKernelContext * context,const Tensor & input,const Tensor & input_scales,const Tensor & input_zero_points,const Tensor & output_scales,const Tensor & output_zero_points,int quantization_axis,int32_t quantization_min_val,int32_t quantization_max_val,Tensor & output)55 Status PerAxisRequantize(OpKernelContext* context, const Tensor& input,
56                          const Tensor& input_scales,
57                          const Tensor& input_zero_points,
58                          const Tensor& output_scales,
59                          const Tensor& output_zero_points,
60                          int quantization_axis, int32_t quantization_min_val,
61                          int32_t quantization_max_val, Tensor& output) {
62   const bool input_per_axis_quantization = input_scales.dims() == 1;
63   const bool output_per_axis_quantization = output_scales.dims() == 1;
64   const auto& per_axis_scales_shape = input_per_axis_quantization
65                                           ? input_scales.shape()
66                                           : output_scales.shape();
67 
68   Tensor effective_quantized_multipliers;
69   TF_RETURN_IF_ERROR(context->allocate_temp(DT_INT32, per_axis_scales_shape,
70                                             &effective_quantized_multipliers));
71   Tensor effective_shifts;
72   TF_RETURN_IF_ERROR(context->allocate_temp(DT_INT32, per_axis_scales_shape,
73                                             &effective_shifts));
74 
75   const float* input_scales_data = input_scales.flat<float>().data();
76   const float* output_scales_data = output_scales.flat<float>().data();
77   int32_t* effective_quantized_multipliers_data =
78       effective_quantized_multipliers.flat<int32_t>().data();
79   int32_t* effective_shifts_data = effective_shifts.flat<int32_t>().data();
80 
81   const int64_t quantization_dim_size = output.dim_size(quantization_axis);
82 
83   for (int64_t i = 0; i < quantization_dim_size; ++i) {
84     const double effective_multiplier =
85         static_cast<double>(
86             input_scales_data[input_per_axis_quantization ? i : 0]) /
87         output_scales_data[output_per_axis_quantization ? i : 0];
88     TF_RETURN_IF_ERROR(QuantizeMultiplier(
89         effective_multiplier, effective_quantized_multipliers_data[i],
90         effective_shifts_data[i]));
91   }
92 
93   const int32* input_zero_points_data = input_zero_points.flat<int32>().data();
94   const int32* output_zero_points_data =
95       output_zero_points.flat<int32>().data();
96 
97   auto input_tensor =
98       input.template flat_inner_outer_dims<Tin, 3>(quantization_axis - 1);
99   auto output_tensor =
100       output.template flat_inner_outer_dims<Tout, 3>(quantization_axis - 1);
101 
102   for (int i = 0; i < quantization_dim_size; ++i) {
103     output_tensor.template chip<1>(i) =
104         input_tensor.template chip<1>(i).unaryExpr(
105             [effective_quantized_multipliers_data, effective_shifts_data,
106              input_zero_points_data, output_zero_points_data,
107              quantization_min_val, quantization_max_val,
108              input_per_axis_quantization, output_per_axis_quantization,
109              i](Tin input_val) {
110               return AffineRequantizeWithQuantizedMultiplierAndShift<Tin, Tout>(
111                   input_val, effective_quantized_multipliers_data[i],
112                   effective_shifts_data[i],
113                   input_zero_points_data[input_per_axis_quantization ? i : 0],
114                   output_zero_points_data[output_per_axis_quantization ? i : 0],
115                   quantization_min_val, quantization_max_val);
116             });
117   }
118   return OkStatus();
119 }
120 
121 template <typename Tin, typename Tout>
EvalRequantize(OpKernelContext * context,const Tensor & input,const Tensor & input_scales,const Tensor & input_zero_points,const Tensor & output_scales,const Tensor & output_zero_points,int input_quantization_axis,int output_quantization_axis,int32_t quantization_min_val,int32_t quantization_max_val,Tensor & output)122 Status EvalRequantize(OpKernelContext* context, const Tensor& input,
123                       const Tensor& input_scales,
124                       const Tensor& input_zero_points,
125                       const Tensor& output_scales,
126                       const Tensor& output_zero_points,
127                       int input_quantization_axis, int output_quantization_axis,
128                       int32_t quantization_min_val,
129                       int32_t quantization_max_val, Tensor& output) {
130   if (input_quantization_axis == -1 && output_quantization_axis == -1) {
131     return PerTensorToPerTensorRequantize<Tin, Tout>(
132         input, input_scales.scalar<float>()(),
133         input_zero_points.scalar<int32>()(), output_scales.scalar<float>()(),
134         output_zero_points.scalar<int32>()(), quantization_min_val,
135         quantization_max_val, output);
136   } else {
137     const int quantization_axis = input_quantization_axis >= 0
138                                       ? input_quantization_axis
139                                       : output_quantization_axis;
140     return PerAxisRequantize<Tin, Tout>(
141         context, input, input_scales, input_zero_points, output_scales,
142         output_zero_points, quantization_axis, quantization_min_val,
143         quantization_max_val, output);
144   }
145 }
146 
147 }  // namespace
148 
149 // Changing input_quantization_min/max_val is no-op for this kernel.
150 template <typename Tin, typename Tout>
151 class UniformRequantizeOp : public OpKernel {
152  public:
UniformRequantizeOp(OpKernelConstruction * context)153   explicit UniformRequantizeOp(OpKernelConstruction* context)
154       : OpKernel(context) {
155     OP_REQUIRES(context,
156                 (std::is_same<Tin, qint32>() || std::is_same<Tin, qint8>()),
157                 InvalidArgument("Unsupported input type."));
158     OP_REQUIRES(context, (std::is_same<Tout, qint8>()),
159                 InvalidArgument("Unsupported output type."));
160 
161     OP_REQUIRES_OK(context, context->GetAttr("output_quantization_min_val",
162                                              &output_quantization_min_val_));
163     OP_REQUIRES_OK(context, context->GetAttr("output_quantization_max_val",
164                                              &output_quantization_max_val_));
165 
166     OP_REQUIRES_OK(context, context->GetAttr("input_quantization_axis",
167                                              &input_quantization_axis_));
168     OP_REQUIRES_OK(context, context->GetAttr("output_quantization_axis",
169                                              &output_quantization_axis_));
170     OP_REQUIRES(
171         context, (input_quantization_axis_ >= -1),
172         InvalidArgument("input_quantization_axis must be >= -1, given: ",
173                         input_quantization_axis_));
174     OP_REQUIRES(
175         context, (output_quantization_axis_ >= -1),
176         InvalidArgument("output_quantization_axis must be >= -1, given: ",
177                         output_quantization_axis_));
178     OP_REQUIRES(
179         context,
180         (!(input_quantization_axis_ >= 0 && output_quantization_axis_ >= 0) ||
181          input_quantization_axis_ == output_quantization_axis_),
182         InvalidArgument("If input and output is both per-axis quantized, the "
183                         "quantization axis must be same."));
184   }
185 
Compute(OpKernelContext * context)186   void Compute(OpKernelContext* context) override {
187     const Tensor& input = context->input(0);
188     const Tensor& input_scales = context->input(1);
189     const Tensor& input_zero_points = context->input(2);
190     const Tensor& output_scales = context->input(3);
191     const Tensor& output_zero_points = context->input(4);
192 
193     OP_REQUIRES_OK(context,
194                    (QuantizationAxisAndShapeValid(
195                        input.shape(), input_scales.shape(),
196                        input_zero_points.shape(), input_quantization_axis_)));
197     OP_REQUIRES_OK(context,
198                    (QuantizationAxisAndShapeValid(
199                        input.shape(), output_scales.shape(),
200                        output_zero_points.shape(), output_quantization_axis_)));
201 
202     OP_REQUIRES(
203         context,
204         (AllElementsPositive<float>(input_scales) &&
205          AllElementsPositive<float>(output_scales)),
206         InvalidArgument("input/output scales elements must be all positive."));
207 
208     Tensor* output = nullptr;
209     OP_REQUIRES_OK(context,
210                    context->allocate_output(0, input.shape(), &output));
211 
212     OP_REQUIRES_OK(
213         context,
214         EvalRequantize<Tin, Tout>(
215             context, input, input_scales, input_zero_points, output_scales,
216             output_zero_points, input_quantization_axis_,
217             output_quantization_axis_, output_quantization_min_val_,
218             output_quantization_max_val_, *output));
219   }
220 
221  private:
222   int input_quantization_axis_;
223   int32_t input_quantization_min_val_;
224   int32_t input_quantization_max_val_;
225   int output_quantization_axis_;
226   int32_t output_quantization_min_val_;
227   int32_t output_quantization_max_val_;
228 };
229 
230 REGISTER_KERNEL_BUILDER(Name("UniformRequantize")
231                             .Device(DEVICE_CPU)
232                             .TypeConstraint<qint8>("Tin")
233                             .TypeConstraint<qint8>("Tout"),
234                         UniformRequantizeOp<qint8, qint8>);
235 
236 REGISTER_KERNEL_BUILDER(Name("UniformRequantize")
237                             .Device(DEVICE_CPU)
238                             .TypeConstraint<qint32>("Tin")
239                             .TypeConstraint<qint8>("Tout"),
240                         UniformRequantizeOp<qint32, qint8>);
241 
242 }  // namespace tensorflow
243