xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/div.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stddef.h>
16 #include <stdint.h>
17 
18 #include "tensorflow/lite/c/builtin_op_data.h"
19 #include "tensorflow/lite/c/common.h"
20 #include "tensorflow/lite/kernels/internal/compatibility.h"
21 #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
22 #include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
23 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
24 #include "tensorflow/lite/kernels/internal/quantization_util.h"
25 #include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
26 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
27 #include "tensorflow/lite/kernels/internal/tensor.h"
28 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
29 #include "tensorflow/lite/kernels/internal/types.h"
30 #include "tensorflow/lite/kernels/kernel_util.h"
31 
32 namespace tflite {
33 namespace ops {
34 namespace builtin {
35 namespace div {
36 
37 // This file has three implementation of Div.
38 enum KernelType {
39   kReference,
40   kGenericOptimized,  // Neon-free
41   kNeonOptimized,
42 };
43 
44 constexpr int kInputTensor1 = 0;
45 constexpr int kInputTensor2 = 1;
46 constexpr int kOutputTensor = 0;
47 
48 struct OpData {
49   bool requires_broadcast;
50 
51   // Parameters used in the quantized paths where the output is 8bit
52   int32 output_activation_min;
53   int32 output_activation_max;
54 
55   // Parameters used in all quantized paths
56   int32_t output_multiplier;
57   int output_shift;
58 };
59 
Init(TfLiteContext * context,const char * buffer,size_t length)60 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
61   auto* data = new OpData;
62   data->requires_broadcast = false;
63   return data;
64 }
65 
Free(TfLiteContext * context,void * buffer)66 void Free(TfLiteContext* context, void* buffer) {
67   delete reinterpret_cast<OpData*>(buffer);
68 }
69 
Prepare(TfLiteContext * context,TfLiteNode * node)70 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
71   auto* params = reinterpret_cast<TfLiteDivParams*>(node->builtin_data);
72   OpData* data = reinterpret_cast<OpData*>(node->user_data);
73 
74   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
75   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
76 
77   const TfLiteTensor* input1;
78   TF_LITE_ENSURE_OK(context,
79                     GetInputSafe(context, node, kInputTensor1, &input1));
80   const TfLiteTensor* input2;
81   TF_LITE_ENSURE_OK(context,
82                     GetInputSafe(context, node, kInputTensor2, &input2));
83   TfLiteTensor* output;
84   TF_LITE_ENSURE_OK(context,
85                     GetOutputSafe(context, node, kOutputTensor, &output));
86 
87   TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
88   output->type = input2->type;
89 
90   data->requires_broadcast = !HaveSameShapes(input1, input2);
91 
92   TfLiteIntArray* output_size = nullptr;
93   if (data->requires_broadcast) {
94     TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
95                                    context, input1, input2, &output_size));
96   } else {
97     output_size = TfLiteIntArrayCopy(input1->dims);
98   }
99 
100   if (output->type == kTfLiteUInt8) {
101     TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
102         context, params->activation, output, &data->output_activation_min,
103         &data->output_activation_max));
104     const double real_multiplier =
105         input1->params.scale / (input2->params.scale * output->params.scale);
106     QuantizeMultiplier(real_multiplier, &data->output_multiplier,
107                        &data->output_shift);
108   }
109 
110   return context->ResizeTensor(context, output, output_size);
111 }
112 
113 template <KernelType kernel_type>
EvalDiv(TfLiteContext * context,TfLiteNode * node,TfLiteDivParams * params,const OpData * data,const TfLiteTensor * input1,const TfLiteTensor * input2,TfLiteTensor * output)114 void EvalDiv(TfLiteContext* context, TfLiteNode* node, TfLiteDivParams* params,
115              const OpData* data, const TfLiteTensor* input1,
116              const TfLiteTensor* input2, TfLiteTensor* output) {
117 #define TF_LITE_DIV(type, opname, data_type)                             \
118   tflite::ArithmeticParams op_params;                                    \
119   data_type output_activation_min, output_activation_max;                \
120   CalculateActivationRange(params->activation, &output_activation_min,   \
121                            &output_activation_max);                      \
122   SetActivationParams(output_activation_min, output_activation_max,      \
123                       &op_params);                                       \
124   type::opname(op_params, GetTensorShape(input1),                        \
125                GetTensorData<data_type>(input1), GetTensorShape(input2), \
126                GetTensorData<data_type>(input2), GetTensorShape(output), \
127                GetTensorData<data_type>(output))
128   if (output->type == kTfLiteInt32) {
129     if (kernel_type == kReference) {
130       if (data->requires_broadcast) {
131         TF_LITE_DIV(reference_ops, BroadcastDivSlow, int32_t);
132       } else {
133         TF_LITE_DIV(reference_ops, Div, int32_t);
134       }
135     } else {
136       if (data->requires_broadcast) {
137         TF_LITE_DIV(optimized_ops, BroadcastDivSlow, int32_t);
138       } else {
139         TF_LITE_DIV(optimized_ops, Div, int32_t);
140       }
141     }
142   } else if (output->type == kTfLiteFloat32) {
143     if (kernel_type == kReference) {
144       if (data->requires_broadcast) {
145         TF_LITE_DIV(reference_ops, BroadcastDivSlow, float);
146       } else {
147         TF_LITE_DIV(reference_ops, Div, float);
148       }
149     } else {
150       if (data->requires_broadcast) {
151         TF_LITE_DIV(optimized_ops, BroadcastDivSlow, float);
152       } else {
153         TF_LITE_DIV(optimized_ops, Div, float);
154       }
155     }
156   }
157 #undef TF_LITE_DIV
158 }
159 
160 template <KernelType kernel_type>
EvalQuantized(TfLiteContext * context,TfLiteNode * node,TfLiteDivParams * params,const OpData * data,const TfLiteTensor * input1,const TfLiteTensor * input2,TfLiteTensor * output)161 TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
162                            TfLiteDivParams* params, const OpData* data,
163                            const TfLiteTensor* input1,
164                            const TfLiteTensor* input2, TfLiteTensor* output) {
165   if (input1->type == kTfLiteUInt8 && input2->type == kTfLiteUInt8 &&
166       output->type == kTfLiteUInt8) {
167     tflite::ArithmeticParams op_params;
168     SetActivationParams(data->output_activation_min,
169                         data->output_activation_max, &op_params);
170     op_params.input1_offset = -input1->params.zero_point;
171     op_params.input2_offset = -input2->params.zero_point;
172     op_params.output_offset = output->params.zero_point;
173     op_params.output_multiplier = data->output_multiplier;
174     op_params.output_shift = data->output_shift;
175     bool need_broadcast = optimized_ops::ProcessBroadcastShapes(
176         GetTensorShape(input1), GetTensorShape(input2), &op_params);
177 #define TF_LITE_DIV(type, opname, dtype)                             \
178   type::opname(op_params, GetTensorShape(input1),                    \
179                GetTensorData<dtype>(input1), GetTensorShape(input2), \
180                GetTensorData<dtype>(input2), GetTensorShape(output), \
181                GetTensorData<dtype>(output))
182     if (kernel_type == kReference) {
183       if (need_broadcast) {
184         TF_LITE_DIV(reference_ops, BroadcastDivSlow, uint8_t);
185       } else {
186         TF_LITE_DIV(reference_ops, Div, uint8_t);
187       }
188     } else {
189       if (need_broadcast) {
190         TF_LITE_DIV(optimized_ops, BroadcastDivSlow, uint8_t);
191       } else {
192         TF_LITE_DIV(optimized_ops, Div, uint8_t);
193       }
194     }
195 #undef TF_LITE_DIV
196   } else {
197     TF_LITE_KERNEL_LOG(
198         context, "Unsupported combination of input and output types in Div.");
199     return kTfLiteError;
200   }
201   return kTfLiteOk;
202 }
203 
204 template <typename T>
CheckNonZero(TfLiteContext * context,const TfLiteTensor * tensor)205 TfLiteStatus CheckNonZero(TfLiteContext* context, const TfLiteTensor* tensor) {
206   const auto* data = GetTensorData<T>(tensor);
207   const size_t number_elements = tensor->bytes / sizeof(T);
208   for (size_t i = 0; i < number_elements; i++) {
209     TF_LITE_ENSURE(context, data[i] != 0);
210   }
211   return kTfLiteOk;
212 }
213 
214 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)215 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
216   auto* params = reinterpret_cast<TfLiteDivParams*>(node->builtin_data);
217   OpData* data = reinterpret_cast<OpData*>(node->user_data);
218 
219   const TfLiteTensor* input1;
220   TF_LITE_ENSURE_OK(context,
221                     GetInputSafe(context, node, kInputTensor1, &input1));
222   const TfLiteTensor* input2;
223   TF_LITE_ENSURE_OK(context,
224                     GetInputSafe(context, node, kInputTensor2, &input2));
225   TfLiteTensor* output;
226   TF_LITE_ENSURE_OK(context,
227                     GetOutputSafe(context, node, kOutputTensor, &output));
228 
229 
230   if (output->type == kTfLiteFloat32) {
231     // Div by zero seems ok in this case, just like in TF case infinities are
232     // returned. So we don't do a check at this point.
233     EvalDiv<kernel_type>(context, node, params, data, input1, input2, output);
234   } else if (output->type == kTfLiteInt32) {
235     CheckNonZero<int32_t>(context, input2);
236     EvalDiv<kernel_type>(context, node, params, data, input1, input2, output);
237   } else if (output->type == kTfLiteUInt8) {
238     CheckNonZero<uint8_t>(context, input2);
239     TF_LITE_ENSURE_OK(
240         context, EvalQuantized<kernel_type>(context, node, params, data, input1,
241                                             input2, output));
242   } else {
243     TF_LITE_KERNEL_LOG(
244         context,
245         "Div only supports FLOAT32, INT32 and quantized UINT8 now, got %d.",
246         output->type);
247     return kTfLiteError;
248   }
249 
250   return kTfLiteOk;
251 }
252 
253 }  // namespace div
254 
Register_DIV_REF()255 TfLiteRegistration* Register_DIV_REF() {
256   static TfLiteRegistration r = {div::Init, div::Free, div::Prepare,
257                                  div::Eval<div::kReference>};
258   return &r;
259 }
260 
Register_DIV_GENERIC_OPT()261 TfLiteRegistration* Register_DIV_GENERIC_OPT() {
262   static TfLiteRegistration r = {div::Init, div::Free, div::Prepare,
263                                  div::Eval<div::kGenericOptimized>};
264   return &r;
265 }
266 
Register_DIV_NEON_OPT()267 TfLiteRegistration* Register_DIV_NEON_OPT() {
268   static TfLiteRegistration r = {div::Init, div::Free, div::Prepare,
269                                  div::Eval<div::kNeonOptimized>};
270   return &r;
271 }
272 
Register_DIV()273 TfLiteRegistration* Register_DIV() {
274 #ifdef USE_NEON
275   return Register_DIV_NEON_OPT();
276 #else
277   return Register_DIV_GENERIC_OPT();
278 #endif
279 }
280 
281 }  // namespace builtin
282 }  // namespace ops
283 }  // namespace tflite
284