1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/op_kernel.h"
16 #include "tensorflow/core/kernels/uniform_quant_ops/math_utils.h"
17 #include "tensorflow/core/kernels/uniform_quant_ops/tensor_utils.h"
18 #include "tensorflow/core/platform/status.h"
19
20 namespace tensorflow {
21 namespace {
22
23 using tensorflow::errors::InvalidArgument;
24
25 // Requantize from per-tensor to per-tensor.
26 template <typename Tin, typename Tout>
PerTensorToPerTensorRequantize(const Tensor & input,float input_scale,int32_t input_zero_point,float output_scale,int32_t output_zero_point,int32_t quantization_min_val,int32_t quantization_max_val,Tensor & output)27 Status PerTensorToPerTensorRequantize(
28 const Tensor& input, float input_scale, int32_t input_zero_point,
29 float output_scale, int32_t output_zero_point, int32_t quantization_min_val,
30 int32_t quantization_max_val, Tensor& output) {
31 const double effective_multiplier =
32 static_cast<double>(input_scale) / output_scale;
33 int32_t effective_quantized_multiplier;
34 int32_t effective_shift;
35 TF_RETURN_IF_ERROR(QuantizeMultiplier(
36 effective_multiplier, effective_quantized_multiplier, effective_shift));
37
38 output.flat<Tout>() = input.flat<Tin>().unaryExpr(
39 [effective_quantized_multiplier, effective_shift, input_zero_point,
40 output_zero_point, quantization_min_val,
41 quantization_max_val](Tin input_val) {
42 return AffineRequantizeWithQuantizedMultiplierAndShift<Tin, Tout>(
43 input_val, effective_quantized_multiplier, effective_shift,
44 input_zero_point, output_zero_point, quantization_min_val,
45 quantization_max_val);
46 });
47 return OkStatus();
48 }
49
50 // Requantize where the input or output contains any per-axis quantized cases.
51 // - From per-tensor to per-axis.
52 // - From per-axis to per-tensor.
53 // - From per-axis to per-axis.
54 template <typename Tin, typename Tout>
PerAxisRequantize(OpKernelContext * context,const Tensor & input,const Tensor & input_scales,const Tensor & input_zero_points,const Tensor & output_scales,const Tensor & output_zero_points,int quantization_axis,int32_t quantization_min_val,int32_t quantization_max_val,Tensor & output)55 Status PerAxisRequantize(OpKernelContext* context, const Tensor& input,
56 const Tensor& input_scales,
57 const Tensor& input_zero_points,
58 const Tensor& output_scales,
59 const Tensor& output_zero_points,
60 int quantization_axis, int32_t quantization_min_val,
61 int32_t quantization_max_val, Tensor& output) {
62 const bool input_per_axis_quantization = input_scales.dims() == 1;
63 const bool output_per_axis_quantization = output_scales.dims() == 1;
64 const auto& per_axis_scales_shape = input_per_axis_quantization
65 ? input_scales.shape()
66 : output_scales.shape();
67
68 Tensor effective_quantized_multipliers;
69 TF_RETURN_IF_ERROR(context->allocate_temp(DT_INT32, per_axis_scales_shape,
70 &effective_quantized_multipliers));
71 Tensor effective_shifts;
72 TF_RETURN_IF_ERROR(context->allocate_temp(DT_INT32, per_axis_scales_shape,
73 &effective_shifts));
74
75 const float* input_scales_data = input_scales.flat<float>().data();
76 const float* output_scales_data = output_scales.flat<float>().data();
77 int32_t* effective_quantized_multipliers_data =
78 effective_quantized_multipliers.flat<int32_t>().data();
79 int32_t* effective_shifts_data = effective_shifts.flat<int32_t>().data();
80
81 const int64_t quantization_dim_size = output.dim_size(quantization_axis);
82
83 for (int64_t i = 0; i < quantization_dim_size; ++i) {
84 const double effective_multiplier =
85 static_cast<double>(
86 input_scales_data[input_per_axis_quantization ? i : 0]) /
87 output_scales_data[output_per_axis_quantization ? i : 0];
88 TF_RETURN_IF_ERROR(QuantizeMultiplier(
89 effective_multiplier, effective_quantized_multipliers_data[i],
90 effective_shifts_data[i]));
91 }
92
93 const int32* input_zero_points_data = input_zero_points.flat<int32>().data();
94 const int32* output_zero_points_data =
95 output_zero_points.flat<int32>().data();
96
97 auto input_tensor =
98 input.template flat_inner_outer_dims<Tin, 3>(quantization_axis - 1);
99 auto output_tensor =
100 output.template flat_inner_outer_dims<Tout, 3>(quantization_axis - 1);
101
102 for (int i = 0; i < quantization_dim_size; ++i) {
103 output_tensor.template chip<1>(i) =
104 input_tensor.template chip<1>(i).unaryExpr(
105 [effective_quantized_multipliers_data, effective_shifts_data,
106 input_zero_points_data, output_zero_points_data,
107 quantization_min_val, quantization_max_val,
108 input_per_axis_quantization, output_per_axis_quantization,
109 i](Tin input_val) {
110 return AffineRequantizeWithQuantizedMultiplierAndShift<Tin, Tout>(
111 input_val, effective_quantized_multipliers_data[i],
112 effective_shifts_data[i],
113 input_zero_points_data[input_per_axis_quantization ? i : 0],
114 output_zero_points_data[output_per_axis_quantization ? i : 0],
115 quantization_min_val, quantization_max_val);
116 });
117 }
118 return OkStatus();
119 }
120
121 template <typename Tin, typename Tout>
EvalRequantize(OpKernelContext * context,const Tensor & input,const Tensor & input_scales,const Tensor & input_zero_points,const Tensor & output_scales,const Tensor & output_zero_points,int input_quantization_axis,int output_quantization_axis,int32_t quantization_min_val,int32_t quantization_max_val,Tensor & output)122 Status EvalRequantize(OpKernelContext* context, const Tensor& input,
123 const Tensor& input_scales,
124 const Tensor& input_zero_points,
125 const Tensor& output_scales,
126 const Tensor& output_zero_points,
127 int input_quantization_axis, int output_quantization_axis,
128 int32_t quantization_min_val,
129 int32_t quantization_max_val, Tensor& output) {
130 if (input_quantization_axis == -1 && output_quantization_axis == -1) {
131 return PerTensorToPerTensorRequantize<Tin, Tout>(
132 input, input_scales.scalar<float>()(),
133 input_zero_points.scalar<int32>()(), output_scales.scalar<float>()(),
134 output_zero_points.scalar<int32>()(), quantization_min_val,
135 quantization_max_val, output);
136 } else {
137 const int quantization_axis = input_quantization_axis >= 0
138 ? input_quantization_axis
139 : output_quantization_axis;
140 return PerAxisRequantize<Tin, Tout>(
141 context, input, input_scales, input_zero_points, output_scales,
142 output_zero_points, quantization_axis, quantization_min_val,
143 quantization_max_val, output);
144 }
145 }
146
147 } // namespace
148
149 // Changing input_quantization_min/max_val is no-op for this kernel.
150 template <typename Tin, typename Tout>
151 class UniformRequantizeOp : public OpKernel {
152 public:
UniformRequantizeOp(OpKernelConstruction * context)153 explicit UniformRequantizeOp(OpKernelConstruction* context)
154 : OpKernel(context) {
155 OP_REQUIRES(context,
156 (std::is_same<Tin, qint32>() || std::is_same<Tin, qint8>()),
157 InvalidArgument("Unsupported input type."));
158 OP_REQUIRES(context, (std::is_same<Tout, qint8>()),
159 InvalidArgument("Unsupported output type."));
160
161 OP_REQUIRES_OK(context, context->GetAttr("output_quantization_min_val",
162 &output_quantization_min_val_));
163 OP_REQUIRES_OK(context, context->GetAttr("output_quantization_max_val",
164 &output_quantization_max_val_));
165
166 OP_REQUIRES_OK(context, context->GetAttr("input_quantization_axis",
167 &input_quantization_axis_));
168 OP_REQUIRES_OK(context, context->GetAttr("output_quantization_axis",
169 &output_quantization_axis_));
170 OP_REQUIRES(
171 context, (input_quantization_axis_ >= -1),
172 InvalidArgument("input_quantization_axis must be >= -1, given: ",
173 input_quantization_axis_));
174 OP_REQUIRES(
175 context, (output_quantization_axis_ >= -1),
176 InvalidArgument("output_quantization_axis must be >= -1, given: ",
177 output_quantization_axis_));
178 OP_REQUIRES(
179 context,
180 (!(input_quantization_axis_ >= 0 && output_quantization_axis_ >= 0) ||
181 input_quantization_axis_ == output_quantization_axis_),
182 InvalidArgument("If input and output is both per-axis quantized, the "
183 "quantization axis must be same."));
184 }
185
Compute(OpKernelContext * context)186 void Compute(OpKernelContext* context) override {
187 const Tensor& input = context->input(0);
188 const Tensor& input_scales = context->input(1);
189 const Tensor& input_zero_points = context->input(2);
190 const Tensor& output_scales = context->input(3);
191 const Tensor& output_zero_points = context->input(4);
192
193 OP_REQUIRES_OK(context,
194 (QuantizationAxisAndShapeValid(
195 input.shape(), input_scales.shape(),
196 input_zero_points.shape(), input_quantization_axis_)));
197 OP_REQUIRES_OK(context,
198 (QuantizationAxisAndShapeValid(
199 input.shape(), output_scales.shape(),
200 output_zero_points.shape(), output_quantization_axis_)));
201
202 OP_REQUIRES(
203 context,
204 (AllElementsPositive<float>(input_scales) &&
205 AllElementsPositive<float>(output_scales)),
206 InvalidArgument("input/output scales elements must be all positive."));
207
208 Tensor* output = nullptr;
209 OP_REQUIRES_OK(context,
210 context->allocate_output(0, input.shape(), &output));
211
212 OP_REQUIRES_OK(
213 context,
214 EvalRequantize<Tin, Tout>(
215 context, input, input_scales, input_zero_points, output_scales,
216 output_zero_points, input_quantization_axis_,
217 output_quantization_axis_, output_quantization_min_val_,
218 output_quantization_max_val_, *output));
219 }
220
221 private:
222 int input_quantization_axis_;
223 int32_t input_quantization_min_val_;
224 int32_t input_quantization_max_val_;
225 int output_quantization_axis_;
226 int32_t output_quantization_min_val_;
227 int32_t output_quantization_max_val_;
228 };
229
230 REGISTER_KERNEL_BUILDER(Name("UniformRequantize")
231 .Device(DEVICE_CPU)
232 .TypeConstraint<qint8>("Tin")
233 .TypeConstraint<qint8>("Tout"),
234 UniformRequantizeOp<qint8, qint8>);
235
236 REGISTER_KERNEL_BUILDER(Name("UniformRequantize")
237 .Device(DEVICE_CPU)
238 .TypeConstraint<qint32>("Tin")
239 .TypeConstraint<qint8>("Tout"),
240 UniformRequantizeOp<qint32, qint8>);
241
242 } // namespace tensorflow
243