xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/internal/reference/comparisons.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_
17 
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/internal/common.h"
20 #include "tensorflow/lite/kernels/internal/types.h"
21 
22 namespace tflite {
23 
24 namespace reference_ops {
25 
26 template <typename T>
EqualFn(T lhs,T rhs)27 inline bool EqualFn(T lhs, T rhs) {
28   return lhs == rhs;
29 }
30 
31 template <typename T>
NotEqualFn(T lhs,T rhs)32 inline bool NotEqualFn(T lhs, T rhs) {
33   return lhs != rhs;
34 }
35 
36 template <typename T>
GreaterFn(T lhs,T rhs)37 inline bool GreaterFn(T lhs, T rhs) {
38   return lhs > rhs;
39 }
40 template <typename T>
GreaterEqualFn(T lhs,T rhs)41 inline bool GreaterEqualFn(T lhs, T rhs) {
42   return lhs >= rhs;
43 }
44 template <typename T>
LessFn(T lhs,T rhs)45 inline bool LessFn(T lhs, T rhs) {
46   return lhs < rhs;
47 }
48 template <typename T>
LessEqualFn(T lhs,T rhs)49 inline bool LessEqualFn(T lhs, T rhs) {
50   return lhs <= rhs;
51 }
52 
53 template <typename T>
54 using ComparisonFn = bool (*)(T, T);
55 
56 template <typename T, ComparisonFn<T> F>
ComparisonImpl(const ComparisonParams & op_params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,bool * output_data)57 inline void ComparisonImpl(
58     const ComparisonParams& op_params, const RuntimeShape& input1_shape,
59     const T* input1_data, const RuntimeShape& input2_shape,
60     const T* input2_data, const RuntimeShape& output_shape, bool* output_data) {
61   const int64_t flatsize =
62       MatchingFlatSize(input1_shape, input2_shape, output_shape);
63   for (int64_t i = 0; i < flatsize; ++i) {
64     output_data[i] = F(input1_data[i], input2_data[i]);
65   }
66 }
67 
68 template <ComparisonFn<float> F>
Comparison(const ComparisonParams & op_params,const RuntimeShape & input1_shape,const float * input1_data,const RuntimeShape & input2_shape,const float * input2_data,const RuntimeShape & output_shape,bool * output_data)69 inline void Comparison(const ComparisonParams& op_params,
70                        const RuntimeShape& input1_shape,
71                        const float* input1_data,
72                        const RuntimeShape& input2_shape,
73                        const float* input2_data,
74                        const RuntimeShape& output_shape, bool* output_data) {
75   ComparisonImpl<float, F>(op_params, input1_shape, input1_data, input2_shape,
76                            input2_data, output_shape, output_data);
77 }
78 
79 template <typename T, ComparisonFn<int32_t> F>
ComparisonWithScaling(const ComparisonParams & op_params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,bool * output_data)80 inline void ComparisonWithScaling(
81     const ComparisonParams& op_params, const RuntimeShape& input1_shape,
82     const T* input1_data, const RuntimeShape& input2_shape,
83     const T* input2_data, const RuntimeShape& output_shape, bool* output_data) {
84   int left_shift = op_params.left_shift;
85   int32_t input1_offset = op_params.input1_offset;
86   int32_t input1_multiplier = op_params.input1_multiplier;
87   int input1_shift = op_params.input1_shift;
88   int32_t input2_offset = op_params.input2_offset;
89   int32_t input2_multiplier = op_params.input2_multiplier;
90   int input2_shift = op_params.input2_shift;
91 
92   const int64_t flatsize =
93       MatchingFlatSize(input1_shape, input2_shape, output_shape);
94   for (int64_t i = 0; i < flatsize; ++i) {
95     const int32_t input1_val = input1_offset + input1_data[i];
96     const int32_t input2_val = input2_offset + input2_data[i];
97     const int32_t shifted_input1_val = input1_val * (1 << left_shift);
98     const int32_t shifted_input2_val = input2_val * (1 << left_shift);
99     const int32_t scaled_input1_val =
100         MultiplyByQuantizedMultiplierSmallerThanOneExp(
101             shifted_input1_val, input1_multiplier, input1_shift);
102     const int32_t scaled_input2_val =
103         MultiplyByQuantizedMultiplierSmallerThanOneExp(
104             shifted_input2_val, input2_multiplier, input2_shift);
105     output_data[i] = F(scaled_input1_val, scaled_input2_val);
106   }
107 }
108 
109 struct BroadcastComparison4DSlowCommon {
110   const RuntimeShape output_shape;
111   NdArrayDesc<4> desc1;
112   NdArrayDesc<4> desc2;
113 };
114 
BroadcastComparison4DSlowPreprocess(const RuntimeShape & unextended_input1_shape,const RuntimeShape & unextended_input2_shape,const RuntimeShape & unextended_output_shape)115 inline BroadcastComparison4DSlowCommon BroadcastComparison4DSlowPreprocess(
116     const RuntimeShape& unextended_input1_shape,
117     const RuntimeShape& unextended_input2_shape,
118     const RuntimeShape& unextended_output_shape) {
119   TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
120   TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
121   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
122   NdArrayDesc<4> desc1;
123   NdArrayDesc<4> desc2;
124   NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
125                                       unextended_input2_shape, &desc1, &desc2);
126   return {RuntimeShape::ExtendedShape(4, unextended_output_shape), desc1,
127           desc2};
128 }
129 
130 template <typename T, ComparisonFn<T> F>
BroadcastComparison4DSlowImpl(const ComparisonParams & op_params,const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const T * input2_data,const RuntimeShape & unextended_output_shape,bool * output_data)131 inline void BroadcastComparison4DSlowImpl(
132     const ComparisonParams& op_params,
133     const RuntimeShape& unextended_input1_shape, const T* input1_data,
134     const RuntimeShape& unextended_input2_shape, const T* input2_data,
135     const RuntimeShape& unextended_output_shape, bool* output_data) {
136   const BroadcastComparison4DSlowCommon dims =
137       BroadcastComparison4DSlowPreprocess(unextended_input1_shape,
138                                           unextended_input2_shape,
139                                           unextended_output_shape);
140 
141   for (int b = 0; b < dims.output_shape.Dims(0); ++b) {
142     for (int y = 0; y < dims.output_shape.Dims(1); ++y) {
143       for (int x = 0; x < dims.output_shape.Dims(2); ++x) {
144         for (int c = 0; c < dims.output_shape.Dims(3); ++c) {
145           output_data[Offset(dims.output_shape, b, y, x, c)] =
146               F(input1_data[SubscriptToIndex(dims.desc1, b, y, x, c)],
147                 input2_data[SubscriptToIndex(dims.desc2, b, y, x, c)]);
148         }
149       }
150     }
151   }
152 }
153 
154 template <ComparisonFn<float> F>
BroadcastComparison4DSlow(const ComparisonParams & op_params,const RuntimeShape & input1_shape,const float * input1_data,const RuntimeShape & input2_shape,const float * input2_data,const RuntimeShape & output_shape,bool * output_data)155 inline void BroadcastComparison4DSlow(const ComparisonParams& op_params,
156                                       const RuntimeShape& input1_shape,
157                                       const float* input1_data,
158                                       const RuntimeShape& input2_shape,
159                                       const float* input2_data,
160                                       const RuntimeShape& output_shape,
161                                       bool* output_data) {
162   BroadcastComparison4DSlowImpl<float, F>(op_params, input1_shape, input1_data,
163                                           input2_shape, input2_data,
164                                           output_shape, output_data);
165 }
166 
167 template <typename T, ComparisonFn<int32_t> F>
BroadcastComparison4DSlowWithScaling(const ComparisonParams & op_params,const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const T * input2_data,const RuntimeShape & unextended_output_shape,bool * output_data)168 inline void BroadcastComparison4DSlowWithScaling(
169     const ComparisonParams& op_params,
170     const RuntimeShape& unextended_input1_shape, const T* input1_data,
171     const RuntimeShape& unextended_input2_shape, const T* input2_data,
172     const RuntimeShape& unextended_output_shape, bool* output_data) {
173   const BroadcastComparison4DSlowCommon dims =
174       BroadcastComparison4DSlowPreprocess(unextended_input1_shape,
175                                           unextended_input2_shape,
176                                           unextended_output_shape);
177 
178   int left_shift = op_params.left_shift;
179   int32_t input1_offset = op_params.input1_offset;
180   int32_t input1_multiplier = op_params.input1_multiplier;
181   int input1_shift = op_params.input1_shift;
182   int32_t input2_offset = op_params.input2_offset;
183   int32_t input2_multiplier = op_params.input2_multiplier;
184   int input2_shift = op_params.input2_shift;
185 
186   for (int b = 0; b < dims.output_shape.Dims(0); ++b) {
187     for (int y = 0; y < dims.output_shape.Dims(1); ++y) {
188       for (int x = 0; x < dims.output_shape.Dims(2); ++x) {
189         for (int c = 0; c < dims.output_shape.Dims(3); ++c) {
190           const int32_t input1_val =
191               input1_offset +
192               input1_data[SubscriptToIndex(dims.desc1, b, y, x, c)];
193           const int32_t input2_val =
194               input2_offset +
195               input2_data[SubscriptToIndex(dims.desc2, b, y, x, c)];
196           const int32_t shifted_input1_val = input1_val * (1 << left_shift);
197           const int32_t shifted_input2_val = input2_val * (1 << left_shift);
198           const int32_t scaled_input1_val =
199               MultiplyByQuantizedMultiplierSmallerThanOneExp(
200                   shifted_input1_val, input1_multiplier, input1_shift);
201           const int32_t scaled_input2_val =
202               MultiplyByQuantizedMultiplierSmallerThanOneExp(
203                   shifted_input2_val, input2_multiplier, input2_shift);
204           output_data[Offset(dims.output_shape, b, y, x, c)] =
205               F(scaled_input1_val, scaled_input2_val);
206         }
207       }
208     }
209   }
210 }
211 
212 #define TFLITE_COMPARISON_OP(name)                                             \
213   inline void name(const ComparisonParams& op_params,                          \
214                    const RuntimeShape& input1_shape, const float* input1_data, \
215                    const RuntimeShape& input2_shape, const float* input2_data, \
216                    const RuntimeShape& output_shape, bool* output_data) {      \
217     Comparison<name##Fn>(op_params, input1_shape, input1_data, input2_shape,   \
218                          input2_data, output_shape, output_data);              \
219   }                                                                            \
220   template <typename T>                                                        \
221   inline void name##NoScaling(                                                 \
222       const ComparisonParams& op_params, const RuntimeShape& input1_shape,     \
223       const T* input1_data, const RuntimeShape& input2_shape,                  \
224       const T* input2_data, const RuntimeShape& output_shape,                  \
225       bool* output_data) {                                                     \
226     ComparisonImpl<T, name##Fn>(op_params, input1_shape, input1_data,          \
227                                 input2_shape, input2_data, output_shape,       \
228                                 output_data);                                  \
229   }                                                                            \
230   template <typename T>                                                        \
231   inline void name##WithScaling(                                               \
232       const ComparisonParams& op_params, const RuntimeShape& input1_shape,     \
233       const T* input1_data, const RuntimeShape& input2_shape,                  \
234       const T* input2_data, const RuntimeShape& output_shape,                  \
235       bool* output_data) {                                                     \
236     ComparisonWithScaling<T, name##Fn>(op_params, input1_shape, input1_data,   \
237                                        input2_shape, input2_data,              \
238                                        output_shape, output_data);             \
239   }                                                                            \
240   template <typename T>                                                        \
241   inline void Broadcast4DSlow##name##NoScaling(                                \
242       const ComparisonParams& op_params, const RuntimeShape& input1_shape,     \
243       const T* input1_data, const RuntimeShape& input2_shape,                  \
244       const T* input2_data, const RuntimeShape& output_shape,                  \
245       bool* output_data) {                                                     \
246     BroadcastComparison4DSlowImpl<T, name##Fn>(                                \
247         op_params, input1_shape, input1_data, input2_shape, input2_data,       \
248         output_shape, output_data);                                            \
249   }                                                                            \
250   inline void Broadcast4DSlow##name(                                           \
251       const ComparisonParams& op_params, const RuntimeShape& input1_shape,     \
252       const float* input1_data, const RuntimeShape& input2_shape,              \
253       const float* input2_data, const RuntimeShape& output_shape,              \
254       bool* output_data) {                                                     \
255     BroadcastComparison4DSlow<name##Fn>(op_params, input1_shape, input1_data,  \
256                                         input2_shape, input2_data,             \
257                                         output_shape, output_data);            \
258   }                                                                            \
259   template <typename T>                                                        \
260   inline void Broadcast4DSlow##name##WithScaling(                              \
261       const ComparisonParams& op_params, const RuntimeShape& input1_shape,     \
262       const T* input1_data, const RuntimeShape& input2_shape,                  \
263       const T* input2_data, const RuntimeShape& output_shape,                  \
264       bool* output_data) {                                                     \
265     BroadcastComparison4DSlowWithScaling<T, name##Fn>(                         \
266         op_params, input1_shape, input1_data, input2_shape, input2_data,       \
267         output_shape, output_data);                                            \
268   }
269 TFLITE_COMPARISON_OP(Equal);
270 TFLITE_COMPARISON_OP(NotEqual);
271 TFLITE_COMPARISON_OP(Greater);
272 TFLITE_COMPARISON_OP(GreaterEqual);
273 TFLITE_COMPARISON_OP(Less);
274 TFLITE_COMPARISON_OP(LessEqual);
275 #undef TFLITE_COMPARISON_OP
276 
277 }  // namespace reference_ops
278 }  // namespace tflite
279 
280 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_
281