1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_
17
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/internal/common.h"
20 #include "tensorflow/lite/kernels/internal/types.h"
21
22 namespace tflite {
23
24 namespace reference_ops {
25
26 template <typename T>
EqualFn(T lhs,T rhs)27 inline bool EqualFn(T lhs, T rhs) {
28 return lhs == rhs;
29 }
30
31 template <typename T>
NotEqualFn(T lhs,T rhs)32 inline bool NotEqualFn(T lhs, T rhs) {
33 return lhs != rhs;
34 }
35
36 template <typename T>
GreaterFn(T lhs,T rhs)37 inline bool GreaterFn(T lhs, T rhs) {
38 return lhs > rhs;
39 }
40 template <typename T>
GreaterEqualFn(T lhs,T rhs)41 inline bool GreaterEqualFn(T lhs, T rhs) {
42 return lhs >= rhs;
43 }
44 template <typename T>
LessFn(T lhs,T rhs)45 inline bool LessFn(T lhs, T rhs) {
46 return lhs < rhs;
47 }
48 template <typename T>
LessEqualFn(T lhs,T rhs)49 inline bool LessEqualFn(T lhs, T rhs) {
50 return lhs <= rhs;
51 }
52
53 template <typename T>
54 using ComparisonFn = bool (*)(T, T);
55
56 template <typename T, ComparisonFn<T> F>
ComparisonImpl(const ComparisonParams & op_params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,bool * output_data)57 inline void ComparisonImpl(
58 const ComparisonParams& op_params, const RuntimeShape& input1_shape,
59 const T* input1_data, const RuntimeShape& input2_shape,
60 const T* input2_data, const RuntimeShape& output_shape, bool* output_data) {
61 const int64_t flatsize =
62 MatchingFlatSize(input1_shape, input2_shape, output_shape);
63 for (int64_t i = 0; i < flatsize; ++i) {
64 output_data[i] = F(input1_data[i], input2_data[i]);
65 }
66 }
67
68 template <ComparisonFn<float> F>
Comparison(const ComparisonParams & op_params,const RuntimeShape & input1_shape,const float * input1_data,const RuntimeShape & input2_shape,const float * input2_data,const RuntimeShape & output_shape,bool * output_data)69 inline void Comparison(const ComparisonParams& op_params,
70 const RuntimeShape& input1_shape,
71 const float* input1_data,
72 const RuntimeShape& input2_shape,
73 const float* input2_data,
74 const RuntimeShape& output_shape, bool* output_data) {
75 ComparisonImpl<float, F>(op_params, input1_shape, input1_data, input2_shape,
76 input2_data, output_shape, output_data);
77 }
78
79 template <typename T, ComparisonFn<int32_t> F>
ComparisonWithScaling(const ComparisonParams & op_params,const RuntimeShape & input1_shape,const T * input1_data,const RuntimeShape & input2_shape,const T * input2_data,const RuntimeShape & output_shape,bool * output_data)80 inline void ComparisonWithScaling(
81 const ComparisonParams& op_params, const RuntimeShape& input1_shape,
82 const T* input1_data, const RuntimeShape& input2_shape,
83 const T* input2_data, const RuntimeShape& output_shape, bool* output_data) {
84 int left_shift = op_params.left_shift;
85 int32_t input1_offset = op_params.input1_offset;
86 int32_t input1_multiplier = op_params.input1_multiplier;
87 int input1_shift = op_params.input1_shift;
88 int32_t input2_offset = op_params.input2_offset;
89 int32_t input2_multiplier = op_params.input2_multiplier;
90 int input2_shift = op_params.input2_shift;
91
92 const int64_t flatsize =
93 MatchingFlatSize(input1_shape, input2_shape, output_shape);
94 for (int64_t i = 0; i < flatsize; ++i) {
95 const int32_t input1_val = input1_offset + input1_data[i];
96 const int32_t input2_val = input2_offset + input2_data[i];
97 const int32_t shifted_input1_val = input1_val * (1 << left_shift);
98 const int32_t shifted_input2_val = input2_val * (1 << left_shift);
99 const int32_t scaled_input1_val =
100 MultiplyByQuantizedMultiplierSmallerThanOneExp(
101 shifted_input1_val, input1_multiplier, input1_shift);
102 const int32_t scaled_input2_val =
103 MultiplyByQuantizedMultiplierSmallerThanOneExp(
104 shifted_input2_val, input2_multiplier, input2_shift);
105 output_data[i] = F(scaled_input1_val, scaled_input2_val);
106 }
107 }
108
109 struct BroadcastComparison4DSlowCommon {
110 const RuntimeShape output_shape;
111 NdArrayDesc<4> desc1;
112 NdArrayDesc<4> desc2;
113 };
114
BroadcastComparison4DSlowPreprocess(const RuntimeShape & unextended_input1_shape,const RuntimeShape & unextended_input2_shape,const RuntimeShape & unextended_output_shape)115 inline BroadcastComparison4DSlowCommon BroadcastComparison4DSlowPreprocess(
116 const RuntimeShape& unextended_input1_shape,
117 const RuntimeShape& unextended_input2_shape,
118 const RuntimeShape& unextended_output_shape) {
119 TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
120 TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
121 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
122 NdArrayDesc<4> desc1;
123 NdArrayDesc<4> desc2;
124 NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
125 unextended_input2_shape, &desc1, &desc2);
126 return {RuntimeShape::ExtendedShape(4, unextended_output_shape), desc1,
127 desc2};
128 }
129
130 template <typename T, ComparisonFn<T> F>
BroadcastComparison4DSlowImpl(const ComparisonParams & op_params,const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const T * input2_data,const RuntimeShape & unextended_output_shape,bool * output_data)131 inline void BroadcastComparison4DSlowImpl(
132 const ComparisonParams& op_params,
133 const RuntimeShape& unextended_input1_shape, const T* input1_data,
134 const RuntimeShape& unextended_input2_shape, const T* input2_data,
135 const RuntimeShape& unextended_output_shape, bool* output_data) {
136 const BroadcastComparison4DSlowCommon dims =
137 BroadcastComparison4DSlowPreprocess(unextended_input1_shape,
138 unextended_input2_shape,
139 unextended_output_shape);
140
141 for (int b = 0; b < dims.output_shape.Dims(0); ++b) {
142 for (int y = 0; y < dims.output_shape.Dims(1); ++y) {
143 for (int x = 0; x < dims.output_shape.Dims(2); ++x) {
144 for (int c = 0; c < dims.output_shape.Dims(3); ++c) {
145 output_data[Offset(dims.output_shape, b, y, x, c)] =
146 F(input1_data[SubscriptToIndex(dims.desc1, b, y, x, c)],
147 input2_data[SubscriptToIndex(dims.desc2, b, y, x, c)]);
148 }
149 }
150 }
151 }
152 }
153
154 template <ComparisonFn<float> F>
BroadcastComparison4DSlow(const ComparisonParams & op_params,const RuntimeShape & input1_shape,const float * input1_data,const RuntimeShape & input2_shape,const float * input2_data,const RuntimeShape & output_shape,bool * output_data)155 inline void BroadcastComparison4DSlow(const ComparisonParams& op_params,
156 const RuntimeShape& input1_shape,
157 const float* input1_data,
158 const RuntimeShape& input2_shape,
159 const float* input2_data,
160 const RuntimeShape& output_shape,
161 bool* output_data) {
162 BroadcastComparison4DSlowImpl<float, F>(op_params, input1_shape, input1_data,
163 input2_shape, input2_data,
164 output_shape, output_data);
165 }
166
167 template <typename T, ComparisonFn<int32_t> F>
BroadcastComparison4DSlowWithScaling(const ComparisonParams & op_params,const RuntimeShape & unextended_input1_shape,const T * input1_data,const RuntimeShape & unextended_input2_shape,const T * input2_data,const RuntimeShape & unextended_output_shape,bool * output_data)168 inline void BroadcastComparison4DSlowWithScaling(
169 const ComparisonParams& op_params,
170 const RuntimeShape& unextended_input1_shape, const T* input1_data,
171 const RuntimeShape& unextended_input2_shape, const T* input2_data,
172 const RuntimeShape& unextended_output_shape, bool* output_data) {
173 const BroadcastComparison4DSlowCommon dims =
174 BroadcastComparison4DSlowPreprocess(unextended_input1_shape,
175 unextended_input2_shape,
176 unextended_output_shape);
177
178 int left_shift = op_params.left_shift;
179 int32_t input1_offset = op_params.input1_offset;
180 int32_t input1_multiplier = op_params.input1_multiplier;
181 int input1_shift = op_params.input1_shift;
182 int32_t input2_offset = op_params.input2_offset;
183 int32_t input2_multiplier = op_params.input2_multiplier;
184 int input2_shift = op_params.input2_shift;
185
186 for (int b = 0; b < dims.output_shape.Dims(0); ++b) {
187 for (int y = 0; y < dims.output_shape.Dims(1); ++y) {
188 for (int x = 0; x < dims.output_shape.Dims(2); ++x) {
189 for (int c = 0; c < dims.output_shape.Dims(3); ++c) {
190 const int32_t input1_val =
191 input1_offset +
192 input1_data[SubscriptToIndex(dims.desc1, b, y, x, c)];
193 const int32_t input2_val =
194 input2_offset +
195 input2_data[SubscriptToIndex(dims.desc2, b, y, x, c)];
196 const int32_t shifted_input1_val = input1_val * (1 << left_shift);
197 const int32_t shifted_input2_val = input2_val * (1 << left_shift);
198 const int32_t scaled_input1_val =
199 MultiplyByQuantizedMultiplierSmallerThanOneExp(
200 shifted_input1_val, input1_multiplier, input1_shift);
201 const int32_t scaled_input2_val =
202 MultiplyByQuantizedMultiplierSmallerThanOneExp(
203 shifted_input2_val, input2_multiplier, input2_shift);
204 output_data[Offset(dims.output_shape, b, y, x, c)] =
205 F(scaled_input1_val, scaled_input2_val);
206 }
207 }
208 }
209 }
210 }
211
212 #define TFLITE_COMPARISON_OP(name) \
213 inline void name(const ComparisonParams& op_params, \
214 const RuntimeShape& input1_shape, const float* input1_data, \
215 const RuntimeShape& input2_shape, const float* input2_data, \
216 const RuntimeShape& output_shape, bool* output_data) { \
217 Comparison<name##Fn>(op_params, input1_shape, input1_data, input2_shape, \
218 input2_data, output_shape, output_data); \
219 } \
220 template <typename T> \
221 inline void name##NoScaling( \
222 const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
223 const T* input1_data, const RuntimeShape& input2_shape, \
224 const T* input2_data, const RuntimeShape& output_shape, \
225 bool* output_data) { \
226 ComparisonImpl<T, name##Fn>(op_params, input1_shape, input1_data, \
227 input2_shape, input2_data, output_shape, \
228 output_data); \
229 } \
230 template <typename T> \
231 inline void name##WithScaling( \
232 const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
233 const T* input1_data, const RuntimeShape& input2_shape, \
234 const T* input2_data, const RuntimeShape& output_shape, \
235 bool* output_data) { \
236 ComparisonWithScaling<T, name##Fn>(op_params, input1_shape, input1_data, \
237 input2_shape, input2_data, \
238 output_shape, output_data); \
239 } \
240 template <typename T> \
241 inline void Broadcast4DSlow##name##NoScaling( \
242 const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
243 const T* input1_data, const RuntimeShape& input2_shape, \
244 const T* input2_data, const RuntimeShape& output_shape, \
245 bool* output_data) { \
246 BroadcastComparison4DSlowImpl<T, name##Fn>( \
247 op_params, input1_shape, input1_data, input2_shape, input2_data, \
248 output_shape, output_data); \
249 } \
250 inline void Broadcast4DSlow##name( \
251 const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
252 const float* input1_data, const RuntimeShape& input2_shape, \
253 const float* input2_data, const RuntimeShape& output_shape, \
254 bool* output_data) { \
255 BroadcastComparison4DSlow<name##Fn>(op_params, input1_shape, input1_data, \
256 input2_shape, input2_data, \
257 output_shape, output_data); \
258 } \
259 template <typename T> \
260 inline void Broadcast4DSlow##name##WithScaling( \
261 const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
262 const T* input1_data, const RuntimeShape& input2_shape, \
263 const T* input2_data, const RuntimeShape& output_shape, \
264 bool* output_data) { \
265 BroadcastComparison4DSlowWithScaling<T, name##Fn>( \
266 op_params, input1_shape, input1_data, input2_shape, input2_data, \
267 output_shape, output_data); \
268 }
269 TFLITE_COMPARISON_OP(Equal);
270 TFLITE_COMPARISON_OP(NotEqual);
271 TFLITE_COMPARISON_OP(Greater);
272 TFLITE_COMPARISON_OP(GreaterEqual);
273 TFLITE_COMPARISON_OP(Less);
274 TFLITE_COMPARISON_OP(LessEqual);
275 #undef TFLITE_COMPARISON_OP
276
277 } // namespace reference_ops
278 } // namespace tflite
279
280 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_COMPARISONS_H_
281