1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_
17
18 #include <stdint.h>
19 #include <sys/types.h>
20
21 #include <algorithm>
22
23 #include "public/gemmlowp.h"
24 #include "tensorflow/lite/kernels/internal/common.h"
25 #include "tensorflow/lite/kernels/internal/legacy_types.h"
26 #include "tensorflow/lite/kernels/internal/reference/conv.h"
27 #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h"
28 #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
29 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
30 #include "tensorflow/lite/kernels/internal/reference/tanh.h"
31 #include "tensorflow/lite/kernels/internal/types.h"
32
33 namespace tflite {
34
35 namespace reference_ops {
36
37 static constexpr int kDepthwiseReverseShift = -1;
38
ShapeFromDims(const tflite::Dims<4> & dims,RuntimeShape * shape)39 inline void ShapeFromDims(const tflite::Dims<4>& dims, RuntimeShape* shape) {
40 shape->BuildFrom(
41 {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
42 }
43
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,int depth_multiplier,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)44 inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
45 const float* filter_data, const Dims<4>& filter_dims,
46 const float* bias_data, const Dims<4>& bias_dims,
47 int stride_width, int stride_height,
48 int dilation_width_factor, int dilation_height_factor,
49 int pad_width, int pad_height, int depth_multiplier,
50 float output_activation_min,
51 float output_activation_max, float* output_data,
52 const Dims<4>& output_dims) {
53 tflite::DepthwiseParams op_params;
54 // Padding type is ignored, but still set.
55 op_params.padding_type = PaddingType::kSame;
56 op_params.padding_values.width = pad_width;
57 op_params.padding_values.height = pad_height;
58 op_params.stride_width = stride_width;
59 op_params.stride_height = stride_height;
60 op_params.dilation_width_factor = dilation_width_factor;
61 op_params.dilation_height_factor = dilation_height_factor;
62 op_params.depth_multiplier = depth_multiplier;
63 op_params.float_activation_min = output_activation_min;
64 op_params.float_activation_max = output_activation_max;
65
66 DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
67 DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
68 bias_data, DimsToShape(output_dims), output_data);
69 }
70
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)71 inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
72 const float* filter_data, const Dims<4>& filter_dims,
73 const float* bias_data, const Dims<4>& bias_dims,
74 int stride_width, int stride_height, int pad_width,
75 int pad_height, int depth_multiplier,
76 float output_activation_min,
77 float output_activation_max, float* output_data,
78 const Dims<4>& output_dims) {
79 DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
80 bias_dims, stride_width, stride_height, 1, 1, pad_width,
81 pad_height, depth_multiplier, output_activation_min,
82 output_activation_max, output_data, output_dims);
83 }
84
85 // Legacy, for compatibility with old checked-in code.
86 template <FusedActivationFunctionType Ac>
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,float * output_data,const Dims<4> & output_dims)87 void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
88 const float* filter_data, const Dims<4>& filter_dims,
89 const float* bias_data, const Dims<4>& bias_dims,
90 int stride_width, int stride_height, int pad_width,
91 int pad_height, int depth_multiplier, float* output_data,
92 const Dims<4>& output_dims) {
93 float output_activation_min, output_activation_max;
94 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
95 DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
96 bias_dims, stride_width, stride_height, pad_width, pad_height,
97 depth_multiplier, output_activation_min, output_activation_max,
98 output_data, output_dims);
99 }
100
101 // Legacy, for compatibility with old checked-in code.
102 template <FusedActivationFunctionType Ac>
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,int depth_multiplier,float * output_data,const Dims<4> & output_dims)103 void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
104 const float* filter_data, const Dims<4>& filter_dims,
105 const float* bias_data, const Dims<4>& bias_dims, int stride,
106 int pad_width, int pad_height, int depth_multiplier,
107 float* output_data, const Dims<4>& output_dims) {
108 DepthwiseConv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
109 bias_dims, stride, stride, pad_width, pad_height,
110 depth_multiplier, output_data, output_dims);
111 }
112
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)113 inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
114 int32 input_offset, const uint8* filter_data,
115 const Dims<4>& filter_dims, int32 filter_offset,
116 const int32* bias_data, const Dims<4>& bias_dims,
117 int stride_width, int stride_height,
118 int dilation_width_factor, int dilation_height_factor,
119 int pad_width, int pad_height, int depth_multiplier,
120 int32 output_offset, int32 output_multiplier,
121 int output_shift, int32 output_activation_min,
122 int32 output_activation_max, uint8* output_data,
123 const Dims<4>& output_dims) {
124 tflite::DepthwiseParams op_params;
125 // Padding type is ignored, but still set.
126 op_params.padding_type = PaddingType::kSame;
127 op_params.padding_values.width = pad_width;
128 op_params.padding_values.height = pad_height;
129 op_params.stride_width = stride_width;
130 op_params.stride_height = stride_height;
131 op_params.dilation_width_factor = dilation_width_factor;
132 op_params.dilation_height_factor = dilation_height_factor;
133 op_params.depth_multiplier = depth_multiplier;
134 op_params.quantized_activation_min = output_activation_min;
135 op_params.quantized_activation_max = output_activation_max;
136 op_params.input_offset = input_offset;
137 op_params.weights_offset = filter_offset;
138 op_params.output_offset = output_offset;
139 op_params.output_multiplier = output_multiplier;
140 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
141 op_params.output_shift = kDepthwiseReverseShift * output_shift;
142
143 DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
144 DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
145 bias_data, DimsToShape(output_dims), output_data);
146 }
147
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)148 inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
149 int32 input_offset, const uint8* filter_data,
150 const Dims<4>& filter_dims, int32 filter_offset,
151 const int32* bias_data, const Dims<4>& bias_dims,
152 int stride_width, int stride_height, int pad_width,
153 int pad_height, int depth_multiplier,
154 int32 output_offset, int32 output_multiplier,
155 int output_shift, int32 output_activation_min,
156 int32 output_activation_max, uint8* output_data,
157 const Dims<4>& output_dims) {
158 DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
159 filter_offset, bias_data, bias_dims, stride_width,
160 stride_height, 1, 1, pad_width, pad_height, depth_multiplier,
161 output_offset, output_multiplier, output_shift,
162 output_activation_min, output_activation_max, output_data,
163 output_dims);
164 }
165
166 // Legacy, for compatibility with old checked-in code.
167 template <FusedActivationFunctionType Ac>
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)168 void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
169 int32 input_offset, const uint8* filter_data,
170 const Dims<4>& filter_dims, int32 filter_offset,
171 const int32* bias_data, const Dims<4>& bias_dims,
172 int stride_width, int stride_height, int pad_width,
173 int pad_height, int depth_multiplier, int32 output_offset,
174 int32 output_multiplier, int output_shift,
175 int32 output_activation_min, int32 output_activation_max,
176 uint8* output_data, const Dims<4>& output_dims) {
177 if (Ac == FusedActivationFunctionType::kNone) {
178 TFLITE_DCHECK_EQ(output_activation_min, 0);
179 TFLITE_DCHECK_EQ(output_activation_max, 255);
180 }
181 DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
182 filter_offset, bias_data, bias_dims, stride_width,
183 stride_height, pad_width, pad_height, depth_multiplier,
184 output_offset, output_multiplier, output_shift,
185 output_activation_min, output_activation_max, output_data,
186 output_dims);
187 }
188
189 // Legacy, for compatibility with old checked-in code.
190 template <FusedActivationFunctionType Ac>
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)191 void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
192 int32 input_offset, const uint8* filter_data,
193 const Dims<4>& filter_dims, int32 filter_offset,
194 const int32* bias_data, const Dims<4>& bias_dims, int stride,
195 int pad_width, int pad_height, int depth_multiplier,
196 int32 output_offset, int32 output_multiplier,
197 int output_shift, int32 output_activation_min,
198 int32 output_activation_max, uint8* output_data,
199 const Dims<4>& output_dims) {
200 DepthwiseConv<Ac>(input_data, input_dims, input_offset, filter_data,
201 filter_dims, filter_offset, bias_data, bias_dims, stride,
202 stride, pad_width, pad_height, depth_multiplier,
203 output_offset, output_multiplier, output_shift,
204 output_activation_min, output_activation_max, output_data,
205 output_dims);
206 }
207
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)208 inline void Conv(const float* input_data, const Dims<4>& input_dims,
209 const float* filter_data, const Dims<4>& filter_dims,
210 const float* bias_data, const Dims<4>& bias_dims,
211 int stride_width, int stride_height, int dilation_width_factor,
212 int dilation_height_factor, int pad_width, int pad_height,
213 float output_activation_min, float output_activation_max,
214 float* output_data, const Dims<4>& output_dims,
215 float* im2col_data, const Dims<4>& im2col_dims) {
216 tflite::ConvParams op_params;
217 // Padding type is ignored, but still set.
218 op_params.padding_type = PaddingType::kSame;
219 op_params.padding_values.width = pad_width;
220 op_params.padding_values.height = pad_height;
221 op_params.stride_width = stride_width;
222 op_params.stride_height = stride_height;
223 op_params.dilation_width_factor = dilation_width_factor;
224 op_params.dilation_height_factor = dilation_height_factor;
225 op_params.float_activation_min = output_activation_min;
226 op_params.float_activation_max = output_activation_max;
227
228 Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
229 filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
230 output_data, DimsToShape(im2col_dims), im2col_data);
231 }
232
233 template <FusedActivationFunctionType Ac>
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)234 void Conv(const float* input_data, const Dims<4>& input_dims,
235 const float* filter_data, const Dims<4>& filter_dims,
236 const float* bias_data, const Dims<4>& bias_dims, int stride_width,
237 int stride_height, int dilation_width_factor,
238 int dilation_height_factor, int pad_width, int pad_height,
239 float* output_data, const Dims<4>& output_dims, float* im2col_data,
240 const Dims<4>& im2col_dims) {
241 float output_activation_min, output_activation_max;
242 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
243 Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
244 stride_width, stride_height, dilation_width_factor,
245 dilation_height_factor, pad_width, pad_height, output_activation_min,
246 output_activation_max, output_data, output_dims, im2col_data,
247 im2col_dims);
248 }
249
250 // legacy, for compatibility with old checked-in code
251 template <FusedActivationFunctionType Ac>
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)252 void Conv(const float* input_data, const Dims<4>& input_dims,
253 const float* filter_data, const Dims<4>& filter_dims,
254 const float* bias_data, const Dims<4>& bias_dims, int stride_width,
255 int stride_height, int pad_width, int pad_height, float* output_data,
256 const Dims<4>& output_dims, float* im2col_data,
257 const Dims<4>& im2col_dims) {
258 float output_activation_min, output_activation_max;
259 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
260 Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
261 stride_width, stride_height, 1, 1, pad_width, pad_height,
262 output_activation_min, output_activation_max, output_data, output_dims,
263 im2col_data, im2col_dims);
264 }
265
266 // legacy, for compatibility with old checked-in code
267 template <FusedActivationFunctionType Ac>
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)268 void Conv(const float* input_data, const Dims<4>& input_dims,
269 const float* filter_data, const Dims<4>& filter_dims,
270 const float* bias_data, const Dims<4>& bias_dims, int stride,
271 int pad_width, int pad_height, float* output_data,
272 const Dims<4>& output_dims, float* im2col_data,
273 const Dims<4>& im2col_dims) {
274 Conv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
275 bias_dims, stride, stride, 1, 1, pad_width, pad_height, output_data,
276 output_dims, im2col_data, im2col_dims);
277 }
278
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemmlowp_context)279 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
280 int32 input_offset, const uint8* filter_data,
281 const Dims<4>& filter_dims, int32 filter_offset,
282 const int32* bias_data, const Dims<4>& bias_dims,
283 int stride_width, int stride_height, int dilation_width_factor,
284 int dilation_height_factor, int pad_width, int pad_height,
285 int32 output_offset, int32 output_multiplier, int output_shift,
286 int32 output_activation_min, int32 output_activation_max,
287 uint8* output_data, const Dims<4>& output_dims,
288 uint8* im2col_data, const Dims<4>& im2col_dims,
289 gemmlowp::GemmContext* gemmlowp_context) {
290 tflite::ConvParams op_params;
291 // Padding type is ignored, but still set.
292 op_params.padding_type = PaddingType::kSame;
293 op_params.padding_values.width = pad_width;
294 op_params.padding_values.height = pad_height;
295 op_params.stride_width = stride_width;
296 op_params.stride_height = stride_height;
297 op_params.dilation_width_factor = dilation_width_factor;
298 op_params.dilation_height_factor = dilation_height_factor;
299 op_params.input_offset = input_offset;
300 op_params.weights_offset = filter_offset;
301 op_params.output_offset = output_offset;
302 op_params.output_multiplier = output_multiplier;
303 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
304 op_params.output_shift = kReverseShift * output_shift;
305 op_params.quantized_activation_min = output_activation_min;
306 op_params.quantized_activation_max = output_activation_max;
307
308 Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
309 filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
310 output_data, DimsToShape(im2col_dims), im2col_data, gemmlowp_context);
311 }
312
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemmlowp_context)313 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
314 int32 input_offset, const uint8* filter_data,
315 const Dims<4>& filter_dims, int32 filter_offset,
316 const int32* bias_data, const Dims<4>& bias_dims,
317 int stride_width, int stride_height, int pad_width,
318 int pad_height, int32 output_offset, int32 output_multiplier,
319 int output_shift, int32 output_activation_min,
320 int32 output_activation_max, uint8* output_data,
321 const Dims<4>& output_dims, uint8* im2col_data,
322 const Dims<4>& im2col_dims,
323 gemmlowp::GemmContext* gemmlowp_context) {
324 Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
325 filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1,
326 pad_width, pad_height, output_offset, output_multiplier, output_shift,
327 output_activation_min, output_activation_max, output_data, output_dims,
328 im2col_data, im2col_dims, gemmlowp_context);
329 }
330
331 // legacy, for compatibility with old checked-in code
332 template <FusedActivationFunctionType Ac>
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemmlowp_context)333 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
334 int32 input_offset, const uint8* filter_data,
335 const Dims<4>& filter_dims, int32 filter_offset,
336 const int32* bias_data, const Dims<4>& bias_dims,
337 int stride_width, int stride_height, int pad_width,
338 int pad_height, int32 output_offset, int32 output_multiplier,
339 int output_shift, int32 output_activation_min,
340 int32 output_activation_max, uint8* output_data,
341 const Dims<4>& output_dims, uint8* im2col_data,
342 const Dims<4>& im2col_dims,
343 gemmlowp::GemmContext* gemmlowp_context) {
344 static_assert(Ac == FusedActivationFunctionType::kNone ||
345 Ac == FusedActivationFunctionType::kRelu ||
346 Ac == FusedActivationFunctionType::kRelu6 ||
347 Ac == FusedActivationFunctionType::kRelu1,
348 "");
349 if (Ac == FusedActivationFunctionType::kNone) {
350 TFLITE_DCHECK_EQ(output_activation_min, 0);
351 TFLITE_DCHECK_EQ(output_activation_max, 255);
352 }
353 Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
354 filter_offset, bias_data, bias_dims, stride_width, stride_height,
355 pad_width, pad_height, output_offset, output_multiplier, output_shift,
356 output_activation_min, output_activation_max, output_data, output_dims,
357 im2col_data, im2col_dims, gemmlowp_context);
358 }
359
360 // legacy, for compatibility with old checked-in code
361 template <FusedActivationFunctionType Ac>
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemmlowp_context)362 void Conv(const uint8* input_data, const Dims<4>& input_dims,
363 int32 input_offset, const uint8* filter_data,
364 const Dims<4>& filter_dims, int32 filter_offset,
365 const int32* bias_data, const Dims<4>& bias_dims, int stride,
366 int pad_width, int pad_height, int32 output_offset,
367 int32 output_multiplier, int output_shift,
368 int32 output_activation_min, int32 output_activation_max,
369 uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data,
370 const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemmlowp_context) {
371 Conv<Ac>(input_data, input_dims, input_offset, filter_data, filter_dims,
372 filter_offset, bias_data, bias_dims, stride, stride, pad_width,
373 pad_height, output_offset, output_multiplier, output_shift,
374 output_activation_min, output_activation_max, output_data,
375 output_dims, im2col_data, im2col_dims, gemmlowp_context);
376 }
377
TransposeConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,int stride_width,int stride_height,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)378 inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
379 const float* filter_data, const Dims<4>& filter_dims,
380 int stride_width, int stride_height, int pad_width,
381 int pad_height, float* output_data,
382 const Dims<4>& output_dims, float* im2col_data,
383 const Dims<4>& im2col_dims) {
384 tflite::ConvParams op_params;
385 // Padding type is ignored, but still set.
386 op_params.padding_type = PaddingType::kSame;
387 op_params.padding_values.width = pad_width;
388 op_params.padding_values.height = pad_height;
389 op_params.stride_width = stride_width;
390 op_params.stride_height = stride_height;
391
392 TransposeConv(op_params, DimsToShape(input_dims), input_data,
393 DimsToShape(filter_dims), filter_data,
394 /*bias_shape*/ RuntimeShape(), /*bias*/ nullptr,
395 DimsToShape(output_dims), output_data, DimsToShape(im2col_dims),
396 im2col_data);
397 }
398
TransposeConv(const ConvParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,float * im2col_data)399 inline void TransposeConv(
400 const ConvParams& params, const RuntimeShape& input_shape,
401 const float* input_data, const RuntimeShape& filter_shape,
402 const float* filter_data, const RuntimeShape& output_shape,
403 float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) {
404 TransposeConv(params, input_shape, input_data, filter_shape, filter_data,
405 /*bias_shape*/ RuntimeShape(), /*bias*/ nullptr, output_shape,
406 output_data, im2col_shape, im2col_data);
407 }
408
FullyConnected(const float * input_data,const Dims<4> & input_dims,const float * weights_data,const Dims<4> & weights_dims,const float * bias_data,const Dims<4> & bias_dims,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)409 inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
410 const float* weights_data,
411 const Dims<4>& weights_dims, const float* bias_data,
412 const Dims<4>& bias_dims,
413 float output_activation_min,
414 float output_activation_max, float* output_data,
415 const Dims<4>& output_dims) {
416 tflite::FullyConnectedParams op_params;
417 op_params.float_activation_min = output_activation_min;
418 op_params.float_activation_max = output_activation_max;
419
420 FullyConnected(op_params, DimsToShape(input_dims), input_data,
421 DimsToShape(weights_dims), weights_data,
422 DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
423 output_data);
424 }
425
426 // legacy, for compatibility with old checked-in code
427 template <FusedActivationFunctionType Ac>
FullyConnected(const float * input_data,const Dims<4> & input_dims,const float * weights_data,const Dims<4> & weights_dims,const float * bias_data,const Dims<4> & bias_dims,float * output_data,const Dims<4> & output_dims)428 void FullyConnected(const float* input_data, const Dims<4>& input_dims,
429 const float* weights_data, const Dims<4>& weights_dims,
430 const float* bias_data, const Dims<4>& bias_dims,
431 float* output_data, const Dims<4>& output_dims) {
432 float output_activation_min, output_activation_max;
433 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
434 FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data,
435 bias_dims, output_activation_min, output_activation_max,
436 output_data, output_dims);
437 }
438
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,gemmlowp::GemmContext *)439 inline void FullyConnected(
440 const FullyConnectedParams& params, const RuntimeShape& input_shape,
441 const uint8* input_data, const RuntimeShape& filter_shape,
442 const uint8* filter_data, const RuntimeShape& bias_shape,
443 const int32* bias_data, const RuntimeShape& output_shape,
444 uint8* output_data, gemmlowp::GemmContext*) {
445 FullyConnected(params, input_shape, input_data, filter_shape, filter_data,
446 bias_shape, bias_data, output_shape, output_data);
447 }
448
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,int16 * output_data,gemmlowp::GemmContext *)449 inline void FullyConnected(
450 const FullyConnectedParams& params, const RuntimeShape& input_shape,
451 const uint8* input_data, const RuntimeShape& filter_shape,
452 const uint8* filter_data, const RuntimeShape& bias_shape,
453 const int32* bias_data, const RuntimeShape& output_shape,
454 int16* output_data, gemmlowp::GemmContext*) {
455 FullyConnected(params, input_shape, input_data, filter_shape, filter_data,
456 bias_shape, bias_data, output_shape, output_data);
457 }
458
FullyConnected(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemmlowp_context)459 inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
460 int32 input_offset, const uint8* filter_data,
461 const Dims<4>& filter_dims, int32 filter_offset,
462 const int32* bias_data, const Dims<4>& bias_dims,
463 int32 output_offset, int32 output_multiplier,
464 int output_shift, int32 output_activation_min,
465 int32 output_activation_max, uint8* output_data,
466 const Dims<4>& output_dims,
467 gemmlowp::GemmContext* gemmlowp_context) {
468 tflite::FullyConnectedParams op_params;
469 op_params.input_offset = input_offset;
470 op_params.weights_offset = filter_offset;
471 op_params.output_offset = output_offset;
472 op_params.output_multiplier = output_multiplier;
473 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
474 op_params.output_shift = kReverseShift * output_shift;
475 op_params.quantized_activation_min = output_activation_min;
476 op_params.quantized_activation_max = output_activation_max;
477
478 FullyConnected(op_params, DimsToShape(input_dims), input_data,
479 DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
480 bias_data, DimsToShape(output_dims), output_data,
481 gemmlowp_context);
482 }
483
FullyConnected(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,int16 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemmlowp_context)484 inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
485 int32 input_offset, const uint8* filter_data,
486 const Dims<4>& filter_dims, int32 filter_offset,
487 const int32* bias_data, const Dims<4>& bias_dims,
488 int32 output_offset, int32 output_multiplier,
489 int output_shift, int32 output_activation_min,
490 int32 output_activation_max, int16* output_data,
491 const Dims<4>& output_dims,
492 gemmlowp::GemmContext* gemmlowp_context) {
493 tflite::FullyConnectedParams op_params;
494 op_params.input_offset = input_offset;
495 op_params.weights_offset = filter_offset;
496 op_params.output_offset = output_offset;
497 op_params.output_multiplier = output_multiplier;
498 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
499 op_params.output_shift = kReverseShift * output_shift;
500 op_params.quantized_activation_min = output_activation_min;
501 op_params.quantized_activation_max = output_activation_max;
502
503 FullyConnected(op_params, DimsToShape(input_dims), input_data,
504 DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
505 bias_data, DimsToShape(output_dims), output_data,
506 gemmlowp_context);
507 }
508
ShuffledFullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & weights_shape,const uint8 * shuffled_weights_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,int16 * output_data,uint8 * shuffled_input_workspace_data,gemmlowp::GemmContext *)509 inline void ShuffledFullyConnected(
510 const FullyConnectedParams& params, const RuntimeShape& input_shape,
511 const uint8* input_data, const RuntimeShape& weights_shape,
512 const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
513 const int32* bias_data, const RuntimeShape& output_shape,
514 int16* output_data, uint8* shuffled_input_workspace_data,
515 gemmlowp::GemmContext*) {
516 ShuffledFullyConnected(params, input_shape, input_data, weights_shape,
517 shuffled_weights_data, bias_shape, bias_data,
518 output_shape, output_data,
519 shuffled_input_workspace_data);
520 }
521
ShuffledFullyConnected(const uint8 * input_data,const Dims<4> & input_dims,const uint8 * shuffled_weights_data,const Dims<4> & weights_dims,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,int16 * output_data,const Dims<4> & output_dims,uint8 * shuffled_input_workspace_data,gemmlowp::GemmContext * gemmlowp_context)522 inline void ShuffledFullyConnected(
523 const uint8* input_data, const Dims<4>& input_dims,
524 const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
525 const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
526 int output_shift, int32 output_activation_min, int32 output_activation_max,
527 int16* output_data, const Dims<4>& output_dims,
528 uint8* shuffled_input_workspace_data,
529 gemmlowp::GemmContext* gemmlowp_context) {
530 tflite::FullyConnectedParams op_params;
531 op_params.output_multiplier = output_multiplier;
532 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
533 op_params.output_shift = kReverseShift * output_shift;
534 op_params.quantized_activation_min = output_activation_min;
535 op_params.quantized_activation_max = output_activation_max;
536
537 ShuffledFullyConnected(op_params, DimsToShape(input_dims), input_data,
538 DimsToShape(weights_dims), shuffled_weights_data,
539 DimsToShape(bias_dims), bias_data,
540 DimsToShape(output_dims), output_data,
541 shuffled_input_workspace_data, gemmlowp_context);
542 }
543
544 // legacy, for compatibility with old checked-in code
545 template <FusedActivationFunctionType Ac>
FullyConnected(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemmlowp_context)546 void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
547 int32 input_offset, const uint8* filter_data,
548 const Dims<4>& filter_dims, int32 filter_offset,
549 const int32* bias_data, const Dims<4>& bias_dims,
550 int32 output_offset, int32 output_multiplier,
551 int output_shift, int32 output_activation_min,
552 int32 output_activation_max, uint8* output_data,
553 const Dims<4>& output_dims,
554 gemmlowp::GemmContext* gemmlowp_context) {
555 static_assert(Ac == FusedActivationFunctionType::kNone ||
556 Ac == FusedActivationFunctionType::kRelu ||
557 Ac == FusedActivationFunctionType::kRelu6 ||
558 Ac == FusedActivationFunctionType::kRelu1,
559 "");
560 if (Ac == FusedActivationFunctionType::kNone) {
561 TFLITE_DCHECK_EQ(output_activation_min, 0);
562 TFLITE_DCHECK_EQ(output_activation_max, 255);
563 }
564 FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
565 filter_offset, bias_data, bias_dims, output_offset,
566 output_multiplier, output_shift, output_activation_min,
567 output_activation_max, output_data, output_dims,
568 gemmlowp_context);
569 }
570
LstmCell(const float * input_data,const Dims<4> & input_dims,const float * prev_activ_data,const Dims<4> & prev_activ_dims,const float * weights_data,const Dims<4> & weights_dims,const float * bias_data,const Dims<4> & bias_dims,const float * prev_state_data,const Dims<4> & prev_state_dims,float * output_state_data,const Dims<4> & output_state_dims,float * output_activ_data,const Dims<4> & output_activ_dims,float * concat_temp_data,const Dims<4> & concat_temp_dims,float * activ_temp_data,const Dims<4> & activ_temp_dims)571 inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
572 const float* prev_activ_data,
573 const Dims<4>& prev_activ_dims, const float* weights_data,
574 const Dims<4>& weights_dims, const float* bias_data,
575 const Dims<4>& bias_dims, const float* prev_state_data,
576 const Dims<4>& prev_state_dims, float* output_state_data,
577 const Dims<4>& output_state_dims, float* output_activ_data,
578 const Dims<4>& output_activ_dims, float* concat_temp_data,
579 const Dims<4>& concat_temp_dims, float* activ_temp_data,
580 const Dims<4>& activ_temp_dims) {
581 tflite::LstmCellParams op_params;
582 // Float LSTM cell does not need parameters to be set: leave untouched.
583
584 LstmCell(op_params, DimsToShape(input_dims), input_data,
585 DimsToShape(prev_activ_dims), prev_activ_data,
586 DimsToShape(weights_dims), weights_data, DimsToShape(bias_dims),
587 bias_data, DimsToShape(prev_state_dims), prev_state_data,
588 DimsToShape(output_state_dims), output_state_data,
589 DimsToShape(output_activ_dims), output_activ_data,
590 DimsToShape(concat_temp_dims), concat_temp_data,
591 DimsToShape(activ_temp_dims), activ_temp_data);
592 }
593
594 template <int StateIntegerBits>
LstmCell(const uint8 * input_data_uint8,const Dims<4> & input_dims,const uint8 * prev_activ_data_uint8,const Dims<4> & prev_activ_dims,const uint8 * weights_data_uint8,const Dims<4> & weights_dims,const int32 * bias_data_int32,const Dims<4> & bias_dims,const int16 * prev_state_data_int16,const Dims<4> & prev_state_dims,int16 * output_state_data_int16,const Dims<4> & output_state_dims,uint8 * output_activ_data_uint8,const Dims<4> & output_activ_dims,uint8 * concat_temp_data_uint8,const Dims<4> & concat_temp_dims,int16 * activ_temp_data_int16,const Dims<4> & activ_temp_dims,int32 weights_zero_point,int32 accum_multiplier,int accum_shift,gemmlowp::GemmContext * gemmlowp_context)595 void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
596 const uint8* prev_activ_data_uint8,
597 const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
598 const Dims<4>& weights_dims, const int32* bias_data_int32,
599 const Dims<4>& bias_dims, const int16* prev_state_data_int16,
600 const Dims<4>& prev_state_dims, int16* output_state_data_int16,
601 const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
602 const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
603 const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
604 const Dims<4>& activ_temp_dims, int32 weights_zero_point,
605 int32 accum_multiplier, int accum_shift,
606 gemmlowp::GemmContext* gemmlowp_context) {
607 tflite::LstmCellParams op_params;
608 op_params.weights_zero_point = weights_zero_point;
609 op_params.accum_multiplier = accum_multiplier;
610 op_params.accum_shift = accum_shift;
611
612 LstmCell<StateIntegerBits>(
613 op_params, DimsToShape(input_dims), input_data_uint8,
614 DimsToShape(prev_activ_dims), prev_activ_data_uint8,
615 DimsToShape(weights_dims), weights_data_uint8, DimsToShape(bias_dims),
616 bias_data_int32, DimsToShape(prev_state_dims), prev_state_data_int16,
617 DimsToShape(output_state_dims), output_state_data_int16,
618 DimsToShape(output_activ_dims), output_activ_data_uint8,
619 DimsToShape(concat_temp_dims), concat_temp_data_uint8,
620 DimsToShape(activ_temp_dims), activ_temp_data_int16, gemmlowp_context);
621 }
622
623 template <typename T>
BroadcastDiv(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)624 void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
625 const T* input2_data, const Dims<4>& input2_dims,
626 T output_activation_min, T output_activation_max,
627 T* output_data, const Dims<4>& output_dims) {
628 tflite::ArithmeticParams op_params;
629 SetActivationParams(output_activation_min, output_activation_max, &op_params);
630
631 BroadcastDivSlow(op_params, DimsToShape(input1_dims), input1_data,
632 DimsToShape(input2_dims), input2_data,
633 DimsToShape(output_dims), output_data);
634 }
635
636 template <typename T>
Div(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)637 inline void Div(const T* input1_data, const Dims<4>& input1_dims,
638 const T* input2_data, const Dims<4>& input2_dims,
639 T output_activation_min, T output_activation_max,
640 T* output_data, const Dims<4>& output_dims) {
641 tflite::ArithmeticParams op_params;
642 SetActivationParams(output_activation_min, output_activation_max, &op_params);
643
644 Div(op_params, DimsToShape(input1_dims), input1_data,
645 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
646 output_data);
647 }
648
649 template <FusedActivationFunctionType Ac, typename Scalar>
Concatenation(int concat_dim,const Scalar * const * input_data,const Dims<4> * const * input_dims,int inputs_count,Scalar * output_data,const Dims<4> & output_dims)650 inline void Concatenation(int concat_dim, const Scalar* const* input_data,
651 const Dims<4>* const* input_dims, int inputs_count,
652 Scalar* output_data, const Dims<4>& output_dims) {
653 // For now we don't have a model with a Concatenation with fused activation.
654 TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
655
656 std::vector<RuntimeShape> input_shapes(inputs_count);
657 std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
658 for (int i = 0; i < inputs_count; ++i) {
659 ShapeFromDims(*input_dims[i], &input_shapes[i]);
660 input_shapes_indirect[i] = &input_shapes[i];
661 }
662 tflite::ConcatenationParams op_params;
663 op_params.axis = 3 - concat_dim;
664 op_params.inputs_count = inputs_count;
665
666 Concatenation(op_params, input_shapes_indirect.data(), input_data,
667 DimsToShape(output_dims), output_data);
668 }
669
Concatenation(int concat_dim,const uint8 * const * input_data,const Dims<4> * const * input_dims,const int32 * input_zeropoint,const float * input_scale,int inputs_count,uint8 * output_data,const Dims<4> & output_dims,const int32 output_zeropoint,const float output_scale)670 inline void Concatenation(int concat_dim, const uint8* const* input_data,
671 const Dims<4>* const* input_dims,
672 const int32* input_zeropoint,
673 const float* input_scale, int inputs_count,
674 uint8* output_data, const Dims<4>& output_dims,
675 const int32 output_zeropoint,
676 const float output_scale) {
677 std::vector<RuntimeShape> input_shapes(inputs_count);
678 std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
679 for (int i = 0; i < inputs_count; ++i) {
680 ShapeFromDims(*input_dims[i], &input_shapes[i]);
681 input_shapes_indirect[i] = &input_shapes[i];
682 }
683 tflite::ConcatenationParams op_params;
684 op_params.axis = 3 - concat_dim;
685 op_params.input_zeropoint = input_zeropoint;
686 op_params.input_scale = input_scale;
687 op_params.inputs_count = inputs_count;
688 op_params.output_zeropoint = output_zeropoint;
689 op_params.output_scale = output_scale;
690
691 ConcatenationWithScaling(op_params, input_shapes_indirect.data(), input_data,
692 DimsToShape(output_dims), output_data);
693 }
694
695 template <FusedActivationFunctionType Ac, typename Scalar>
DepthConcatenation(const Scalar * const * input_data,const Dims<4> * const * input_dims,int inputs_count,Scalar * output_data,const Dims<4> & output_dims)696 void DepthConcatenation(const Scalar* const* input_data,
697 const Dims<4>* const* input_dims, int inputs_count,
698 Scalar* output_data, const Dims<4>& output_dims) {
699 // For now we don't have a model with a Concatenation with fused activation.
700 TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
701
702 std::vector<RuntimeShape> input_shapes(inputs_count);
703 std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
704 for (int i = 0; i < inputs_count; ++i) {
705 ShapeFromDims(*input_dims[i], &input_shapes[i]);
706 input_shapes_indirect[i] = &input_shapes[i];
707 }
708 tflite::ConcatenationParams op_params;
709 op_params.inputs_count = inputs_count;
710
711 DepthConcatenation(op_params, input_shapes_indirect.data(), input_data,
712 DimsToShape(output_dims), output_data);
713 }
714
715 template <typename Scalar>
TensorFlowSplit(const Scalar * input_data,const Dims<4> & input_dims,int axis,int outputs_count,Scalar * const * output_data,const Dims<4> * const * output_dims)716 void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
717 int axis, int outputs_count, Scalar* const* output_data,
718 const Dims<4>* const* output_dims) {
719 std::vector<RuntimeShape> output_shapes(outputs_count);
720 std::vector<const RuntimeShape*> output_shapes_indirect(outputs_count);
721 for (int i = 0; i < outputs_count; ++i) {
722 ShapeFromDims(*output_dims[i], &output_shapes[i]);
723 output_shapes_indirect[i] = &output_shapes[i];
724 }
725 tflite::SplitParams op_params;
726 op_params.axis = 3 - axis;
727 op_params.num_split = outputs_count;
728
729 Split(op_params, DimsToShape(input_dims), input_data,
730 output_shapes_indirect.data(), output_data);
731 }
732
733 template <FusedActivationFunctionType Ac, typename Scalar>
TensorFlowSplit(const Scalar * input_data,const Dims<4> & input_dims,int outputs_count,Scalar * const * output_data,const Dims<4> * const * output_dims)734 void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
735 int outputs_count, Scalar* const* output_data,
736 const Dims<4>* const* output_dims) {
737 TFLITE_DCHECK_GE(outputs_count, 1);
738 for (int i = 0; i < outputs_count; i++) {
739 /* batches = */ MatchingArraySize(*output_dims[i], 3, input_dims, 3);
740 /* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2);
741 /* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1);
742 }
743 // For now we don't have a model with a Split with fused activation.
744 TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
745
746 TensorFlowSplit(input_data, input_dims, /*axis=*/0, outputs_count,
747 output_data, output_dims);
748 }
749
Softmax(const float * input_data,const RuntimeShape & input_shape,float beta,float * output_data,const RuntimeShape & output_shape)750 inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
751 float beta, float* output_data,
752 const RuntimeShape& output_shape) {
753 SoftmaxParams params;
754 params.beta = beta;
755 Softmax(params, input_shape, input_data, output_shape, output_data);
756 }
757
Softmax(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_beta_multiplier,int32 input_beta_left_shift,int diff_min,uint8 * output_data,const RuntimeShape & output_shape)758 inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
759 int32 input_beta_multiplier, int32 input_beta_left_shift,
760 int diff_min, uint8* output_data,
761 const RuntimeShape& output_shape) {
762 SoftmaxParams params;
763 params.input_multiplier = input_beta_multiplier;
764 params.input_left_shift = input_beta_left_shift;
765 params.diff_min = diff_min;
766 Softmax(params, input_shape, input_data, output_shape, output_data);
767 }
768
LogSoftmax(const float * input_data,const RuntimeShape & input_shape,float * output_data,const RuntimeShape & output_shape)769 inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
770 float* output_data, const RuntimeShape& output_shape) {
771 SoftmaxParams params;
772 // No params currently used for float LogSoftmax.
773 LogSoftmax(params, input_shape, input_data, output_shape, output_data);
774 }
775
LogSoftmax(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_multiplier,int32 input_left_shift,int32 reverse_scaling_divisor,int32 reverse_scaling_right_shift,int diff_min,uint8 * output_data,const RuntimeShape & output_shape)776 inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
777 int32 input_multiplier, int32 input_left_shift,
778 int32 reverse_scaling_divisor,
779 int32 reverse_scaling_right_shift, int diff_min,
780 uint8* output_data, const RuntimeShape& output_shape) {
781 SoftmaxParams params;
782 params.input_multiplier = input_multiplier;
783 params.input_left_shift = input_left_shift;
784 params.reverse_scaling_divisor = reverse_scaling_divisor;
785 params.reverse_scaling_right_shift = reverse_scaling_right_shift;
786 params.diff_min = diff_min;
787 LogSoftmax(params, input_shape, input_data, output_shape, output_data);
788 }
789
Logistic(const LogisticParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)790 inline void Logistic(const LogisticParams& params,
791 const RuntimeShape& input_shape, const uint8* input_data,
792 const RuntimeShape& output_shape, uint8* output_data) {
793 const int32 input_zero_point = params.input_zero_point;
794 const int32 input_range_radius = params.input_range_radius;
795 const int32 input_multiplier = params.input_multiplier;
796 const int input_left_shift = params.input_left_shift;
797 const int flat_size = MatchingFlatSize(input_shape, output_shape);
798
799 for (int i = 0; i < flat_size; i++) {
800 const uint8 input_val_u8 = input_data[i];
801 const int32 input_val_centered =
802 static_cast<int32>(input_val_u8) - input_zero_point;
803 uint8 output_val;
804 if (input_val_centered <= -input_range_radius) {
805 output_val = 0;
806 } else if (input_val_centered >= input_range_radius) {
807 output_val = 255;
808 } else {
809 const int32 input_val_rescaled =
810 MultiplyByQuantizedMultiplierGreaterThanOne(
811 input_val_centered, input_multiplier, input_left_shift);
812 using FixedPoint4 = gemmlowp::FixedPoint<int32, 4>;
813 using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
814 const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
815 const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4);
816 // Convert from Q0.31 to Q23.8.
817 using gemmlowp::RoundingDivideByPOT;
818 int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 23);
819 if (output_val_s32 == 256) {
820 output_val_s32 = 255;
821 }
822 // Reinterpret as U0.8.
823 TFLITE_DCHECK_GE(output_val_s32, 0);
824 TFLITE_DCHECK_LE(output_val_s32, 255);
825 output_val = static_cast<uint8>(output_val_s32);
826 }
827 output_data[i] = output_val;
828 }
829 }
830
Logistic(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const RuntimeShape & output_shape)831 inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
832 int32 input_zero_point, int32 input_range_radius,
833 int32 input_multiplier, int input_left_shift,
834 uint8* output_data, const RuntimeShape& output_shape) {
835 LogisticParams params;
836 params.input_zero_point = input_zero_point;
837 params.input_range_radius = input_range_radius;
838 params.input_multiplier = input_multiplier;
839 params.input_left_shift = input_left_shift;
840 Logistic(params, input_shape, input_data, output_shape, output_data);
841 }
842
Logistic(const RuntimeShape & input_shape,const int16 * input_data,const RuntimeShape & output_shape,int16 * output_data)843 inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
844 const RuntimeShape& output_shape, int16* output_data) {
845 LogisticParams params;
846 // No params currently needed by int16 Logistic.
847 Logistic(params, input_shape, input_data, output_shape, output_data);
848 }
849
Tanh(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const RuntimeShape & output_shape)850 inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
851 int32 input_zero_point, int32 input_range_radius,
852 int32 input_multiplier, int input_left_shift,
853 uint8* output_data, const RuntimeShape& output_shape) {
854 TanhParams params;
855 params.input_zero_point = input_zero_point;
856 params.input_range_radius = input_range_radius;
857 params.input_multiplier = input_multiplier;
858 params.input_left_shift = input_left_shift;
859 Tanh(params, input_shape, input_data, output_shape, output_data);
860 }
861
Tanh(const int16 * input_data,const RuntimeShape & input_shape,int input_left_shift,int16 * output_data,const RuntimeShape & output_shape)862 inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
863 int input_left_shift, int16* output_data,
864 const RuntimeShape& output_shape) {
865 TanhParams params;
866 params.input_left_shift = input_left_shift;
867 Tanh(params, input_shape, input_data, output_shape, output_data);
868 }
869
Dequantize(const uint8 * input_data,const Dims<4> & input_dims,int32 zero_point,double scale,float * output_data,const Dims<4> & output_dims)870 inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
871 int32 zero_point, double scale, float* output_data,
872 const Dims<4>& output_dims) {
873 tflite::DequantizationParams op_params;
874 op_params.zero_point = zero_point;
875 op_params.scale = scale;
876
877 Dequantize(op_params, DimsToShape(input_dims), input_data,
878 DimsToShape(output_dims), output_data);
879 }
880
FakeQuant(const float * input_data,const Dims<4> & input_dims,float rmin,float rmax,int num_bits,float * output_data,const Dims<4> & output_dims)881 inline void FakeQuant(const float* input_data, const Dims<4>& input_dims,
882 float rmin, float rmax, int num_bits, float* output_data,
883 const Dims<4>& output_dims) {
884 tflite::FakeQuantParams op_params;
885 op_params.num_bits = num_bits;
886 op_params.minmax.min = rmin;
887 op_params.minmax.max = rmax;
888
889 FakeQuant(op_params, DimsToShape(input_dims), input_data,
890 DimsToShape(output_dims), output_data);
891 }
892
893 template <typename T>
Gather(const T * input_data,const Dims<4> & input_dims,int input_rank,const int32 * coords_data,const Dims<4> & coords_dims,T * output_data,const Dims<4> & output_dims)894 inline void Gather(const T* input_data, const Dims<4>& input_dims,
895 int input_rank, const int32* coords_data,
896 const Dims<4>& coords_dims, T* output_data,
897 const Dims<4>& output_dims) {
898 tflite::GatherParams op_params;
899 op_params.axis = 4 - input_rank;
900 op_params.batch_dims = 0;
901
902 Gather(op_params, DimsToShape(input_dims), input_data,
903 DimsToShape(coords_dims), coords_data, DimsToShape(output_dims),
904 output_data);
905 }
906
LegacyReverseBits32(uint32 n)907 inline uint32 LegacyReverseBits32(uint32 n) {
908 n = ((n >> 1) & 0x55555555) | ((n & 0x55555555) << 1);
909 n = ((n >> 2) & 0x33333333) | ((n & 0x33333333) << 2);
910 n = ((n >> 4) & 0x0F0F0F0F) | ((n & 0x0F0F0F0F) << 4);
911 return (((n & 0xFF) << 24) | ((n & 0xFF00) << 8) | ((n & 0xFF0000) >> 8) |
912 ((n & 0xFF000000) >> 24));
913 }
914
StridedSliceReverseIndices(tflite::StridedSliceParams * p)915 inline void StridedSliceReverseIndices(tflite::StridedSliceParams* p) {
916 TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count);
917 TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count);
918
919 std::reverse(p->start_indices, p->start_indices + p->start_indices_count);
920 std::reverse(p->stop_indices, p->stop_indices + p->stop_indices_count);
921 std::reverse(p->strides, p->strides + p->strides_count);
922
923 p->begin_mask = LegacyReverseBits32(static_cast<uint32>(p->begin_mask)) >>
924 (32 - p->start_indices_count);
925 p->ellipsis_mask =
926 LegacyReverseBits32(static_cast<uint32>(p->ellipsis_mask)) >>
927 (32 - p->start_indices_count);
928 p->end_mask = LegacyReverseBits32(static_cast<uint32>(p->end_mask)) >>
929 (32 - p->start_indices_count);
930 p->new_axis_mask =
931 LegacyReverseBits32(static_cast<uint32>(p->new_axis_mask)) >>
932 (32 - p->start_indices_count);
933 p->shrink_axis_mask =
934 LegacyReverseBits32(static_cast<uint32>(p->shrink_axis_mask)) >>
935 (32 - p->start_indices_count);
936 }
937
938 template <typename T>
StridedSlice(const T * input_data,const Dims<4> & input_dims,int begin_mask,int end_mask,int shrink_axis_mask,const std::vector<int> & start_indices,const std::vector<int> & stop_indices,const std::vector<int> & strides,T * output_data,const Dims<4> & output_dims)939 inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
940 int begin_mask, int end_mask, int shrink_axis_mask,
941 const std::vector<int>& start_indices,
942 const std::vector<int>& stop_indices,
943 const std::vector<int>& strides, T* output_data,
944 const Dims<4>& output_dims) {
945 TFLITE_DCHECK_EQ(start_indices.size(), 4);
946 auto op_params = strided_slice::BuildStridedSliceParams(
947 begin_mask, end_mask, shrink_axis_mask, start_indices, stop_indices,
948 strides);
949 StridedSliceReverseIndices(&op_params);
950
951 StridedSlice(op_params, DimsToShape(input_dims), input_data,
952 DimsToShape(output_dims), output_data);
953 }
954
955 template <typename T>
Mean(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & reduction_indices,T * output_data,const Dims<4> & output_dims)956 inline void Mean(const T* input_data, const Dims<4>& input_dims,
957 const std::vector<int>& reduction_indices, T* output_data,
958 const Dims<4>& output_dims) {
959 tflite::MeanParams op_params;
960 op_params.axis_count = reduction_indices.size();
961 for (int i = 0; i < op_params.axis_count; ++i) {
962 op_params.axis[i] = reduction_indices[op_params.axis_count - 1 - i];
963 }
964
965 Mean(op_params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
966 output_data);
967 }
968
969 template <typename T>
Transpose(const T * input,const Dims<4> & input_dims,T * output,const Dims<4> & output_dims,const int * permuted_axes)970 void Transpose(const T* input, const Dims<4>& input_dims, T* output,
971 const Dims<4>& output_dims, const int* permuted_axes) {
972 TransposeParams params;
973 params.perm_count = 4;
974 for (int i = 0; i < 4; ++i) {
975 params.perm[i] = 3 - permuted_axes[3 - i];
976 }
977 Transpose(params, DimsToShape(input_dims), input, DimsToShape(output_dims),
978 output);
979 }
980
981 template <typename T, ComparisonFn<T> F>
Comparison(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,bool * output_data,const Dims<4> & output_dims)982 inline void Comparison(const T* input1_data, const Dims<4>& input1_dims,
983 const T* input2_data, const Dims<4>& input2_dims,
984 bool* output_data, const Dims<4>& output_dims) {
985 ComparisonParams op_params;
986 // No parameters needed.
987 ComparisonImpl<T, F>(op_params, DimsToShape(input1_dims), input1_data,
988 DimsToShape(input2_dims), input2_data,
989 DimsToShape(output_dims), output_data);
990 }
991
992 template <typename T, ComparisonFn<int32> F>
Comparison(int left_shift,const T * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const T * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,bool * output_data,const Dims<4> & output_dims)993 inline void Comparison(int left_shift, const T* input1_data,
994 const Dims<4>& input1_dims, int32 input1_offset,
995 int32 input1_multiplier, int input1_shift,
996 const T* input2_data, const Dims<4>& input2_dims,
997 int32 input2_offset, int32 input2_multiplier,
998 int input2_shift, bool* output_data,
999 const Dims<4>& output_dims) {
1000 tflite::ComparisonParams op_params;
1001 op_params.left_shift = left_shift;
1002 op_params.input1_offset = input1_offset;
1003 op_params.input1_multiplier = input1_multiplier;
1004 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
1005 op_params.input1_shift = kReverseShift * input1_shift;
1006 op_params.input2_offset = input2_offset;
1007 op_params.input2_multiplier = input2_multiplier;
1008 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
1009 op_params.input2_shift = kReverseShift * input2_shift;
1010
1011 ComparisonWithScaling<T, F>(op_params, DimsToShape(input1_dims), input1_data,
1012 DimsToShape(input2_dims), input2_data,
1013 DimsToShape(output_dims), output_data);
1014 }
1015
1016 template <typename T, ComparisonFn<T> F>
BroadcastComparison(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,bool * output_data,const Dims<4> & output_dims)1017 inline void BroadcastComparison(const T* input1_data,
1018 const Dims<4>& input1_dims,
1019 const T* input2_data,
1020 const Dims<4>& input2_dims, bool* output_data,
1021 const Dims<4>& output_dims) {
1022 ComparisonParams op_params;
1023 // No parameters needed.
1024 BroadcastComparison4DSlowImpl<T, F>(op_params, DimsToShape(input1_dims),
1025 input1_data, DimsToShape(input2_dims),
1026 input2_data, DimsToShape(output_dims),
1027 output_data);
1028 }
1029
1030 template <typename T, ComparisonFn<int32> F>
BroadcastComparison(int left_shift,const T * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const T * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,bool * output_data,const Dims<4> & output_dims)1031 inline void BroadcastComparison(int left_shift, const T* input1_data,
1032 const Dims<4>& input1_dims, int32 input1_offset,
1033 int32 input1_multiplier, int input1_shift,
1034 const T* input2_data,
1035 const Dims<4>& input2_dims, int32 input2_offset,
1036 int32 input2_multiplier, int input2_shift,
1037 bool* output_data, const Dims<4>& output_dims) {
1038 ComparisonParams op_params;
1039
1040 op_params.left_shift = left_shift;
1041 op_params.input1_offset = input1_offset;
1042 op_params.input1_multiplier = input1_multiplier;
1043 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
1044 op_params.input1_shift = kReverseShift * input1_shift;
1045 op_params.input2_offset = input2_offset;
1046 op_params.input2_multiplier = input2_multiplier;
1047 // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
1048 op_params.input2_shift = kReverseShift * input2_shift;
1049
1050 BroadcastComparison4DSlowWithScaling<T, F>(
1051 op_params, DimsToShape(input1_dims), input1_data,
1052 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1053 output_data);
1054 }
1055
1056 #define TFLITE_LEGACY_COMPARISON_OP(name) \
1057 template <typename T> \
1058 inline void name(const T* input1_data, const Dims<4>& input1_dims, \
1059 const T* input2_data, const Dims<4>& input2_dims, \
1060 bool* output_data, const Dims<4>& output_dims) { \
1061 ruy::profiler::ScopeLabel label(#name); \
1062 Comparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
1063 input2_dims, output_data, output_dims); \
1064 } \
1065 template <typename T> \
1066 inline void name( \
1067 int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
1068 int32 input1_offset, int32 input1_multiplier, int input1_shift, \
1069 const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
1070 int32 input2_multiplier, int input2_shift, bool* output_data, \
1071 const Dims<4>& output_dims) { \
1072 ruy::profiler::ScopeLabel label(#name "/8bit"); \
1073 Comparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
1074 input1_offset, input1_multiplier, input1_shift, \
1075 input2_data, input2_dims, input2_offset, \
1076 input2_multiplier, input2_shift, output_data, \
1077 output_dims); \
1078 } \
1079 template <typename T> \
1080 inline void Broadcast##name( \
1081 const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, \
1082 const Dims<4>& input2_dims, bool* output_data, \
1083 const Dims<4>& output_dims) { \
1084 ruy::profiler::ScopeLabel label("Broadcast" #name); \
1085 BroadcastComparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
1086 input2_dims, output_data, output_dims); \
1087 } \
1088 template <typename T> \
1089 inline void Broadcast##name( \
1090 int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
1091 int32 input1_offset, int32 input1_multiplier, int input1_shift, \
1092 const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
1093 int32 input2_multiplier, int input2_shift, bool* output_data, \
1094 const Dims<4>& output_dims) { \
1095 ruy::profiler::ScopeLabel label("Broadcast" #name "/8bit"); \
1096 BroadcastComparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
1097 input1_offset, input1_multiplier, \
1098 input1_shift, input2_data, input2_dims, \
1099 input2_offset, input2_multiplier, \
1100 input2_shift, output_data, output_dims); \
1101 }
1102 TFLITE_LEGACY_COMPARISON_OP(Equal);
1103 TFLITE_LEGACY_COMPARISON_OP(NotEqual);
1104 TFLITE_LEGACY_COMPARISON_OP(Greater);
1105 TFLITE_LEGACY_COMPARISON_OP(GreaterEqual);
1106 TFLITE_LEGACY_COMPARISON_OP(Less);
1107 TFLITE_LEGACY_COMPARISON_OP(LessEqual);
1108 #undef TFLITE_LEGACY_COMPARISON_OP
1109
1110 template <typename D, typename T>
Select(const D * input_condition_data,const Dims<4> & input_condition_dims,const T * input_x_data,const Dims<4> & input_x_dims,const T * input_y_data,const Dims<4> & input_y_dims,T * output_data,const Dims<4> & output_dims)1111 inline void Select(const D* input_condition_data,
1112 const Dims<4>& input_condition_dims, const T* input_x_data,
1113 const Dims<4>& input_x_dims, const T* input_y_data,
1114 const Dims<4>& input_y_dims, T* output_data,
1115 const Dims<4>& output_dims) {
1116 Select(DimsToShape(input_condition_dims), input_condition_data,
1117 DimsToShape(input_x_dims), input_x_data, DimsToShape(input_y_dims),
1118 input_y_data, DimsToShape(output_dims), output_data);
1119 }
1120
1121 template <typename D, typename T>
RankOneSelect(const D * input_condition_data,const Dims<4> & input_condition_dims,const T * input_x_data,const Dims<4> & input_x_dims,const T * input_y_data,const Dims<4> & input_y_dims,T * output_data,const Dims<4> & output_dims)1122 inline void RankOneSelect(const D* input_condition_data,
1123 const Dims<4>& input_condition_dims,
1124 const T* input_x_data, const Dims<4>& input_x_dims,
1125 const T* input_y_data, const Dims<4>& input_y_dims,
1126 T* output_data, const Dims<4>& output_dims) {
1127 RankOneSelect(DimsToShape(input_condition_dims), input_condition_data,
1128 DimsToShape(input_x_dims), input_x_data,
1129 DimsToShape(input_y_dims), input_y_data,
1130 DimsToShape(output_dims), output_data);
1131 }
1132
1133 template <typename T, typename TI>
SparseToDense(const std::vector<std::vector<TI>> & indices,const T * values,T default_value,T * output_data,const Dims<4> & output_dims,bool value_is_scalar)1134 inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
1135 const T* values, T default_value, T* output_data,
1136 const Dims<4>& output_dims, bool value_is_scalar) {
1137 SparseToDense(indices, values, default_value, value_is_scalar,
1138 DimsToShape(output_dims), output_data);
1139 }
1140
1141 template <typename Scalar>
Pack(int dim,const Scalar * const * input_data,const Dims<4> * const * input_dims,int inputs_count,Scalar * output_data,const Dims<4> & output_dims)1142 void Pack(int dim, const Scalar* const* input_data,
1143 const Dims<4>* const* input_dims, int inputs_count,
1144 Scalar* output_data, const Dims<4>& output_dims) {
1145 std::vector<RuntimeShape> input_shapes(inputs_count);
1146 std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
1147 for (int i = 0; i < inputs_count; ++i) {
1148 ShapeFromDims(*input_dims[i], &input_shapes[i]);
1149 input_shapes_indirect[i] = &input_shapes[i];
1150 }
1151 tflite::PackParams op_params;
1152 op_params.axis = 3 - dim;
1153 op_params.inputs_count = inputs_count;
1154
1155 Pack(op_params, input_shapes_indirect.data(), input_data,
1156 DimsToShape(output_dims), output_data);
1157 }
1158
1159 template <typename Scalar>
Unpack(int axis,const Scalar * input_data,const Dims<4> & input_dims,int dimensions,int outputs_count,Scalar * const * output_datas,const Dims<4> & output_dims)1160 void Unpack(int axis, const Scalar* input_data, const Dims<4>& input_dims,
1161 int dimensions, int outputs_count, Scalar* const* output_datas,
1162 const Dims<4>& output_dims) {
1163 tflite::UnpackParams op_params;
1164 op_params.axis = 3 - axis;
1165 op_params.num_split = outputs_count;
1166
1167 Unpack(op_params, DimsToShape(input_dims), input_data,
1168 DimsToShape(output_dims), output_datas);
1169 }
1170
1171 template <typename Scalar>
Pack(int dim,const Scalar * const * input_data,const Dims<4> * const * input_dims,const int32 * input_zeropoint,const float * input_scale,int inputs_count,Scalar * output_data,const Dims<4> & output_dims,const int32 output_zeropoint,const float output_scale)1172 void Pack(int dim, const Scalar* const* input_data,
1173 const Dims<4>* const* input_dims, const int32* input_zeropoint,
1174 const float* input_scale, int inputs_count, Scalar* output_data,
1175 const Dims<4>& output_dims, const int32 output_zeropoint,
1176 const float output_scale) {
1177 std::vector<RuntimeShape> input_shapes(inputs_count);
1178 std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
1179 for (int i = 0; i < inputs_count; ++i) {
1180 ShapeFromDims(*input_dims[i], &input_shapes[i]);
1181 input_shapes_indirect[i] = &input_shapes[i];
1182 }
1183 tflite::PackParams op_params;
1184 op_params.axis = 3 - dim;
1185 op_params.input_zeropoint = input_zeropoint;
1186 op_params.input_scale = input_scale;
1187 op_params.inputs_count = inputs_count;
1188 op_params.output_zeropoint = output_zeropoint;
1189 op_params.output_scale = output_scale;
1190
1191 PackWithScaling(op_params, input_shapes_indirect.data(), input_data,
1192 DimsToShape(output_dims), output_data);
1193 }
1194
1195 template <FusedActivationFunctionType Ac>
L2Normalization(const float * input_data,const RuntimeShape & input_shape,float * output_data,const RuntimeShape & output_shape)1196 void L2Normalization(const float* input_data, const RuntimeShape& input_shape,
1197 float* output_data, const RuntimeShape& output_shape) {
1198 static_assert(Ac == FusedActivationFunctionType::kNone, "");
1199 tflite::L2NormalizationParams op_params;
1200 // No params need to be set for float.
1201
1202 L2Normalization(op_params, input_shape, input_data, output_shape,
1203 output_data);
1204 }
1205
L2Normalization(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_zero_point,uint8 * output_data,const RuntimeShape & output_shape)1206 inline void L2Normalization(const uint8* input_data,
1207 const RuntimeShape& input_shape,
1208 int32 input_zero_point, uint8* output_data,
1209 const RuntimeShape& output_shape) {
1210 tflite::L2NormalizationParams op_params;
1211 op_params.input_zero_point = input_zero_point;
1212
1213 L2Normalization(op_params, input_shape, input_data, output_shape,
1214 output_data);
1215 }
1216
1217 template <FusedActivationFunctionType Ac>
L2Normalization(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1218 void L2Normalization(const float* input_data, const Dims<4>& input_dims,
1219 float* output_data, const Dims<4>& output_dims) {
1220 L2Normalization<Ac>(input_data, DimsToShape(input_dims), output_data,
1221 DimsToShape(output_dims));
1222 }
1223
L2Normalization(const uint8 * input_data,const Dims<4> & input_dims,int32 input_zero_point,uint8 * output_data,const Dims<4> & output_dims)1224 inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
1225 int32 input_zero_point, uint8* output_data,
1226 const Dims<4>& output_dims) {
1227 L2Normalization(input_data, DimsToShape(input_dims), input_zero_point,
1228 output_data, DimsToShape(output_dims));
1229 }
1230
Relu(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1231 inline void Relu(const float* input_data, const Dims<4>& input_dims,
1232 float* output_data, const Dims<4>& output_dims) {
1233 Relu(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1234 output_data);
1235 }
1236
Relu1(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1237 inline void Relu1(const float* input_data, const Dims<4>& input_dims,
1238 float* output_data, const Dims<4>& output_dims) {
1239 Relu1(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1240 output_data);
1241 }
1242
Relu6(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1243 inline void Relu6(const float* input_data, const Dims<4>& input_dims,
1244 float* output_data, const Dims<4>& output_dims) {
1245 Relu6(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1246 output_data);
1247 }
1248
ReluX(uint8 min_value,uint8 max_value,const uint8 * input_data,const RuntimeShape & input_shape,uint8 * output_data,const RuntimeShape & output_shape)1249 inline void ReluX(uint8 min_value, uint8 max_value, const uint8* input_data,
1250 const RuntimeShape& input_shape, uint8* output_data,
1251 const RuntimeShape& output_shape) {
1252 tflite::ActivationParams params;
1253 params.quantized_activation_max = max_value;
1254 params.quantized_activation_min = min_value;
1255 ReluX(params, input_shape, input_data, output_shape, output_data);
1256 }
1257
1258 template <FusedActivationFunctionType Ac>
Add(int left_shift,const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1259 inline void Add(int left_shift, const uint8* input1_data,
1260 const Dims<4>& input1_dims, int32 input1_offset,
1261 int32 input1_multiplier, int input1_shift,
1262 const uint8* input2_data, const Dims<4>& input2_dims,
1263 int32 input2_offset, int32 input2_multiplier, int input2_shift,
1264 int32 output_offset, int32 output_multiplier, int output_shift,
1265 int32 output_activation_min, int32 output_activation_max,
1266 uint8* output_data, const Dims<4>& output_dims) {
1267 constexpr int kReverseShift = -1;
1268 static_assert(Ac == FusedActivationFunctionType::kNone ||
1269 Ac == FusedActivationFunctionType::kRelu ||
1270 Ac == FusedActivationFunctionType::kRelu6 ||
1271 Ac == FusedActivationFunctionType::kRelu1,
1272 "");
1273 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
1274 if (Ac == FusedActivationFunctionType::kNone) {
1275 TFLITE_DCHECK_EQ(output_activation_min, 0);
1276 TFLITE_DCHECK_EQ(output_activation_max, 255);
1277 }
1278
1279 tflite::ArithmeticParams op_params;
1280 op_params.left_shift = left_shift;
1281 op_params.input1_offset = input1_offset;
1282 op_params.input1_multiplier = input1_multiplier;
1283 op_params.input1_shift = kReverseShift * input1_shift;
1284 op_params.input2_offset = input2_offset;
1285 op_params.input2_multiplier = input2_multiplier;
1286 op_params.input2_shift = kReverseShift * input2_shift;
1287 op_params.output_offset = output_offset;
1288 op_params.output_multiplier = output_multiplier;
1289 op_params.output_shift = kReverseShift * output_shift;
1290 op_params.quantized_activation_min = output_activation_min;
1291 op_params.quantized_activation_max = output_activation_max;
1292 Add(op_params, DimsToShape(input1_dims), input1_data,
1293 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1294 output_data);
1295 }
1296
1297 template <FusedActivationFunctionType Ac>
Add(const int32 * input1_data,const Dims<4> & input1_dims,const int32 * input2_data,const Dims<4> & input2_dims,int32 * output_data,const Dims<4> & output_dims)1298 void Add(const int32* input1_data, const Dims<4>& input1_dims,
1299 const int32* input2_data, const Dims<4>& input2_dims,
1300 int32* output_data, const Dims<4>& output_dims) {
1301 ruy::profiler::ScopeLabel label("Add/int32");
1302 TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
1303
1304 tflite::ArithmeticParams op_params;
1305 op_params.quantized_activation_min = std::numeric_limits<int32>::min();
1306 op_params.quantized_activation_max = std::numeric_limits<int32>::max();
1307 Add(op_params, DimsToShape(input1_dims), input1_data,
1308 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1309 output_data);
1310 }
1311
1312 template <FusedActivationFunctionType Ac>
BroadcastAdd(int left_shift,const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1313 inline void BroadcastAdd(int left_shift, const uint8* input1_data,
1314 const Dims<4>& input1_dims, int32 input1_offset,
1315 int32 input1_multiplier, int input1_shift,
1316 const uint8* input2_data, const Dims<4>& input2_dims,
1317 int32 input2_offset, int32 input2_multiplier,
1318 int input2_shift, int32 output_offset,
1319 int32 output_multiplier, int output_shift,
1320 int32 output_activation_min,
1321 int32 output_activation_max, uint8* output_data,
1322 const Dims<4>& output_dims) {
1323 constexpr int kReverseShift = -1;
1324 static_assert(Ac == FusedActivationFunctionType::kNone ||
1325 Ac == FusedActivationFunctionType::kRelu ||
1326 Ac == FusedActivationFunctionType::kRelu6 ||
1327 Ac == FusedActivationFunctionType::kRelu1,
1328 "");
1329 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
1330 if (Ac == FusedActivationFunctionType::kNone) {
1331 TFLITE_DCHECK_EQ(output_activation_min, 0);
1332 TFLITE_DCHECK_EQ(output_activation_max, 255);
1333 }
1334
1335 tflite::ArithmeticParams op_params;
1336 op_params.left_shift = left_shift;
1337 op_params.input1_offset = input1_offset;
1338 op_params.input1_multiplier = input1_multiplier;
1339 op_params.input1_shift = kReverseShift * input1_shift;
1340 op_params.input2_offset = input2_offset;
1341 op_params.input2_multiplier = input2_multiplier;
1342 op_params.input2_shift = kReverseShift * input2_shift;
1343 op_params.output_offset = output_offset;
1344 op_params.output_multiplier = output_multiplier;
1345 op_params.output_shift = kReverseShift * output_shift;
1346 op_params.quantized_activation_min = output_activation_min;
1347 op_params.quantized_activation_max = output_activation_max;
1348 BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
1349 DimsToShape(input2_dims), input2_data,
1350 DimsToShape(output_dims), output_data);
1351 }
1352
1353 template <FusedActivationFunctionType Ac>
Add(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)1354 void Add(const float* input1_data, const Dims<4>& input1_dims,
1355 const float* input2_data, const Dims<4>& input2_dims,
1356 float* output_data, const Dims<4>& output_dims) {
1357 float output_activation_min, output_activation_max;
1358 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1359
1360 tflite::ArithmeticParams op_params;
1361 op_params.float_activation_min = output_activation_min;
1362 op_params.float_activation_max = output_activation_max;
1363 Add(op_params, DimsToShape(input1_dims), input1_data,
1364 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1365 output_data);
1366 }
1367
1368 template <typename T>
BroadcastAdd(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)1369 void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
1370 const T* input2_data, const Dims<4>& input2_dims,
1371 T output_activation_min, T output_activation_max,
1372 T* output_data, const Dims<4>& output_dims) {
1373 tflite::ArithmeticParams op_params;
1374 op_params.float_activation_min = output_activation_min;
1375 op_params.float_activation_max = output_activation_max;
1376 BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
1377 DimsToShape(input2_dims), input2_data,
1378 DimsToShape(output_dims), output_data);
1379 }
1380
1381 template <FusedActivationFunctionType Ac>
BroadcastAddFivefold(int y0,int y1,int y2,int y3,int y4,int left_shift,const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1382 inline void BroadcastAddFivefold(
1383 int y0, int y1, int y2, int y3, int y4, int left_shift,
1384 const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset,
1385 int32 input1_multiplier, int input1_shift, const uint8* input2_data,
1386 const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier,
1387 int input2_shift, int32 output_offset, int32 output_multiplier,
1388 int output_shift, int32 output_activation_min, int32 output_activation_max,
1389 uint8* output_data, const Dims<4>& output_dims) {
1390 constexpr int kReverseShift = -1;
1391 static_assert(Ac == FusedActivationFunctionType::kNone ||
1392 Ac == FusedActivationFunctionType::kRelu ||
1393 Ac == FusedActivationFunctionType::kRelu6 ||
1394 Ac == FusedActivationFunctionType::kRelu1,
1395 "");
1396 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
1397 if (Ac == FusedActivationFunctionType::kNone) {
1398 TFLITE_DCHECK_EQ(output_activation_min, 0);
1399 TFLITE_DCHECK_EQ(output_activation_max, 255);
1400 }
1401 tflite::ArithmeticParams op_params;
1402 op_params.broadcast_category =
1403 tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
1404 op_params.left_shift = left_shift;
1405 op_params.input1_offset = input1_offset;
1406 op_params.input1_multiplier = input1_multiplier;
1407 op_params.input1_shift = kReverseShift * input1_shift;
1408 op_params.input2_offset = input2_offset;
1409 op_params.input2_multiplier = input2_multiplier;
1410 op_params.input2_shift = kReverseShift * input2_shift;
1411 op_params.output_offset = output_offset;
1412 op_params.output_multiplier = output_multiplier;
1413 op_params.output_shift = kReverseShift * output_shift;
1414 op_params.quantized_activation_min = output_activation_min;
1415 op_params.quantized_activation_max = output_activation_max;
1416 op_params.broadcast_shape[4] = y0;
1417 op_params.broadcast_shape[3] = y1;
1418 op_params.broadcast_shape[2] = y2;
1419 op_params.broadcast_shape[1] = y3;
1420 op_params.broadcast_shape[0] = y4;
1421 BroadcastAddFivefold(op_params, DimsToShape(input1_dims), input1_data,
1422 DimsToShape(input2_dims), input2_data,
1423 DimsToShape(output_dims), output_data);
1424 }
1425
1426 // legacy, for compatibility with old checked-in code
1427 template <FusedActivationFunctionType Ac, typename T>
BroadcastAdd(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)1428 void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
1429 const T* input2_data, const Dims<4>& input2_dims,
1430 T* output_data, const Dims<4>& output_dims) {
1431 T output_activation_min, output_activation_max;
1432 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1433
1434 BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims,
1435 output_activation_min, output_activation_max, output_data,
1436 output_dims);
1437 }
1438
1439 template <FusedActivationFunctionType Ac>
Add(const int16 * input1_data,const Dims<4> & input1_dims,int input1_shift,const int16 * input2_data,const Dims<4> & input2_dims,int input2_shift,int16 output_activation_min,int16 output_activation_max,int16 * output_data,const Dims<4> & output_dims)1440 inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
1441 int input1_shift, const int16* input2_data,
1442 const Dims<4>& input2_dims, int input2_shift,
1443 int16 output_activation_min, int16 output_activation_max,
1444 int16* output_data, const Dims<4>& output_dims) {
1445 static_assert(Ac == FusedActivationFunctionType::kNone ||
1446 Ac == FusedActivationFunctionType::kRelu ||
1447 Ac == FusedActivationFunctionType::kRelu6 ||
1448 Ac == FusedActivationFunctionType::kRelu1,
1449 "");
1450 TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
1451 if (Ac == FusedActivationFunctionType::kNone) {
1452 TFLITE_DCHECK_EQ(output_activation_min, -32768);
1453 TFLITE_DCHECK_EQ(output_activation_max, 32767);
1454 }
1455
1456 tflite::ArithmeticParams op_params;
1457 op_params.input1_shift = kReverseShift * input1_shift;
1458 op_params.input2_shift = kReverseShift * input2_shift;
1459 op_params.quantized_activation_min = output_activation_min;
1460 op_params.quantized_activation_max = output_activation_max;
1461 Add(op_params, DimsToShape(input1_dims), input1_data,
1462 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1463 output_data);
1464 }
1465
Sub(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)1466 inline void Sub(const float* input1_data, const Dims<4>& input1_dims,
1467 const float* input2_data, const Dims<4>& input2_dims,
1468 float* output_data, const Dims<4>& output_dims) {
1469 float output_activation_min, output_activation_max;
1470 GetActivationMinMax(FusedActivationFunctionType::kNone,
1471 &output_activation_min, &output_activation_max);
1472 tflite::ArithmeticParams op_params;
1473 op_params.float_activation_min = output_activation_min;
1474 op_params.float_activation_max = output_activation_max;
1475 Sub(op_params, DimsToShape(input1_dims), input1_data,
1476 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1477 output_data);
1478 }
1479
1480 template <typename T>
Sub(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)1481 void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data,
1482 const Dims<4>& input2_dims, T* output_data,
1483 const Dims<4>& output_dims) {
1484 tflite::ArithmeticParams op_params;
1485 op_params.quantized_activation_min = std::numeric_limits<T>::min();
1486 op_params.quantized_activation_max = std::numeric_limits<T>::max();
1487 Sub(op_params, DimsToShape(input1_dims), input1_data,
1488 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1489 output_data);
1490 }
1491
AveragePool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)1492 inline bool AveragePool(const float* input_data, const Dims<4>& input_dims,
1493 int stride_width, int stride_height, int pad_width,
1494 int pad_height, int kwidth, int kheight,
1495 float output_activation_min,
1496 float output_activation_max, float* output_data,
1497 const Dims<4>& output_dims) {
1498 tflite::PoolParams params;
1499 params.stride_height = stride_height;
1500 params.stride_width = stride_width;
1501 params.filter_height = kheight;
1502 params.filter_width = kwidth;
1503 params.padding_values.height = pad_height;
1504 params.padding_values.width = pad_width;
1505 params.float_activation_min = output_activation_min;
1506 params.float_activation_max = output_activation_max;
1507 return AveragePool(params, DimsToShape(input_dims), input_data,
1508 DimsToShape(output_dims), output_data);
1509 }
1510
1511 // Transitional version that will be moved shortly to legacy_reference_ops, as
1512 // part of RuntimeShape revisions.
BroadcastMul4DSlow(const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1513 inline void BroadcastMul4DSlow(const uint8* input1_data,
1514 const Dims<4>& input1_dims, int32 input1_offset,
1515 const uint8* input2_data,
1516 const Dims<4>& input2_dims, int32 input2_offset,
1517 int32 output_offset, int32 output_multiplier,
1518 int output_shift, int32 output_activation_min,
1519 int32 output_activation_max, uint8* output_data,
1520 const Dims<4>& output_dims) {
1521 tflite::ArithmeticParams op_params;
1522 SetActivationParams(output_activation_min, output_activation_max, &op_params);
1523 op_params.input1_offset = input1_offset;
1524 op_params.input2_offset = input2_offset;
1525 op_params.output_offset = output_offset;
1526 op_params.output_multiplier = output_multiplier;
1527 op_params.output_shift = output_shift;
1528
1529 BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
1530 DimsToShape(input2_dims), input2_data,
1531 DimsToShape(output_dims), output_data);
1532 }
1533
BroadcastMul(const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1534 inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
1535 int32 input1_offset, const uint8* input2_data,
1536 const Dims<4>& input2_dims, int32 input2_offset,
1537 int32 output_offset, int32 output_multiplier,
1538 int output_shift, int32 output_activation_min,
1539 int32 output_activation_max, uint8* output_data,
1540 const Dims<4>& output_dims) {
1541 BroadcastMul4DSlow(
1542 input1_data, input1_dims, input1_offset, input2_data, input2_dims,
1543 input2_offset, output_offset, output_multiplier,
1544 //
1545 kReverseShift * output_shift,
1546 //
1547 output_activation_min, output_activation_max, output_data, output_dims);
1548 }
1549
1550 // legacy, for compatibility with old checked-in code
1551 template <FusedActivationFunctionType Ac>
BroadcastMul(const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1552 inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
1553 int32 input1_offset, const uint8* input2_data,
1554 const Dims<4>& input2_dims, int32 input2_offset,
1555 int32 output_offset, int32 output_multiplier,
1556 int output_shift, int32 output_activation_min,
1557 int32 output_activation_max, uint8* output_data,
1558 const Dims<4>& output_dims) {
1559 BroadcastMul(input1_data, input1_dims, input1_offset, input2_data,
1560 input2_dims, input2_offset, output_offset, output_multiplier,
1561 output_shift, output_activation_min, output_activation_max,
1562 output_data, output_dims);
1563 }
1564
1565 // legacy, for compatibility with old checked-in code
1566 template <FusedActivationFunctionType Ac>
AveragePool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float * output_data,const Dims<4> & output_dims)1567 bool AveragePool(const float* input_data, const Dims<4>& input_dims,
1568 int stride_width, int stride_height, int pad_width,
1569 int pad_height, int kwidth, int kheight, float* output_data,
1570 const Dims<4>& output_dims) {
1571 float output_activation_min, output_activation_max;
1572 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1573
1574 return AveragePool(input_data, input_dims, stride_width, stride_height,
1575 pad_width, pad_height, kwidth, kheight,
1576 output_activation_min, output_activation_max, output_data,
1577 output_dims);
1578 }
1579
1580 // legacy, for compatibility with old checked-in code
1581 template <FusedActivationFunctionType Ac>
AveragePool(const float * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)1582 bool AveragePool(const float* input_data, const Dims<4>& input_dims, int stride,
1583 int pad_width, int pad_height, int filter_width,
1584 int filter_height, float* output_data,
1585 const Dims<4>& output_dims) {
1586 return AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width,
1587 pad_height, filter_width, filter_height, output_data,
1588 output_dims);
1589 }
1590
AveragePool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1591 inline bool AveragePool(const uint8* input_data, const Dims<4>& input_dims,
1592 int stride_width, int stride_height, int pad_width,
1593 int pad_height, int filter_width, int filter_height,
1594 int32 output_activation_min,
1595 int32 output_activation_max, uint8* output_data,
1596 const Dims<4>& output_dims) {
1597 tflite::PoolParams params;
1598 params.stride_height = stride_height;
1599 params.stride_width = stride_width;
1600 params.filter_height = filter_height;
1601 params.filter_width = filter_width;
1602 params.padding_values.height = pad_height;
1603 params.padding_values.width = pad_width;
1604 params.quantized_activation_min = output_activation_min;
1605 params.quantized_activation_max = output_activation_max;
1606 return AveragePool(params, DimsToShape(input_dims), input_data,
1607 DimsToShape(output_dims), output_data);
1608 }
1609
1610 // legacy, for compatibility with old checked-in code
1611 template <FusedActivationFunctionType Ac>
AveragePool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1612 bool AveragePool(const uint8* input_data, const Dims<4>& input_dims,
1613 int stride_width, int stride_height, int pad_width,
1614 int pad_height, int filter_width, int filter_height,
1615 int32 output_activation_min, int32 output_activation_max,
1616 uint8* output_data, const Dims<4>& output_dims) {
1617 static_assert(Ac == FusedActivationFunctionType::kNone ||
1618 Ac == FusedActivationFunctionType::kRelu ||
1619 Ac == FusedActivationFunctionType::kRelu6 ||
1620 Ac == FusedActivationFunctionType::kRelu1,
1621 "");
1622 if (Ac == FusedActivationFunctionType::kNone) {
1623 TFLITE_DCHECK_EQ(output_activation_min, 0);
1624 TFLITE_DCHECK_EQ(output_activation_max, 255);
1625 }
1626 return AveragePool(input_data, input_dims, stride_width, stride_height,
1627 pad_width, pad_height, filter_width, filter_height,
1628 output_activation_min, output_activation_max, output_data,
1629 output_dims);
1630 }
1631
1632 // legacy, for compatibility with old checked-in code
1633 template <FusedActivationFunctionType Ac>
AveragePool(const uint8 * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1634 bool AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride,
1635 int pad_width, int pad_height, int filter_width,
1636 int filter_height, int32 output_activation_min,
1637 int32 output_activation_max, uint8* output_data,
1638 const Dims<4>& output_dims) {
1639 return AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width,
1640 pad_height, filter_width, filter_height,
1641 output_activation_min, output_activation_max,
1642 output_data, output_dims);
1643 }
1644
MaxPool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)1645 inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
1646 int stride_width, int stride_height, int pad_width,
1647 int pad_height, int kwidth, int kheight,
1648 float output_activation_min, float output_activation_max,
1649 float* output_data, const Dims<4>& output_dims) {
1650 tflite::PoolParams params;
1651 params.stride_height = stride_height;
1652 params.stride_width = stride_width;
1653 params.filter_height = kheight;
1654 params.filter_width = kwidth;
1655 params.padding_values.height = pad_height;
1656 params.padding_values.width = pad_width;
1657 params.float_activation_min = output_activation_min;
1658 params.float_activation_max = output_activation_max;
1659 MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1660 output_data);
1661 }
1662
1663 // legacy, for compatibility with old checked-in code
1664 template <FusedActivationFunctionType Ac>
MaxPool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float * output_data,const Dims<4> & output_dims)1665 void MaxPool(const float* input_data, const Dims<4>& input_dims,
1666 int stride_width, int stride_height, int pad_width, int pad_height,
1667 int kwidth, int kheight, float* output_data,
1668 const Dims<4>& output_dims) {
1669 float output_activation_min, output_activation_max;
1670 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1671 MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
1672 pad_height, kwidth, kheight, output_activation_min,
1673 output_activation_max, output_data, output_dims);
1674 }
1675
1676 // legacy, for compatibility with old checked-in code
1677 template <FusedActivationFunctionType Ac>
MaxPool(const float * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)1678 void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride,
1679 int pad_width, int pad_height, int filter_width, int filter_height,
1680 float* output_data, const Dims<4>& output_dims) {
1681 MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
1682 filter_width, filter_height, output_data, output_dims);
1683 }
1684
MaxPool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1685 inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
1686 int stride_width, int stride_height, int pad_width,
1687 int pad_height, int filter_width, int filter_height,
1688 int32 output_activation_min, int32 output_activation_max,
1689 uint8* output_data, const Dims<4>& output_dims) {
1690 PoolParams params;
1691 params.stride_height = stride_height;
1692 params.stride_width = stride_width;
1693 params.filter_height = filter_height;
1694 params.filter_width = filter_width;
1695 params.padding_values.height = pad_height;
1696 params.padding_values.width = pad_width;
1697 params.quantized_activation_min = output_activation_min;
1698 params.quantized_activation_max = output_activation_max;
1699 MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1700 output_data);
1701 }
1702
1703 // legacy, for compatibility with old checked-in code
1704 template <FusedActivationFunctionType Ac>
MaxPool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1705 void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
1706 int stride_width, int stride_height, int pad_width, int pad_height,
1707 int filter_width, int filter_height, int32 output_activation_min,
1708 int32 output_activation_max, uint8* output_data,
1709 const Dims<4>& output_dims) {
1710 static_assert(Ac == FusedActivationFunctionType::kNone ||
1711 Ac == FusedActivationFunctionType::kRelu ||
1712 Ac == FusedActivationFunctionType::kRelu6 ||
1713 Ac == FusedActivationFunctionType::kRelu1,
1714 "");
1715 if (Ac == FusedActivationFunctionType::kNone) {
1716 TFLITE_DCHECK_EQ(output_activation_min, 0);
1717 TFLITE_DCHECK_EQ(output_activation_max, 255);
1718 }
1719 MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
1720 pad_height, filter_width, filter_height, output_activation_min,
1721 output_activation_max, output_data, output_dims);
1722 }
1723
1724 // legacy, for compatibility with old checked-in code
1725 template <FusedActivationFunctionType Ac>
MaxPool(const uint8 * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1726 void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride,
1727 int pad_width, int pad_height, int filter_width, int filter_height,
1728 int32 output_activation_min, int32 output_activation_max,
1729 uint8* output_data, const Dims<4>& output_dims) {
1730 MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
1731 filter_width, filter_height, output_activation_min,
1732 output_activation_max, output_data, output_dims);
1733 }
1734
L2Pool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)1735 inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
1736 int stride_width, int stride_height, int pad_width,
1737 int pad_height, int filter_width, int filter_height,
1738 float output_activation_min, float output_activation_max,
1739 float* output_data, const Dims<4>& output_dims) {
1740 PoolParams params;
1741 params.stride_height = stride_height;
1742 params.stride_width = stride_width;
1743 params.filter_height = filter_height;
1744 params.filter_width = filter_width;
1745 params.padding_values.height = pad_height;
1746 params.padding_values.width = pad_width;
1747 params.float_activation_min = output_activation_min;
1748 params.float_activation_max = output_activation_max;
1749 L2Pool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1750 output_data);
1751 }
1752
1753 // legacy, for compatibility with old checked-in code
1754 template <FusedActivationFunctionType Ac>
L2Pool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)1755 void L2Pool(const float* input_data, const Dims<4>& input_dims,
1756 int stride_width, int stride_height, int pad_width, int pad_height,
1757 int filter_width, int filter_height, float* output_data,
1758 const Dims<4>& output_dims) {
1759 float output_activation_min, output_activation_max;
1760 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1761 L2Pool(input_data, input_dims, stride_width, stride_height, pad_width,
1762 pad_height, filter_width, filter_height, output_activation_min,
1763 output_activation_max, output_data, output_dims);
1764 }
1765
1766 // legacy, for compatibility with old checked-in code
1767 template <FusedActivationFunctionType Ac>
L2Pool(const float * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)1768 void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride,
1769 int pad_width, int pad_height, int filter_width, int filter_height,
1770 float* output_data, const Dims<4>& output_dims) {
1771 L2Pool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
1772 filter_width, filter_height, output_data, output_dims);
1773 }
1774
Softmax(const float * input_data,const Dims<4> & input_dims,float beta,float * output_data,const Dims<4> & output_dims)1775 inline void Softmax(const float* input_data, const Dims<4>& input_dims,
1776 float beta, float* output_data,
1777 const Dims<4>& output_dims) {
1778 Softmax(input_data, DimsToShape(input_dims), beta, output_data,
1779 DimsToShape(output_dims));
1780 }
1781
Softmax(const uint8 * input_data,const Dims<4> & input_dims,int32 input_beta_multiplier,int32 input_beta_left_shift,int diff_min,uint8 * output_data,const Dims<4> & output_dims)1782 inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
1783 int32 input_beta_multiplier, int32 input_beta_left_shift,
1784 int diff_min, uint8* output_data,
1785 const Dims<4>& output_dims) {
1786 Softmax(input_data, DimsToShape(input_dims), input_beta_multiplier,
1787 input_beta_left_shift, diff_min, output_data,
1788 DimsToShape(output_dims));
1789 }
1790
LogSoftmax(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1791 inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims,
1792 float* output_data, const Dims<4>& output_dims) {
1793 LogSoftmax(input_data, DimsToShape(input_dims), output_data,
1794 DimsToShape(output_dims));
1795 }
1796
LogSoftmax(const uint8 * input_data,const Dims<4> & input_dims,int32 input_multiplier,int32 input_left_shift,int32 reverse_scaling_divisor,int32 reverse_scaling_right_shift,int diff_min,uint8 * output_data,const Dims<4> & output_dims)1797 inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
1798 int32 input_multiplier, int32 input_left_shift,
1799 int32 reverse_scaling_divisor,
1800 int32 reverse_scaling_right_shift, int diff_min,
1801 uint8* output_data, const Dims<4>& output_dims) {
1802 LogSoftmax(input_data, DimsToShape(input_dims), input_multiplier,
1803 input_left_shift, reverse_scaling_divisor,
1804 reverse_scaling_right_shift, diff_min, output_data,
1805 DimsToShape(output_dims));
1806 }
1807
Logistic(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1808 inline void Logistic(const float* input_data, const Dims<4>& input_dims,
1809 float* output_data, const Dims<4>& output_dims) {
1810 Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1811 output_data);
1812 }
1813
Logistic(const uint8 * input_data,const Dims<4> & input_dims,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const Dims<4> & output_dims)1814 inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
1815 int32 input_zero_point, int32 input_range_radius,
1816 int32 input_multiplier, int input_left_shift,
1817 uint8* output_data, const Dims<4>& output_dims) {
1818 Logistic(input_data, DimsToShape(input_dims), input_zero_point,
1819 input_range_radius, input_multiplier, input_left_shift, output_data,
1820 DimsToShape(output_dims));
1821 }
1822
Logistic(const int16 * input_data,const Dims<4> & input_dims,int16 * output_data,const Dims<4> & output_dims)1823 inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
1824 int16* output_data, const Dims<4>& output_dims) {
1825 Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1826 output_data);
1827 }
1828
Tanh(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1829 inline void Tanh(const float* input_data, const Dims<4>& input_dims,
1830 float* output_data, const Dims<4>& output_dims) {
1831 Tanh(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1832 output_data);
1833 }
1834
Tanh(const uint8 * input_data,const Dims<4> & input_dims,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const Dims<4> & output_dims)1835 inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
1836 int32 input_zero_point, int32 input_range_radius,
1837 int32 input_multiplier, int input_left_shift,
1838 uint8* output_data, const Dims<4>& output_dims) {
1839 Tanh(input_data, DimsToShape(input_dims), input_zero_point,
1840 input_range_radius, input_multiplier, input_left_shift, output_data,
1841 DimsToShape(output_dims));
1842 }
1843
Tanh(const int16 * input_data,const Dims<4> & input_dims,int input_left_shift,int16 * output_data,const Dims<4> & output_dims)1844 inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
1845 int input_left_shift, int16* output_data,
1846 const Dims<4>& output_dims) {
1847 Tanh(input_data, DimsToShape(input_dims), input_left_shift, output_data,
1848 DimsToShape(output_dims));
1849 }
1850
1851 template <typename T>
DepthToSpace(const T * input_data,const Dims<4> & input_dims,int block_size,T * output_data,const Dims<4> & output_dims)1852 inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
1853 int block_size, T* output_data,
1854 const Dims<4>& output_dims) {
1855 tflite::DepthToSpaceParams op_params;
1856 op_params.block_size = block_size;
1857
1858 DepthToSpace(op_params, DimsToShape(input_dims), input_data,
1859 DimsToShape(output_dims), output_data);
1860 }
1861
1862 template <typename T>
SpaceToDepth(const T * input_data,const Dims<4> & input_dims,int block_size,T * output_data,const Dims<4> & output_dims)1863 inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
1864 int block_size, T* output_data,
1865 const Dims<4>& output_dims) {
1866 tflite::SpaceToDepthParams op_params;
1867 op_params.block_size = block_size;
1868
1869 SpaceToDepth(op_params, DimsToShape(input_dims), input_data,
1870 DimsToShape(output_dims), output_data);
1871 }
1872
1873 template <typename T>
Mul(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)1874 inline void Mul(const T* input1_data, const Dims<4>& input1_dims,
1875 const T* input2_data, const Dims<4>& input2_dims,
1876 T output_activation_min, T output_activation_max,
1877 T* output_data, const Dims<4>& output_dims) {
1878 tflite::ArithmeticParams op_params;
1879 SetActivationParams(output_activation_min, output_activation_max, &op_params);
1880
1881 Mul(op_params, DimsToShape(input1_dims), input1_data,
1882 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1883 output_data);
1884 }
1885
1886 // legacy, for compatibility with old checked-in code
1887 template <FusedActivationFunctionType Ac>
Mul(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)1888 void Mul(const float* input1_data, const Dims<4>& input1_dims,
1889 const float* input2_data, const Dims<4>& input2_dims,
1890 float* output_data, const Dims<4>& output_dims) {
1891 float output_activation_min, output_activation_max;
1892 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1893
1894 tflite::ArithmeticParams op_params;
1895 SetActivationParams(output_activation_min, output_activation_max, &op_params);
1896
1897 Mul(op_params, DimsToShape(input1_dims), input1_data,
1898 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1899 output_data);
1900 }
1901
1902 template <typename T>
BroadcastMul(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)1903 void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
1904 const T* input2_data, const Dims<4>& input2_dims,
1905 T output_activation_min, T output_activation_max,
1906 T* output_data, const Dims<4>& output_dims) {
1907 tflite::ArithmeticParams op_params;
1908 SetActivationParams(output_activation_min, output_activation_max, &op_params);
1909
1910 BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
1911 DimsToShape(input2_dims), input2_data,
1912 DimsToShape(output_dims), output_data);
1913 }
1914
1915 // legacy, for compatibility with old checked-in code
1916 template <FusedActivationFunctionType Ac, typename T>
BroadcastMul(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)1917 void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
1918 const T* input2_data, const Dims<4>& input2_dims,
1919 T* output_data, const Dims<4>& output_dims) {
1920 T output_activation_min, output_activation_max;
1921 GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1922
1923 tflite::ArithmeticParams op_params;
1924 SetActivationParams(output_activation_min, output_activation_max, &op_params);
1925
1926 BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
1927 DimsToShape(input2_dims), input2_data,
1928 DimsToShape(output_dims), output_data);
1929 }
1930
Mul(const int16 * input1_data,const Dims<4> & input1_dims,const int16 * input2_data,const Dims<4> & input2_dims,int16 * output_data,const Dims<4> & output_dims)1931 inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
1932 const int16* input2_data, const Dims<4>& input2_dims,
1933 int16* output_data, const Dims<4>& output_dims) {
1934 tflite::ArithmeticParams op_params;
1935 // No params in this version.
1936
1937 Mul(op_params, DimsToShape(input1_dims), input1_data,
1938 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1939 output_data);
1940 }
1941
Mul(const int16 * input1_data,const Dims<4> & input1_dims,const int16 * input2_data,const Dims<4> & input2_dims,int32 output_offset,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1942 inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
1943 const int16* input2_data, const Dims<4>& input2_dims,
1944 int32 output_offset, int32 output_activation_min,
1945 int32 output_activation_max, uint8* output_data,
1946 const Dims<4>& output_dims) {
1947 tflite::ArithmeticParams op_params;
1948 op_params.quantized_activation_min = output_activation_min;
1949 op_params.quantized_activation_max = output_activation_max;
1950 op_params.output_offset = output_offset;
1951
1952 Mul(op_params, DimsToShape(input1_dims), input1_data,
1953 DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1954 output_data);
1955 }
1956
LocalResponseNormalization(const float * input_data,const Dims<4> & input_dims,int range,float bias,float alpha,float beta,float * output_data,const Dims<4> & output_dims)1957 inline void LocalResponseNormalization(const float* input_data,
1958 const Dims<4>& input_dims, int range,
1959 float bias, float alpha, float beta,
1960 float* output_data,
1961 const Dims<4>& output_dims) {
1962 tflite::LocalResponseNormalizationParams op_params;
1963 op_params.range = range;
1964 op_params.bias = bias;
1965 op_params.alpha = alpha;
1966 op_params.beta = beta;
1967
1968 LocalResponseNormalization(op_params, DimsToShape(input_dims), input_data,
1969 DimsToShape(output_dims), output_data);
1970 }
1971
1972 template <typename SrcT, typename DstT>
Cast(const SrcT * input_data,const Dims<4> & input_dims,DstT * output_data,const Dims<4> & output_dims)1973 void Cast(const SrcT* input_data, const Dims<4>& input_dims, DstT* output_data,
1974 const Dims<4>& output_dims) {
1975 Cast(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1976 output_data);
1977 }
1978
Floor(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1979 inline void Floor(const float* input_data, const Dims<4>& input_dims,
1980 float* output_data, const Dims<4>& output_dims) {
1981 Floor(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1982 output_data);
1983 }
1984
1985 template <typename T>
ResizeBilinear(const T * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,T * output_data,const Dims<4> & output_dims,bool align_corners)1986 inline void ResizeBilinear(const T* input_data, const Dims<4>& input_dims,
1987 const int32* output_size_data,
1988 const Dims<4>& output_size_dims, T* output_data,
1989 const Dims<4>& output_dims, bool align_corners) {
1990 tflite::ResizeBilinearParams op_params;
1991 op_params.align_corners = align_corners;
1992 op_params.half_pixel_centers = false;
1993 ResizeBilinear(op_params, DimsToShape(input_dims), input_data,
1994 DimsToShape(output_size_dims), output_size_data,
1995 DimsToShape(output_dims), output_data);
1996 }
1997
1998 // legacy, for compatibility with old checked-in code
ResizeBilinear(const float * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,float * output_data,const Dims<4> & output_dims)1999 inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
2000 const int32* output_size_data,
2001 const Dims<4>& output_size_dims, float* output_data,
2002 const Dims<4>& output_dims) {
2003 ResizeBilinear<float>(input_data, input_dims, output_size_data,
2004 output_size_dims, output_data, output_dims,
2005 /*align_corners=*/false);
2006 }
2007
ResizeBilinear(const uint8 * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,uint8 * output_data,const Dims<4> & output_dims)2008 inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
2009 const int32* output_size_data,
2010 const Dims<4>& output_size_dims, uint8* output_data,
2011 const Dims<4>& output_dims) {
2012 ResizeBilinear<uint8>(input_data, input_dims, output_size_data,
2013 output_size_dims, output_data, output_dims,
2014 /*align_corners=*/false);
2015 }
2016
2017 template <typename T>
SpaceToBatchND(const T * input_data,const Dims<4> & input_dims,const int32 * block_shape_data,const Dims<4> & block_shape_dims,const int32 * paddings_data,const Dims<4> & paddings_dims,T * output_data,const Dims<4> & output_dims,const int32_t pad_value)2018 inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
2019 const int32* block_shape_data,
2020 const Dims<4>& block_shape_dims,
2021 const int32* paddings_data,
2022 const Dims<4>& paddings_dims, T* output_data,
2023 const Dims<4>& output_dims,
2024 const int32_t pad_value) {
2025 tflite::SpaceToBatchParams op_params;
2026 op_params.output_offset = pad_value;
2027
2028 SpaceToBatchND(op_params, DimsToShape(input_dims), input_data,
2029 DimsToShape(block_shape_dims), block_shape_data,
2030 DimsToShape(paddings_dims), paddings_data,
2031 DimsToShape(output_dims), output_data);
2032 }
2033
2034 template <typename T>
SpaceToBatchND(const T * input_data,const Dims<4> & input_dims,const int32 * block_shape_data,const Dims<4> & block_shape_dims,const int32 * paddings_data,const Dims<4> & paddings_dims,T * output_data,const Dims<4> & output_dims)2035 inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
2036 const int32* block_shape_data,
2037 const Dims<4>& block_shape_dims,
2038 const int32* paddings_data,
2039 const Dims<4>& paddings_dims, T* output_data,
2040 const Dims<4>& output_dims) {
2041 tflite::SpaceToBatchParams op_params;
2042 op_params.output_offset = 0;
2043
2044 SpaceToBatchND(op_params, DimsToShape(input_dims), input_data,
2045 DimsToShape(block_shape_dims), block_shape_data,
2046 DimsToShape(paddings_dims), paddings_data,
2047 DimsToShape(output_dims), output_data);
2048 }
2049
2050 template <typename T>
BatchToSpaceND(const T * input_data,const Dims<4> & input_dims,const int32 * block_shape_data,const Dims<4> & block_shape_dims,const int32 * crops_data,const Dims<4> & crops_dims,T * output_data,const Dims<4> & output_dims)2051 inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
2052 const int32* block_shape_data,
2053 const Dims<4>& block_shape_dims,
2054 const int32* crops_data, const Dims<4>& crops_dims,
2055 T* output_data, const Dims<4>& output_dims) {
2056 BatchToSpaceND(DimsToShape(input_dims), input_data,
2057 DimsToShape(block_shape_dims), block_shape_data,
2058 DimsToShape(crops_dims), crops_data, DimsToShape(output_dims),
2059 output_data);
2060 }
2061
2062 // Legacy signature, function covered both Pad and PadV2.
2063 template <typename T>
PadV2(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & left_paddings,const std::vector<int> & right_paddings,T * output_data,const Dims<4> & output_dims,const T pad_value)2064 inline void PadV2(const T* input_data, const Dims<4>& input_dims,
2065 const std::vector<int>& left_paddings,
2066 const std::vector<int>& right_paddings, T* output_data,
2067 const Dims<4>& output_dims, const T pad_value) {
2068 TFLITE_DCHECK_EQ(left_paddings.size(), 4);
2069 TFLITE_DCHECK_EQ(right_paddings.size(), 4);
2070 tflite::PadParams op_params;
2071 op_params.left_padding_count = 4;
2072 op_params.right_padding_count = 4;
2073 for (int i = 0; i < 4; ++i) {
2074 op_params.left_padding[i] = left_paddings[3 - i];
2075 op_params.right_padding[i] = right_paddings[3 - i];
2076 }
2077 // SetFloatOrInt(pad_value, &op_params.pad_value);
2078 const T pad_value_copy = pad_value;
2079
2080 Pad(op_params, DimsToShape(input_dims), input_data, &pad_value_copy,
2081 DimsToShape(output_dims), output_data);
2082 }
2083
2084 // Old Pad that calls legacy PadV2.
2085 template <typename T>
Pad(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & left_paddings,const std::vector<int> & right_paddings,T * output_data,const Dims<4> & output_dims,const int32_t pad_value)2086 inline void Pad(const T* input_data, const Dims<4>& input_dims,
2087 const std::vector<int>& left_paddings,
2088 const std::vector<int>& right_paddings, T* output_data,
2089 const Dims<4>& output_dims, const int32_t pad_value) {
2090 const T converted_pad_value = static_cast<T>(pad_value);
2091 PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
2092 output_dims, converted_pad_value);
2093 }
2094
2095 // Old Pad that only padded with 0.
2096 template <typename T>
Pad(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & left_paddings,const std::vector<int> & right_paddings,T * output_data,const Dims<4> & output_dims)2097 inline void Pad(const T* input_data, const Dims<4>& input_dims,
2098 const std::vector<int>& left_paddings,
2099 const std::vector<int>& right_paddings, T* output_data,
2100 const Dims<4>& output_dims) {
2101 const T pad_value = static_cast<T>(0);
2102 PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
2103 output_dims, pad_value);
2104 }
2105
2106 template <typename T>
TensorFlowMinimum(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,T * output_data,const Dims<4> & output_dims)2107 void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
2108 const T* input2_data, T* output_data,
2109 const Dims<4>& output_dims) {
2110 Minimum(DimsToShape(input1_dims), input1_data, input2_data,
2111 DimsToShape(output_dims), output_data);
2112 }
2113
2114 template <typename T>
TensorFlowMaximum(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,T * output_data,const Dims<4> & output_dims)2115 void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
2116 const T* input2_data, T* output_data,
2117 const Dims<4>& output_dims) {
2118 Maximum(DimsToShape(input1_dims), input1_data, input2_data,
2119 DimsToShape(output_dims), output_data);
2120 }
2121
2122 template <typename T, typename Op>
TensorFlowMaximumMinimum(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims,Op op)2123 void TensorFlowMaximumMinimum(const T* input1_data, const Dims<4>& input1_dims,
2124 const T* input2_data, const Dims<4>& input2_dims,
2125 T* output_data, const Dims<4>& output_dims,
2126 Op op) {
2127 MaximumMinimumBroadcastSlow(DimsToShape(input1_dims), input1_data,
2128 DimsToShape(input2_dims), input2_data,
2129 DimsToShape(output_dims), output_data, op);
2130 }
2131
2132 template <typename T1, typename T2, typename T3>
ArgMax(const T3 * axis,const T1 * input_data,const tflite::Dims<4> & input_dims,T2 * output_data,const tflite::Dims<4> & output_dims)2133 void ArgMax(const T3* axis, const T1* input_data,
2134 const tflite::Dims<4>& input_dims, T2* output_data,
2135 const tflite::Dims<4>& output_dims) {
2136 // Assumes the input always has 4 dimensions, and therefore,
2137 // output always has three dimensions.
2138 auto output_shape = RuntimeShape(
2139 {output_dims.sizes[2], output_dims.sizes[1], output_dims.sizes[0]});
2140 // Another way to interpret this is that output_dims.sizes[4] is always 1.
2141 TFLITE_DCHECK_EQ(output_shape.FlatSize(),
2142 DimsToShape(output_dims).FlatSize());
2143 // Legacy path only supported this.
2144 TFLITE_DCHECK_EQ(axis[0], 3);
2145 ArgMinMax(DimsToShape(input_dims), input_data, axis, output_shape,
2146 output_data, std::greater<T1>());
2147 }
2148
2149 template <typename T1, typename T2, typename T3, typename Cmp>
ArgMinMax(const T3 * axis,const T1 * input_data,const Dims<4> & input_dims,T2 * output_data,const Dims<4> & output_dims,const Cmp & cmp)2150 void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
2151 T2* output_data, const Dims<4>& output_dims, const Cmp& cmp) {
2152 ArgMinMax(axis, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
2153 output_data, cmp);
2154 }
2155
2156 template <typename T>
Pow(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)2157 inline void Pow(const T* input1_data, const Dims<4>& input1_dims,
2158 const T* input2_data, const Dims<4>& input2_dims,
2159 T* output_data, const Dims<4>& output_dims) {
2160 Pow(DimsToShape(input1_dims), input1_data, DimsToShape(input2_dims),
2161 input2_data, DimsToShape(output_dims), output_data);
2162 }
2163
2164 template <typename T>
BroadcastPow(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)2165 inline void BroadcastPow(const T* input1_data, const Dims<4>& input1_dims,
2166 const T* input2_data, const Dims<4>& input2_dims,
2167 T* output_data, const Dims<4>& output_dims) {
2168 BroadcastPow4DSlow(DimsToShape(input1_dims), input1_data,
2169 DimsToShape(input2_dims), input2_data,
2170 DimsToShape(output_dims), output_data);
2171 }
2172
2173 // R: Result type. T1: Input 1 type. T2: Input 2 type.
2174 template <typename R, typename T1, typename T2>
BroadcastBinaryFunction(const T1 * input1_data,const Dims<4> & input1_dims,const T2 * input2_data,const Dims<4> & input2_dims,R * output_data,const Dims<4> & output_dims,R (* func)(T1,T2))2175 inline void BroadcastBinaryFunction(const T1* input1_data,
2176 const Dims<4>& input1_dims,
2177 const T2* input2_data,
2178 const Dims<4>& input2_dims, R* output_data,
2179 const Dims<4>& output_dims,
2180 R (*func)(T1, T2)) {
2181 BroadcastBinaryFunction(DimsToShape(input1_dims), input1_data,
2182 DimsToShape(input2_dims), input2_data,
2183 DimsToShape(output_dims), output_data, func);
2184 }
2185
2186 // R: Result type. T1: Input 1 type. T2: Input 2 type.
2187 template <typename R, typename T1, typename T2>
BinaryFunction(const T1 * input1_data,const Dims<4> & input1_dims,const T2 * input2_data,const Dims<4> & input2_dims,R * output_data,const Dims<4> & output_dims,R (* func)(T1,T2))2188 inline void BinaryFunction(const T1* input1_data, const Dims<4>& input1_dims,
2189 const T2* input2_data, const Dims<4>& input2_dims,
2190 R* output_data, const Dims<4>& output_dims,
2191 R (*func)(T1, T2)) {
2192 BinaryFunction(DimsToShape(input1_dims), input1_data,
2193 DimsToShape(input2_dims), input2_data,
2194 DimsToShape(output_dims), output_data, func);
2195 }
2196
2197 template <typename T>
Slice(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & begin,const std::vector<int> & size,T * output_data,const Dims<4> & output_dims)2198 inline void Slice(const T* input_data, const Dims<4>& input_dims,
2199 const std::vector<int>& begin, const std::vector<int>& size,
2200 T* output_data, const Dims<4>& output_dims) {
2201 tflite::SliceParams op_params;
2202 op_params.begin_count = 4;
2203 op_params.size_count = 4;
2204 for (int i = 0; i < 4; ++i) {
2205 op_params.begin[i] = begin[3 - i];
2206 op_params.size[i] = size[3 - i];
2207 }
2208
2209 Slice(op_params, DimsToShape(input_dims), input_data,
2210 DimsToShape(output_dims), output_data);
2211 }
2212
2213 } // namespace reference_ops
2214 } // namespace tflite
2215 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_
2216