xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/quantized_batch_norm_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
19 #include "tensorflow/core/framework/numeric_op.h"
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/register_types.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/kernels/quantization_utils.h"
24 
25 namespace tensorflow {
26 
27 namespace {
28 
29 // A slow but straightforward implementation of batch normalization.
30 template <typename T1, typename T2>
ReferenceBatchNorm(const Tensor & input,const float input_min,const float input_max,const Tensor & mean,float mean_min,float mean_max,const Tensor & var,float var_min,float var_max,const Tensor & beta,float beta_min,float beta_max,const Tensor & gamma,float gamma_min,float gamma_max,float variance_epsilon,bool scale_after_normalization,Tensor * output,float * output_min,float * output_max)31 void ReferenceBatchNorm(const Tensor& input, const float input_min,
32                         const float input_max, const Tensor& mean,
33                         float mean_min, float mean_max, const Tensor& var,
34                         float var_min, float var_max, const Tensor& beta,
35                         float beta_min, float beta_max, const Tensor& gamma,
36                         float gamma_min, float gamma_max,
37                         float variance_epsilon, bool scale_after_normalization,
38                         Tensor* output, float* output_min, float* output_max) {
39   auto input_flat = input.flat<T1>();
40   auto mean_flat = mean.flat<T1>();
41   auto var_flat = var.flat<T1>();
42   auto beta_flat = beta.flat<T1>();
43   auto gamma_flat = gamma.flat<T1>();
44   auto output_flat = output->flat<T2>();
45 
46   const int depth = mean.dim_size(0);
47   const int row_count = input_flat.size() / depth;
48 
49   *output_min = std::numeric_limits<float>::max();
50   *output_max = std::numeric_limits<float>::lowest();
51   for (int pass = 0; pass < 2; ++pass) {
52     const bool is_range_pass = (pass == 0);
53     for (int row_index = 0; row_index < row_count; ++row_index) {
54       for (int channel = 0; channel < depth; ++channel) {
55         const int input_index = (row_index * depth) + channel;
56         const float input_value =
57             QuantizedToFloat(input_flat(input_index), input_min, input_max);
58         const float mean_value =
59             QuantizedToFloat(mean_flat(channel), mean_min, mean_max);
60         const float var_value =
61             QuantizedToFloat(var_flat(channel), var_min, var_max);
62         const float beta_value =
63             QuantizedToFloat(beta_flat(channel), beta_min, beta_max);
64         const float gamma_value =
65             QuantizedToFloat(gamma_flat(channel), gamma_min, gamma_max);
66         float output_value;
67         if (scale_after_normalization) {
68           output_value = (((input_value - mean_value) /
69                            sqrtf(var_value + variance_epsilon)) *
70                           gamma_value) +
71                          beta_value;
72         } else {
73           output_value = ((input_value - mean_value) /
74                           sqrtf(var_value + variance_epsilon)) +
75                          beta_value;
76         }
77         if (is_range_pass) {
78           *output_min = std::min(output_value, *output_min);
79           *output_max = std::max(output_value, *output_max);
80         } else {
81           output_flat(input_index) =
82               FloatToQuantized<T2>(output_value, *output_min, *output_max);
83         }
84       }
85     }
86   }
87 }
88 
89 // An implementation of batch normalization that does the main calculations
90 // using only fixed-point arithmetic. There's a prologue with some floating
91 // calculations, but assuming the weights are constant these could be hoisted to
92 // an offline process, or baked into the weights.
93 template <typename T1, typename T2>
FixedPointBatchNorm(const Tensor & input,const float input_min,const float input_max,const Tensor & mean,float mean_min,float mean_max,const Tensor & var,float var_min,float var_max,const Tensor & beta,float beta_min,float beta_max,const Tensor & gamma,float gamma_min,float gamma_max,float variance_epsilon,bool scale_after_normalization,Tensor * output,float * output_min,float * output_max)94 void FixedPointBatchNorm(const Tensor& input, const float input_min,
95                          const float input_max, const Tensor& mean,
96                          float mean_min, float mean_max, const Tensor& var,
97                          float var_min, float var_max, const Tensor& beta,
98                          float beta_min, float beta_max, const Tensor& gamma,
99                          float gamma_min, float gamma_max,
100                          float variance_epsilon, bool scale_after_normalization,
101                          Tensor* output, float* output_min, float* output_max) {
102   auto input_flat = input.flat<T1>();
103   auto mean_flat = mean.flat<T1>();
104   auto var_flat = var.flat<T1>();
105   auto beta_flat = beta.flat<T1>();
106   auto gamma_flat = gamma.flat<T1>();
107   auto output_flat = output->flat<T2>();
108 
109   const int depth = mean.dim_size(0);
110   const int row_count = input_flat.size() / depth;
111 
112   // The range here is chosen so that typical input values fit in without any
113   // overflow or loss of precision, going from +1m to -1m with 10 bits of fixed
114   // point precision.
115   *output_min = -(1 << 20);
116   *output_max = (1 << 20);
117 
118   Tensor scale_tensor(DataTypeToEnum<T2>::v(), {depth});
119   auto scale_flat = scale_tensor.flat<T2>();
120   Tensor offset_tensor(DataTypeToEnum<T2>::v(), {depth});
121   auto offset_flat = offset_tensor.flat<T2>();
122   for (int channel = 0; channel < depth; ++channel) {
123     const float mean_value =
124         QuantizedToFloat(mean_flat(channel), mean_min, mean_max);
125     const float var_value =
126         QuantizedToFloat(var_flat(channel), var_min, var_max);
127     const float beta_value =
128         QuantizedToFloat(beta_flat(channel), beta_min, beta_max);
129     const float gamma_value =
130         QuantizedToFloat(gamma_flat(channel), gamma_min, gamma_max);
131     float scale_value;
132     if (scale_after_normalization) {
133       scale_value = (1.0f / sqrtf(var_value + variance_epsilon)) * gamma_value;
134     } else {
135       scale_value = (1.0f / sqrtf(var_value + variance_epsilon));
136     }
137     const float offset_value = (-mean_value * scale_value) + beta_value;
138     scale_flat(channel) =
139         FloatToQuantized<T2>(scale_value, *output_min, *output_max);
140     offset_flat(channel) =
141         FloatToQuantized<T2>(offset_value, *output_min, *output_max);
142   }
143 
144   const T2 one_in_output_space =
145       FloatToQuantized<T2>(1.0f, *output_min, *output_max);
146   for (int row_index = 0; row_index < row_count; ++row_index) {
147     for (int channel = 0; channel < depth; ++channel) {
148       const int input_index = (row_index * depth) + channel;
149       const T2 input_value =
150           RequantizeInNewRange<T1, T2>(input_flat(input_index), input_min,
151                                        input_max, *output_min, *output_max);
152       const T2 scale_value = scale_flat(channel);
153       const T2 offset_value = offset_flat(channel);
154       const T2 output_value =
155           ((input_value * scale_value) / one_in_output_space) + offset_value;
156       output_flat(input_index) = output_value;
157     }
158   }
159 }
160 
161 }  // namespace
162 
163 template <typename T1, typename T2>
164 class QuantizedBatchNormOp : public OpKernel {
165  public:
QuantizedBatchNormOp(OpKernelConstruction * context)166   explicit QuantizedBatchNormOp(OpKernelConstruction* context)
167       : OpKernel(context) {
168     OP_REQUIRES_OK(context,
169                    context->GetAttr("variance_epsilon", &variance_epsilon_));
170     OP_REQUIRES_OK(context, context->GetAttr("scale_after_normalization",
171                                              &scale_after_normalization_));
172   }
173 
Compute(OpKernelContext * context)174   void Compute(OpKernelContext* context) override {
175     const Tensor& input = context->input(0);
176     const auto& input_min_tensor = context->input(1);
177     OP_REQUIRES(context, input_min_tensor.NumElements() == 1,
178                 errors::InvalidArgument("input_min must have 1 element"));
179     const float input_min = input_min_tensor.flat<float>()(0);
180     const auto& input_max_tensor = context->input(2);
181     OP_REQUIRES(context, input_max_tensor.NumElements() == 1,
182                 errors::InvalidArgument("input_max must have 1 element"));
183     const float input_max = input_max_tensor.flat<float>()(0);
184     const Tensor& mean = context->input(3);
185     const auto& mean_min_tensor = context->input(4);
186     OP_REQUIRES(context, mean_min_tensor.NumElements() == 1,
187                 errors::InvalidArgument("mean_min must have 1 element"));
188     const float mean_min = mean_min_tensor.flat<float>()(0);
189     const auto& mean_max_tensor = context->input(5);
190     OP_REQUIRES(context, mean_max_tensor.NumElements() == 1,
191                 errors::InvalidArgument("mean_max must have 1 element"));
192     const float mean_max = mean_max_tensor.flat<float>()(0);
193     const Tensor& var = context->input(6);
194     const auto& var_min_tensor = context->input(7);
195     OP_REQUIRES(context, var_min_tensor.NumElements() == 1,
196                 errors::InvalidArgument("var_min must have 1 element"));
197     const float var_min = var_min_tensor.flat<float>()(0);
198     const auto& var_max_tensor = context->input(8);
199     OP_REQUIRES(context, var_max_tensor.NumElements() == 1,
200                 errors::InvalidArgument("var_max must have 1 element"));
201     const float var_max = var_max_tensor.flat<float>()(0);
202     const Tensor& beta = context->input(9);
203     const auto& beta_min_tensor = context->input(10);
204     OP_REQUIRES(context, beta_min_tensor.NumElements() == 1,
205                 errors::InvalidArgument("beta_min must have 1 element"));
206     const float beta_min = beta_min_tensor.flat<float>()(0);
207     const auto& beta_max_tensor = context->input(11);
208     OP_REQUIRES(context, beta_max_tensor.NumElements() == 1,
209                 errors::InvalidArgument("beta_max must have 1 element"));
210     const float beta_max = beta_max_tensor.flat<float>()(0);
211     const Tensor& gamma = context->input(12);
212     const auto& gamma_min_tensor = context->input(13);
213     OP_REQUIRES(context, gamma_min_tensor.NumElements() == 1,
214                 errors::InvalidArgument("gamma_min must have 1 element"));
215     const float gamma_min = gamma_min_tensor.flat<float>()(0);
216     const auto& gamma_max_tensor = context->input(14);
217     OP_REQUIRES(context, gamma_max_tensor.NumElements() == 1,
218                 errors::InvalidArgument("gamma_max must have 1 element"));
219     const float gamma_max = gamma_max_tensor.flat<float>()(0);
220 
221     OP_REQUIRES(context, input.dims() == 4,
222                 errors::InvalidArgument("input must be 4-dimensional",
223                                         input.shape().DebugString()));
224     OP_REQUIRES(context, mean.dims() == 1,
225                 errors::InvalidArgument("mean must be 1-dimensional",
226                                         mean.shape().DebugString()));
227     OP_REQUIRES(context, var.dims() == 1,
228                 errors::InvalidArgument("var must be 1-dimensional",
229                                         var.shape().DebugString()));
230     OP_REQUIRES(context, beta.dims() == 1,
231                 errors::InvalidArgument("beta must be 1-dimensional",
232                                         beta.shape().DebugString()));
233     OP_REQUIRES(context, gamma.dims() == 1,
234                 errors::InvalidArgument("gamma must be 1-dimensional",
235                                         gamma.shape().DebugString()));
236     OP_REQUIRES(context, mean.NumElements() > 1,
237                 errors::InvalidArgument("Must have at least a mean value",
238                                         gamma.shape().DebugString()));
239     OP_REQUIRES(context, mean.NumElements() > 1,
240                 errors::InvalidArgument("Must have at least a mean value"));
241     const auto last_dim = input.shape().dims() - 1;
242     OP_REQUIRES(context,
243                 mean.shape().dim_size(0) == input.shape().dim_size(last_dim),
244                 errors::InvalidArgument("Must provide as many means as the "
245                                         "last dimension of the input tensor: ",
246                                         mean.shape().DebugString(), " vs. ",
247                                         input.shape().DebugString()));
248     OP_REQUIRES(
249         context, mean.shape().dim_size(0) == var.shape().dim_size(0),
250         errors::InvalidArgument(
251             "Mean and variance tensors must have the same shape: ",
252             mean.shape().DebugString(), " vs. ", var.shape().DebugString()));
253     OP_REQUIRES(
254         context, mean.shape().dim_size(0) == beta.shape().dim_size(0),
255         errors::InvalidArgument(
256             "Mean and beta tensors must have the same shape: ",
257             mean.shape().DebugString(), " vs. ", beta.shape().DebugString()));
258     OP_REQUIRES(
259         context, mean.shape().dim_size(0) == gamma.shape().dim_size(0),
260         errors::InvalidArgument(
261             "Mean and gamma tensors must have the same shape: ",
262             mean.shape().DebugString(), " vs. ", gamma.shape().DebugString()));
263 
264     Tensor* output = nullptr;
265     OP_REQUIRES_OK(context,
266                    context->allocate_output(0, input.shape(), &output));
267     float output_min;
268     float output_max;
269     FixedPointBatchNorm<T1, T2>(input, input_min, input_max, mean, mean_min,
270                                 mean_max, var, var_min, var_max, beta, beta_min,
271                                 beta_max, gamma, gamma_min, gamma_max,
272                                 variance_epsilon_, scale_after_normalization_,
273                                 output, &output_min, &output_max);
274 
275     Tensor* output_min_tensor = nullptr;
276     OP_REQUIRES_OK(context,
277                    context->allocate_output(1, {}, &output_min_tensor));
278     output_min_tensor->flat<float>()(0) = output_min;
279 
280     Tensor* output_max_tensor = nullptr;
281     OP_REQUIRES_OK(context,
282                    context->allocate_output(2, {}, &output_max_tensor));
283     output_max_tensor->flat<float>()(0) = output_max;
284   }
285 
286  private:
287   float variance_epsilon_;
288   bool scale_after_normalization_;
289 };
290 
291 REGISTER_KERNEL_BUILDER(Name("QuantizedBatchNormWithGlobalNormalization")
292                             .Device(DEVICE_CPU)
293                             .TypeConstraint<quint8>("Tinput")
294                             .TypeConstraint<qint32>("out_type"),
295                         QuantizedBatchNormOp<quint8, qint32>);
296 
297 }  // namespace tensorflow
298