xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_
17 
18 #include <stdint.h>
19 #include <sys/types.h>
20 
21 #include <algorithm>
22 
23 #include "public/gemmlowp.h"
24 #include "tensorflow/lite/kernels/internal/common.h"
25 #include "tensorflow/lite/kernels/internal/legacy_types.h"
26 #include "tensorflow/lite/kernels/internal/reference/conv.h"
27 #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h"
28 #include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
29 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
30 #include "tensorflow/lite/kernels/internal/reference/tanh.h"
31 #include "tensorflow/lite/kernels/internal/types.h"
32 
33 namespace tflite {
34 
35 namespace reference_ops {
36 
37 static constexpr int kDepthwiseReverseShift = -1;
38 
ShapeFromDims(const tflite::Dims<4> & dims,RuntimeShape * shape)39 inline void ShapeFromDims(const tflite::Dims<4>& dims, RuntimeShape* shape) {
40   shape->BuildFrom(
41       {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
42 }
43 
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,int depth_multiplier,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)44 inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
45                           const float* filter_data, const Dims<4>& filter_dims,
46                           const float* bias_data, const Dims<4>& bias_dims,
47                           int stride_width, int stride_height,
48                           int dilation_width_factor, int dilation_height_factor,
49                           int pad_width, int pad_height, int depth_multiplier,
50                           float output_activation_min,
51                           float output_activation_max, float* output_data,
52                           const Dims<4>& output_dims) {
53   tflite::DepthwiseParams op_params;
54   // Padding type is ignored, but still set.
55   op_params.padding_type = PaddingType::kSame;
56   op_params.padding_values.width = pad_width;
57   op_params.padding_values.height = pad_height;
58   op_params.stride_width = stride_width;
59   op_params.stride_height = stride_height;
60   op_params.dilation_width_factor = dilation_width_factor;
61   op_params.dilation_height_factor = dilation_height_factor;
62   op_params.depth_multiplier = depth_multiplier;
63   op_params.float_activation_min = output_activation_min;
64   op_params.float_activation_max = output_activation_max;
65 
66   DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
67                 DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
68                 bias_data, DimsToShape(output_dims), output_data);
69 }
70 
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)71 inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
72                           const float* filter_data, const Dims<4>& filter_dims,
73                           const float* bias_data, const Dims<4>& bias_dims,
74                           int stride_width, int stride_height, int pad_width,
75                           int pad_height, int depth_multiplier,
76                           float output_activation_min,
77                           float output_activation_max, float* output_data,
78                           const Dims<4>& output_dims) {
79   DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
80                 bias_dims, stride_width, stride_height, 1, 1, pad_width,
81                 pad_height, depth_multiplier, output_activation_min,
82                 output_activation_max, output_data, output_dims);
83 }
84 
85 // Legacy, for compatibility with old checked-in code.
86 template <FusedActivationFunctionType Ac>
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,float * output_data,const Dims<4> & output_dims)87 void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
88                    const float* filter_data, const Dims<4>& filter_dims,
89                    const float* bias_data, const Dims<4>& bias_dims,
90                    int stride_width, int stride_height, int pad_width,
91                    int pad_height, int depth_multiplier, float* output_data,
92                    const Dims<4>& output_dims) {
93   float output_activation_min, output_activation_max;
94   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
95   DepthwiseConv(input_data, input_dims, filter_data, filter_dims, bias_data,
96                 bias_dims, stride_width, stride_height, pad_width, pad_height,
97                 depth_multiplier, output_activation_min, output_activation_max,
98                 output_data, output_dims);
99 }
100 
101 // Legacy, for compatibility with old checked-in code.
102 template <FusedActivationFunctionType Ac>
DepthwiseConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,int depth_multiplier,float * output_data,const Dims<4> & output_dims)103 void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
104                    const float* filter_data, const Dims<4>& filter_dims,
105                    const float* bias_data, const Dims<4>& bias_dims, int stride,
106                    int pad_width, int pad_height, int depth_multiplier,
107                    float* output_data, const Dims<4>& output_dims) {
108   DepthwiseConv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
109                     bias_dims, stride, stride, pad_width, pad_height,
110                     depth_multiplier, output_data, output_dims);
111 }
112 
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)113 inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
114                           int32 input_offset, const uint8* filter_data,
115                           const Dims<4>& filter_dims, int32 filter_offset,
116                           const int32* bias_data, const Dims<4>& bias_dims,
117                           int stride_width, int stride_height,
118                           int dilation_width_factor, int dilation_height_factor,
119                           int pad_width, int pad_height, int depth_multiplier,
120                           int32 output_offset, int32 output_multiplier,
121                           int output_shift, int32 output_activation_min,
122                           int32 output_activation_max, uint8* output_data,
123                           const Dims<4>& output_dims) {
124   tflite::DepthwiseParams op_params;
125   // Padding type is ignored, but still set.
126   op_params.padding_type = PaddingType::kSame;
127   op_params.padding_values.width = pad_width;
128   op_params.padding_values.height = pad_height;
129   op_params.stride_width = stride_width;
130   op_params.stride_height = stride_height;
131   op_params.dilation_width_factor = dilation_width_factor;
132   op_params.dilation_height_factor = dilation_height_factor;
133   op_params.depth_multiplier = depth_multiplier;
134   op_params.quantized_activation_min = output_activation_min;
135   op_params.quantized_activation_max = output_activation_max;
136   op_params.input_offset = input_offset;
137   op_params.weights_offset = filter_offset;
138   op_params.output_offset = output_offset;
139   op_params.output_multiplier = output_multiplier;
140   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
141   op_params.output_shift = kDepthwiseReverseShift * output_shift;
142 
143   DepthwiseConv(op_params, DimsToShape(input_dims), input_data,
144                 DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
145                 bias_data, DimsToShape(output_dims), output_data);
146 }
147 
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)148 inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
149                           int32 input_offset, const uint8* filter_data,
150                           const Dims<4>& filter_dims, int32 filter_offset,
151                           const int32* bias_data, const Dims<4>& bias_dims,
152                           int stride_width, int stride_height, int pad_width,
153                           int pad_height, int depth_multiplier,
154                           int32 output_offset, int32 output_multiplier,
155                           int output_shift, int32 output_activation_min,
156                           int32 output_activation_max, uint8* output_data,
157                           const Dims<4>& output_dims) {
158   DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
159                 filter_offset, bias_data, bias_dims, stride_width,
160                 stride_height, 1, 1, pad_width, pad_height, depth_multiplier,
161                 output_offset, output_multiplier, output_shift,
162                 output_activation_min, output_activation_max, output_data,
163                 output_dims);
164 }
165 
166 // Legacy, for compatibility with old checked-in code.
167 template <FusedActivationFunctionType Ac>
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)168 void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
169                    int32 input_offset, const uint8* filter_data,
170                    const Dims<4>& filter_dims, int32 filter_offset,
171                    const int32* bias_data, const Dims<4>& bias_dims,
172                    int stride_width, int stride_height, int pad_width,
173                    int pad_height, int depth_multiplier, int32 output_offset,
174                    int32 output_multiplier, int output_shift,
175                    int32 output_activation_min, int32 output_activation_max,
176                    uint8* output_data, const Dims<4>& output_dims) {
177   if (Ac == FusedActivationFunctionType::kNone) {
178     TFLITE_DCHECK_EQ(output_activation_min, 0);
179     TFLITE_DCHECK_EQ(output_activation_max, 255);
180   }
181   DepthwiseConv(input_data, input_dims, input_offset, filter_data, filter_dims,
182                 filter_offset, bias_data, bias_dims, stride_width,
183                 stride_height, pad_width, pad_height, depth_multiplier,
184                 output_offset, output_multiplier, output_shift,
185                 output_activation_min, output_activation_max, output_data,
186                 output_dims);
187 }
188 
189 // Legacy, for compatibility with old checked-in code.
190 template <FusedActivationFunctionType Ac>
DepthwiseConv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,int depth_multiplier,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)191 void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims,
192                    int32 input_offset, const uint8* filter_data,
193                    const Dims<4>& filter_dims, int32 filter_offset,
194                    const int32* bias_data, const Dims<4>& bias_dims, int stride,
195                    int pad_width, int pad_height, int depth_multiplier,
196                    int32 output_offset, int32 output_multiplier,
197                    int output_shift, int32 output_activation_min,
198                    int32 output_activation_max, uint8* output_data,
199                    const Dims<4>& output_dims) {
200   DepthwiseConv<Ac>(input_data, input_dims, input_offset, filter_data,
201                     filter_dims, filter_offset, bias_data, bias_dims, stride,
202                     stride, pad_width, pad_height, depth_multiplier,
203                     output_offset, output_multiplier, output_shift,
204                     output_activation_min, output_activation_max, output_data,
205                     output_dims);
206 }
207 
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)208 inline void Conv(const float* input_data, const Dims<4>& input_dims,
209                  const float* filter_data, const Dims<4>& filter_dims,
210                  const float* bias_data, const Dims<4>& bias_dims,
211                  int stride_width, int stride_height, int dilation_width_factor,
212                  int dilation_height_factor, int pad_width, int pad_height,
213                  float output_activation_min, float output_activation_max,
214                  float* output_data, const Dims<4>& output_dims,
215                  float* im2col_data, const Dims<4>& im2col_dims) {
216   tflite::ConvParams op_params;
217   // Padding type is ignored, but still set.
218   op_params.padding_type = PaddingType::kSame;
219   op_params.padding_values.width = pad_width;
220   op_params.padding_values.height = pad_height;
221   op_params.stride_width = stride_width;
222   op_params.stride_height = stride_height;
223   op_params.dilation_width_factor = dilation_width_factor;
224   op_params.dilation_height_factor = dilation_height_factor;
225   op_params.float_activation_min = output_activation_min;
226   op_params.float_activation_max = output_activation_max;
227 
228   Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
229        filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
230        output_data, DimsToShape(im2col_dims), im2col_data);
231 }
232 
233 template <FusedActivationFunctionType Ac>
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)234 void Conv(const float* input_data, const Dims<4>& input_dims,
235           const float* filter_data, const Dims<4>& filter_dims,
236           const float* bias_data, const Dims<4>& bias_dims, int stride_width,
237           int stride_height, int dilation_width_factor,
238           int dilation_height_factor, int pad_width, int pad_height,
239           float* output_data, const Dims<4>& output_dims, float* im2col_data,
240           const Dims<4>& im2col_dims) {
241   float output_activation_min, output_activation_max;
242   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
243   Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
244        stride_width, stride_height, dilation_width_factor,
245        dilation_height_factor, pad_width, pad_height, output_activation_min,
246        output_activation_max, output_data, output_dims, im2col_data,
247        im2col_dims);
248 }
249 
250 // legacy, for compatibility with old checked-in code
251 template <FusedActivationFunctionType Ac>
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)252 void Conv(const float* input_data, const Dims<4>& input_dims,
253           const float* filter_data, const Dims<4>& filter_dims,
254           const float* bias_data, const Dims<4>& bias_dims, int stride_width,
255           int stride_height, int pad_width, int pad_height, float* output_data,
256           const Dims<4>& output_dims, float* im2col_data,
257           const Dims<4>& im2col_dims) {
258   float output_activation_min, output_activation_max;
259   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
260   Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims,
261        stride_width, stride_height, 1, 1, pad_width, pad_height,
262        output_activation_min, output_activation_max, output_data, output_dims,
263        im2col_data, im2col_dims);
264 }
265 
266 // legacy, for compatibility with old checked-in code
267 template <FusedActivationFunctionType Ac>
Conv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,const float * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)268 void Conv(const float* input_data, const Dims<4>& input_dims,
269           const float* filter_data, const Dims<4>& filter_dims,
270           const float* bias_data, const Dims<4>& bias_dims, int stride,
271           int pad_width, int pad_height, float* output_data,
272           const Dims<4>& output_dims, float* im2col_data,
273           const Dims<4>& im2col_dims) {
274   Conv<Ac>(input_data, input_dims, filter_data, filter_dims, bias_data,
275            bias_dims, stride, stride, 1, 1, pad_width, pad_height, output_data,
276            output_dims, im2col_data, im2col_dims);
277 }
278 
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemmlowp_context)279 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
280                  int32 input_offset, const uint8* filter_data,
281                  const Dims<4>& filter_dims, int32 filter_offset,
282                  const int32* bias_data, const Dims<4>& bias_dims,
283                  int stride_width, int stride_height, int dilation_width_factor,
284                  int dilation_height_factor, int pad_width, int pad_height,
285                  int32 output_offset, int32 output_multiplier, int output_shift,
286                  int32 output_activation_min, int32 output_activation_max,
287                  uint8* output_data, const Dims<4>& output_dims,
288                  uint8* im2col_data, const Dims<4>& im2col_dims,
289                  gemmlowp::GemmContext* gemmlowp_context) {
290   tflite::ConvParams op_params;
291   // Padding type is ignored, but still set.
292   op_params.padding_type = PaddingType::kSame;
293   op_params.padding_values.width = pad_width;
294   op_params.padding_values.height = pad_height;
295   op_params.stride_width = stride_width;
296   op_params.stride_height = stride_height;
297   op_params.dilation_width_factor = dilation_width_factor;
298   op_params.dilation_height_factor = dilation_height_factor;
299   op_params.input_offset = input_offset;
300   op_params.weights_offset = filter_offset;
301   op_params.output_offset = output_offset;
302   op_params.output_multiplier = output_multiplier;
303   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
304   op_params.output_shift = kReverseShift * output_shift;
305   op_params.quantized_activation_min = output_activation_min;
306   op_params.quantized_activation_max = output_activation_max;
307 
308   Conv(op_params, DimsToShape(input_dims), input_data, DimsToShape(filter_dims),
309        filter_data, DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
310        output_data, DimsToShape(im2col_dims), im2col_data, gemmlowp_context);
311 }
312 
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemmlowp_context)313 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
314                  int32 input_offset, const uint8* filter_data,
315                  const Dims<4>& filter_dims, int32 filter_offset,
316                  const int32* bias_data, const Dims<4>& bias_dims,
317                  int stride_width, int stride_height, int pad_width,
318                  int pad_height, int32 output_offset, int32 output_multiplier,
319                  int output_shift, int32 output_activation_min,
320                  int32 output_activation_max, uint8* output_data,
321                  const Dims<4>& output_dims, uint8* im2col_data,
322                  const Dims<4>& im2col_dims,
323                  gemmlowp::GemmContext* gemmlowp_context) {
324   Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
325        filter_offset, bias_data, bias_dims, stride_width, stride_height, 1, 1,
326        pad_width, pad_height, output_offset, output_multiplier, output_shift,
327        output_activation_min, output_activation_max, output_data, output_dims,
328        im2col_data, im2col_dims, gemmlowp_context);
329 }
330 
331 // legacy, for compatibility with old checked-in code
332 template <FusedActivationFunctionType Ac>
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride_width,int stride_height,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemmlowp_context)333 inline void Conv(const uint8* input_data, const Dims<4>& input_dims,
334                  int32 input_offset, const uint8* filter_data,
335                  const Dims<4>& filter_dims, int32 filter_offset,
336                  const int32* bias_data, const Dims<4>& bias_dims,
337                  int stride_width, int stride_height, int pad_width,
338                  int pad_height, int32 output_offset, int32 output_multiplier,
339                  int output_shift, int32 output_activation_min,
340                  int32 output_activation_max, uint8* output_data,
341                  const Dims<4>& output_dims, uint8* im2col_data,
342                  const Dims<4>& im2col_dims,
343                  gemmlowp::GemmContext* gemmlowp_context) {
344   static_assert(Ac == FusedActivationFunctionType::kNone ||
345                     Ac == FusedActivationFunctionType::kRelu ||
346                     Ac == FusedActivationFunctionType::kRelu6 ||
347                     Ac == FusedActivationFunctionType::kRelu1,
348                 "");
349   if (Ac == FusedActivationFunctionType::kNone) {
350     TFLITE_DCHECK_EQ(output_activation_min, 0);
351     TFLITE_DCHECK_EQ(output_activation_max, 255);
352   }
353   Conv(input_data, input_dims, input_offset, filter_data, filter_dims,
354        filter_offset, bias_data, bias_dims, stride_width, stride_height,
355        pad_width, pad_height, output_offset, output_multiplier, output_shift,
356        output_activation_min, output_activation_max, output_data, output_dims,
357        im2col_data, im2col_dims, gemmlowp_context);
358 }
359 
360 // legacy, for compatibility with old checked-in code
361 template <FusedActivationFunctionType Ac>
Conv(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int stride,int pad_width,int pad_height,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,uint8 * im2col_data,const Dims<4> & im2col_dims,gemmlowp::GemmContext * gemmlowp_context)362 void Conv(const uint8* input_data, const Dims<4>& input_dims,
363           int32 input_offset, const uint8* filter_data,
364           const Dims<4>& filter_dims, int32 filter_offset,
365           const int32* bias_data, const Dims<4>& bias_dims, int stride,
366           int pad_width, int pad_height, int32 output_offset,
367           int32 output_multiplier, int output_shift,
368           int32 output_activation_min, int32 output_activation_max,
369           uint8* output_data, const Dims<4>& output_dims, uint8* im2col_data,
370           const Dims<4>& im2col_dims, gemmlowp::GemmContext* gemmlowp_context) {
371   Conv<Ac>(input_data, input_dims, input_offset, filter_data, filter_dims,
372            filter_offset, bias_data, bias_dims, stride, stride, pad_width,
373            pad_height, output_offset, output_multiplier, output_shift,
374            output_activation_min, output_activation_max, output_data,
375            output_dims, im2col_data, im2col_dims, gemmlowp_context);
376 }
377 
TransposeConv(const float * input_data,const Dims<4> & input_dims,const float * filter_data,const Dims<4> & filter_dims,int stride_width,int stride_height,int pad_width,int pad_height,float * output_data,const Dims<4> & output_dims,float * im2col_data,const Dims<4> & im2col_dims)378 inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
379                           const float* filter_data, const Dims<4>& filter_dims,
380                           int stride_width, int stride_height, int pad_width,
381                           int pad_height, float* output_data,
382                           const Dims<4>& output_dims, float* im2col_data,
383                           const Dims<4>& im2col_dims) {
384   tflite::ConvParams op_params;
385   // Padding type is ignored, but still set.
386   op_params.padding_type = PaddingType::kSame;
387   op_params.padding_values.width = pad_width;
388   op_params.padding_values.height = pad_height;
389   op_params.stride_width = stride_width;
390   op_params.stride_height = stride_height;
391 
392   TransposeConv(op_params, DimsToShape(input_dims), input_data,
393                 DimsToShape(filter_dims), filter_data,
394                 /*bias_shape*/ RuntimeShape(), /*bias*/ nullptr,
395                 DimsToShape(output_dims), output_data, DimsToShape(im2col_dims),
396                 im2col_data);
397 }
398 
TransposeConv(const ConvParams & params,const RuntimeShape & input_shape,const float * input_data,const RuntimeShape & filter_shape,const float * filter_data,const RuntimeShape & output_shape,float * output_data,const RuntimeShape & im2col_shape,float * im2col_data)399 inline void TransposeConv(
400     const ConvParams& params, const RuntimeShape& input_shape,
401     const float* input_data, const RuntimeShape& filter_shape,
402     const float* filter_data, const RuntimeShape& output_shape,
403     float* output_data, const RuntimeShape& im2col_shape, float* im2col_data) {
404   TransposeConv(params, input_shape, input_data, filter_shape, filter_data,
405                 /*bias_shape*/ RuntimeShape(), /*bias*/ nullptr, output_shape,
406                 output_data, im2col_shape, im2col_data);
407 }
408 
FullyConnected(const float * input_data,const Dims<4> & input_dims,const float * weights_data,const Dims<4> & weights_dims,const float * bias_data,const Dims<4> & bias_dims,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)409 inline void FullyConnected(const float* input_data, const Dims<4>& input_dims,
410                            const float* weights_data,
411                            const Dims<4>& weights_dims, const float* bias_data,
412                            const Dims<4>& bias_dims,
413                            float output_activation_min,
414                            float output_activation_max, float* output_data,
415                            const Dims<4>& output_dims) {
416   tflite::FullyConnectedParams op_params;
417   op_params.float_activation_min = output_activation_min;
418   op_params.float_activation_max = output_activation_max;
419 
420   FullyConnected(op_params, DimsToShape(input_dims), input_data,
421                  DimsToShape(weights_dims), weights_data,
422                  DimsToShape(bias_dims), bias_data, DimsToShape(output_dims),
423                  output_data);
424 }
425 
426 // legacy, for compatibility with old checked-in code
427 template <FusedActivationFunctionType Ac>
FullyConnected(const float * input_data,const Dims<4> & input_dims,const float * weights_data,const Dims<4> & weights_dims,const float * bias_data,const Dims<4> & bias_dims,float * output_data,const Dims<4> & output_dims)428 void FullyConnected(const float* input_data, const Dims<4>& input_dims,
429                     const float* weights_data, const Dims<4>& weights_dims,
430                     const float* bias_data, const Dims<4>& bias_dims,
431                     float* output_data, const Dims<4>& output_dims) {
432   float output_activation_min, output_activation_max;
433   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
434   FullyConnected(input_data, input_dims, weights_data, weights_dims, bias_data,
435                  bias_dims, output_activation_min, output_activation_max,
436                  output_data, output_dims);
437 }
438 
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,uint8 * output_data,gemmlowp::GemmContext *)439 inline void FullyConnected(
440     const FullyConnectedParams& params, const RuntimeShape& input_shape,
441     const uint8* input_data, const RuntimeShape& filter_shape,
442     const uint8* filter_data, const RuntimeShape& bias_shape,
443     const int32* bias_data, const RuntimeShape& output_shape,
444     uint8* output_data, gemmlowp::GemmContext*) {
445   FullyConnected(params, input_shape, input_data, filter_shape, filter_data,
446                  bias_shape, bias_data, output_shape, output_data);
447 }
448 
FullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & filter_shape,const uint8 * filter_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,int16 * output_data,gemmlowp::GemmContext *)449 inline void FullyConnected(
450     const FullyConnectedParams& params, const RuntimeShape& input_shape,
451     const uint8* input_data, const RuntimeShape& filter_shape,
452     const uint8* filter_data, const RuntimeShape& bias_shape,
453     const int32* bias_data, const RuntimeShape& output_shape,
454     int16* output_data, gemmlowp::GemmContext*) {
455   FullyConnected(params, input_shape, input_data, filter_shape, filter_data,
456                  bias_shape, bias_data, output_shape, output_data);
457 }
458 
FullyConnected(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemmlowp_context)459 inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
460                            int32 input_offset, const uint8* filter_data,
461                            const Dims<4>& filter_dims, int32 filter_offset,
462                            const int32* bias_data, const Dims<4>& bias_dims,
463                            int32 output_offset, int32 output_multiplier,
464                            int output_shift, int32 output_activation_min,
465                            int32 output_activation_max, uint8* output_data,
466                            const Dims<4>& output_dims,
467                            gemmlowp::GemmContext* gemmlowp_context) {
468   tflite::FullyConnectedParams op_params;
469   op_params.input_offset = input_offset;
470   op_params.weights_offset = filter_offset;
471   op_params.output_offset = output_offset;
472   op_params.output_multiplier = output_multiplier;
473   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
474   op_params.output_shift = kReverseShift * output_shift;
475   op_params.quantized_activation_min = output_activation_min;
476   op_params.quantized_activation_max = output_activation_max;
477 
478   FullyConnected(op_params, DimsToShape(input_dims), input_data,
479                  DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
480                  bias_data, DimsToShape(output_dims), output_data,
481                  gemmlowp_context);
482 }
483 
FullyConnected(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,int16 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemmlowp_context)484 inline void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
485                            int32 input_offset, const uint8* filter_data,
486                            const Dims<4>& filter_dims, int32 filter_offset,
487                            const int32* bias_data, const Dims<4>& bias_dims,
488                            int32 output_offset, int32 output_multiplier,
489                            int output_shift, int32 output_activation_min,
490                            int32 output_activation_max, int16* output_data,
491                            const Dims<4>& output_dims,
492                            gemmlowp::GemmContext* gemmlowp_context) {
493   tflite::FullyConnectedParams op_params;
494   op_params.input_offset = input_offset;
495   op_params.weights_offset = filter_offset;
496   op_params.output_offset = output_offset;
497   op_params.output_multiplier = output_multiplier;
498   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
499   op_params.output_shift = kReverseShift * output_shift;
500   op_params.quantized_activation_min = output_activation_min;
501   op_params.quantized_activation_max = output_activation_max;
502 
503   FullyConnected(op_params, DimsToShape(input_dims), input_data,
504                  DimsToShape(filter_dims), filter_data, DimsToShape(bias_dims),
505                  bias_data, DimsToShape(output_dims), output_data,
506                  gemmlowp_context);
507 }
508 
ShuffledFullyConnected(const FullyConnectedParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & weights_shape,const uint8 * shuffled_weights_data,const RuntimeShape & bias_shape,const int32 * bias_data,const RuntimeShape & output_shape,int16 * output_data,uint8 * shuffled_input_workspace_data,gemmlowp::GemmContext *)509 inline void ShuffledFullyConnected(
510     const FullyConnectedParams& params, const RuntimeShape& input_shape,
511     const uint8* input_data, const RuntimeShape& weights_shape,
512     const uint8* shuffled_weights_data, const RuntimeShape& bias_shape,
513     const int32* bias_data, const RuntimeShape& output_shape,
514     int16* output_data, uint8* shuffled_input_workspace_data,
515     gemmlowp::GemmContext*) {
516   ShuffledFullyConnected(params, input_shape, input_data, weights_shape,
517                          shuffled_weights_data, bias_shape, bias_data,
518                          output_shape, output_data,
519                          shuffled_input_workspace_data);
520 }
521 
ShuffledFullyConnected(const uint8 * input_data,const Dims<4> & input_dims,const uint8 * shuffled_weights_data,const Dims<4> & weights_dims,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,int16 * output_data,const Dims<4> & output_dims,uint8 * shuffled_input_workspace_data,gemmlowp::GemmContext * gemmlowp_context)522 inline void ShuffledFullyConnected(
523     const uint8* input_data, const Dims<4>& input_dims,
524     const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
525     const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
526     int output_shift, int32 output_activation_min, int32 output_activation_max,
527     int16* output_data, const Dims<4>& output_dims,
528     uint8* shuffled_input_workspace_data,
529     gemmlowp::GemmContext* gemmlowp_context) {
530   tflite::FullyConnectedParams op_params;
531   op_params.output_multiplier = output_multiplier;
532   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
533   op_params.output_shift = kReverseShift * output_shift;
534   op_params.quantized_activation_min = output_activation_min;
535   op_params.quantized_activation_max = output_activation_max;
536 
537   ShuffledFullyConnected(op_params, DimsToShape(input_dims), input_data,
538                          DimsToShape(weights_dims), shuffled_weights_data,
539                          DimsToShape(bias_dims), bias_data,
540                          DimsToShape(output_dims), output_data,
541                          shuffled_input_workspace_data, gemmlowp_context);
542 }
543 
544 // legacy, for compatibility with old checked-in code
545 template <FusedActivationFunctionType Ac>
FullyConnected(const uint8 * input_data,const Dims<4> & input_dims,int32 input_offset,const uint8 * filter_data,const Dims<4> & filter_dims,int32 filter_offset,const int32 * bias_data,const Dims<4> & bias_dims,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims,gemmlowp::GemmContext * gemmlowp_context)546 void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
547                     int32 input_offset, const uint8* filter_data,
548                     const Dims<4>& filter_dims, int32 filter_offset,
549                     const int32* bias_data, const Dims<4>& bias_dims,
550                     int32 output_offset, int32 output_multiplier,
551                     int output_shift, int32 output_activation_min,
552                     int32 output_activation_max, uint8* output_data,
553                     const Dims<4>& output_dims,
554                     gemmlowp::GemmContext* gemmlowp_context) {
555   static_assert(Ac == FusedActivationFunctionType::kNone ||
556                     Ac == FusedActivationFunctionType::kRelu ||
557                     Ac == FusedActivationFunctionType::kRelu6 ||
558                     Ac == FusedActivationFunctionType::kRelu1,
559                 "");
560   if (Ac == FusedActivationFunctionType::kNone) {
561     TFLITE_DCHECK_EQ(output_activation_min, 0);
562     TFLITE_DCHECK_EQ(output_activation_max, 255);
563   }
564   FullyConnected(input_data, input_dims, input_offset, filter_data, filter_dims,
565                  filter_offset, bias_data, bias_dims, output_offset,
566                  output_multiplier, output_shift, output_activation_min,
567                  output_activation_max, output_data, output_dims,
568                  gemmlowp_context);
569 }
570 
LstmCell(const float * input_data,const Dims<4> & input_dims,const float * prev_activ_data,const Dims<4> & prev_activ_dims,const float * weights_data,const Dims<4> & weights_dims,const float * bias_data,const Dims<4> & bias_dims,const float * prev_state_data,const Dims<4> & prev_state_dims,float * output_state_data,const Dims<4> & output_state_dims,float * output_activ_data,const Dims<4> & output_activ_dims,float * concat_temp_data,const Dims<4> & concat_temp_dims,float * activ_temp_data,const Dims<4> & activ_temp_dims)571 inline void LstmCell(const float* input_data, const Dims<4>& input_dims,
572                      const float* prev_activ_data,
573                      const Dims<4>& prev_activ_dims, const float* weights_data,
574                      const Dims<4>& weights_dims, const float* bias_data,
575                      const Dims<4>& bias_dims, const float* prev_state_data,
576                      const Dims<4>& prev_state_dims, float* output_state_data,
577                      const Dims<4>& output_state_dims, float* output_activ_data,
578                      const Dims<4>& output_activ_dims, float* concat_temp_data,
579                      const Dims<4>& concat_temp_dims, float* activ_temp_data,
580                      const Dims<4>& activ_temp_dims) {
581   tflite::LstmCellParams op_params;
582   // Float LSTM cell does not need parameters to be set: leave untouched.
583 
584   LstmCell(op_params, DimsToShape(input_dims), input_data,
585            DimsToShape(prev_activ_dims), prev_activ_data,
586            DimsToShape(weights_dims), weights_data, DimsToShape(bias_dims),
587            bias_data, DimsToShape(prev_state_dims), prev_state_data,
588            DimsToShape(output_state_dims), output_state_data,
589            DimsToShape(output_activ_dims), output_activ_data,
590            DimsToShape(concat_temp_dims), concat_temp_data,
591            DimsToShape(activ_temp_dims), activ_temp_data);
592 }
593 
594 template <int StateIntegerBits>
LstmCell(const uint8 * input_data_uint8,const Dims<4> & input_dims,const uint8 * prev_activ_data_uint8,const Dims<4> & prev_activ_dims,const uint8 * weights_data_uint8,const Dims<4> & weights_dims,const int32 * bias_data_int32,const Dims<4> & bias_dims,const int16 * prev_state_data_int16,const Dims<4> & prev_state_dims,int16 * output_state_data_int16,const Dims<4> & output_state_dims,uint8 * output_activ_data_uint8,const Dims<4> & output_activ_dims,uint8 * concat_temp_data_uint8,const Dims<4> & concat_temp_dims,int16 * activ_temp_data_int16,const Dims<4> & activ_temp_dims,int32 weights_zero_point,int32 accum_multiplier,int accum_shift,gemmlowp::GemmContext * gemmlowp_context)595 void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
596               const uint8* prev_activ_data_uint8,
597               const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8,
598               const Dims<4>& weights_dims, const int32* bias_data_int32,
599               const Dims<4>& bias_dims, const int16* prev_state_data_int16,
600               const Dims<4>& prev_state_dims, int16* output_state_data_int16,
601               const Dims<4>& output_state_dims, uint8* output_activ_data_uint8,
602               const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8,
603               const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16,
604               const Dims<4>& activ_temp_dims, int32 weights_zero_point,
605               int32 accum_multiplier, int accum_shift,
606               gemmlowp::GemmContext* gemmlowp_context) {
607   tflite::LstmCellParams op_params;
608   op_params.weights_zero_point = weights_zero_point;
609   op_params.accum_multiplier = accum_multiplier;
610   op_params.accum_shift = accum_shift;
611 
612   LstmCell<StateIntegerBits>(
613       op_params, DimsToShape(input_dims), input_data_uint8,
614       DimsToShape(prev_activ_dims), prev_activ_data_uint8,
615       DimsToShape(weights_dims), weights_data_uint8, DimsToShape(bias_dims),
616       bias_data_int32, DimsToShape(prev_state_dims), prev_state_data_int16,
617       DimsToShape(output_state_dims), output_state_data_int16,
618       DimsToShape(output_activ_dims), output_activ_data_uint8,
619       DimsToShape(concat_temp_dims), concat_temp_data_uint8,
620       DimsToShape(activ_temp_dims), activ_temp_data_int16, gemmlowp_context);
621 }
622 
623 template <typename T>
BroadcastDiv(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)624 void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
625                   const T* input2_data, const Dims<4>& input2_dims,
626                   T output_activation_min, T output_activation_max,
627                   T* output_data, const Dims<4>& output_dims) {
628   tflite::ArithmeticParams op_params;
629   SetActivationParams(output_activation_min, output_activation_max, &op_params);
630 
631   BroadcastDivSlow(op_params, DimsToShape(input1_dims), input1_data,
632                    DimsToShape(input2_dims), input2_data,
633                    DimsToShape(output_dims), output_data);
634 }
635 
636 template <typename T>
Div(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)637 inline void Div(const T* input1_data, const Dims<4>& input1_dims,
638                 const T* input2_data, const Dims<4>& input2_dims,
639                 T output_activation_min, T output_activation_max,
640                 T* output_data, const Dims<4>& output_dims) {
641   tflite::ArithmeticParams op_params;
642   SetActivationParams(output_activation_min, output_activation_max, &op_params);
643 
644   Div(op_params, DimsToShape(input1_dims), input1_data,
645       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
646       output_data);
647 }
648 
649 template <FusedActivationFunctionType Ac, typename Scalar>
Concatenation(int concat_dim,const Scalar * const * input_data,const Dims<4> * const * input_dims,int inputs_count,Scalar * output_data,const Dims<4> & output_dims)650 inline void Concatenation(int concat_dim, const Scalar* const* input_data,
651                           const Dims<4>* const* input_dims, int inputs_count,
652                           Scalar* output_data, const Dims<4>& output_dims) {
653   // For now we don't have a model with a Concatenation with fused activation.
654   TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
655 
656   std::vector<RuntimeShape> input_shapes(inputs_count);
657   std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
658   for (int i = 0; i < inputs_count; ++i) {
659     ShapeFromDims(*input_dims[i], &input_shapes[i]);
660     input_shapes_indirect[i] = &input_shapes[i];
661   }
662   tflite::ConcatenationParams op_params;
663   op_params.axis = 3 - concat_dim;
664   op_params.inputs_count = inputs_count;
665 
666   Concatenation(op_params, input_shapes_indirect.data(), input_data,
667                 DimsToShape(output_dims), output_data);
668 }
669 
Concatenation(int concat_dim,const uint8 * const * input_data,const Dims<4> * const * input_dims,const int32 * input_zeropoint,const float * input_scale,int inputs_count,uint8 * output_data,const Dims<4> & output_dims,const int32 output_zeropoint,const float output_scale)670 inline void Concatenation(int concat_dim, const uint8* const* input_data,
671                           const Dims<4>* const* input_dims,
672                           const int32* input_zeropoint,
673                           const float* input_scale, int inputs_count,
674                           uint8* output_data, const Dims<4>& output_dims,
675                           const int32 output_zeropoint,
676                           const float output_scale) {
677   std::vector<RuntimeShape> input_shapes(inputs_count);
678   std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
679   for (int i = 0; i < inputs_count; ++i) {
680     ShapeFromDims(*input_dims[i], &input_shapes[i]);
681     input_shapes_indirect[i] = &input_shapes[i];
682   }
683   tflite::ConcatenationParams op_params;
684   op_params.axis = 3 - concat_dim;
685   op_params.input_zeropoint = input_zeropoint;
686   op_params.input_scale = input_scale;
687   op_params.inputs_count = inputs_count;
688   op_params.output_zeropoint = output_zeropoint;
689   op_params.output_scale = output_scale;
690 
691   ConcatenationWithScaling(op_params, input_shapes_indirect.data(), input_data,
692                            DimsToShape(output_dims), output_data);
693 }
694 
695 template <FusedActivationFunctionType Ac, typename Scalar>
DepthConcatenation(const Scalar * const * input_data,const Dims<4> * const * input_dims,int inputs_count,Scalar * output_data,const Dims<4> & output_dims)696 void DepthConcatenation(const Scalar* const* input_data,
697                         const Dims<4>* const* input_dims, int inputs_count,
698                         Scalar* output_data, const Dims<4>& output_dims) {
699   // For now we don't have a model with a Concatenation with fused activation.
700   TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
701 
702   std::vector<RuntimeShape> input_shapes(inputs_count);
703   std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
704   for (int i = 0; i < inputs_count; ++i) {
705     ShapeFromDims(*input_dims[i], &input_shapes[i]);
706     input_shapes_indirect[i] = &input_shapes[i];
707   }
708   tflite::ConcatenationParams op_params;
709   op_params.inputs_count = inputs_count;
710 
711   DepthConcatenation(op_params, input_shapes_indirect.data(), input_data,
712                      DimsToShape(output_dims), output_data);
713 }
714 
715 template <typename Scalar>
TensorFlowSplit(const Scalar * input_data,const Dims<4> & input_dims,int axis,int outputs_count,Scalar * const * output_data,const Dims<4> * const * output_dims)716 void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
717                      int axis, int outputs_count, Scalar* const* output_data,
718                      const Dims<4>* const* output_dims) {
719   std::vector<RuntimeShape> output_shapes(outputs_count);
720   std::vector<const RuntimeShape*> output_shapes_indirect(outputs_count);
721   for (int i = 0; i < outputs_count; ++i) {
722     ShapeFromDims(*output_dims[i], &output_shapes[i]);
723     output_shapes_indirect[i] = &output_shapes[i];
724   }
725   tflite::SplitParams op_params;
726   op_params.axis = 3 - axis;
727   op_params.num_split = outputs_count;
728 
729   Split(op_params, DimsToShape(input_dims), input_data,
730         output_shapes_indirect.data(), output_data);
731 }
732 
733 template <FusedActivationFunctionType Ac, typename Scalar>
TensorFlowSplit(const Scalar * input_data,const Dims<4> & input_dims,int outputs_count,Scalar * const * output_data,const Dims<4> * const * output_dims)734 void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
735                      int outputs_count, Scalar* const* output_data,
736                      const Dims<4>* const* output_dims) {
737   TFLITE_DCHECK_GE(outputs_count, 1);
738   for (int i = 0; i < outputs_count; i++) {
739     /* batches = */ MatchingArraySize(*output_dims[i], 3, input_dims, 3);
740     /* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2);
741     /* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1);
742   }
743   // For now we don't have a model with a Split with fused activation.
744   TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
745 
746   TensorFlowSplit(input_data, input_dims, /*axis=*/0, outputs_count,
747                   output_data, output_dims);
748 }
749 
Softmax(const float * input_data,const RuntimeShape & input_shape,float beta,float * output_data,const RuntimeShape & output_shape)750 inline void Softmax(const float* input_data, const RuntimeShape& input_shape,
751                     float beta, float* output_data,
752                     const RuntimeShape& output_shape) {
753   SoftmaxParams params;
754   params.beta = beta;
755   Softmax(params, input_shape, input_data, output_shape, output_data);
756 }
757 
Softmax(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_beta_multiplier,int32 input_beta_left_shift,int diff_min,uint8 * output_data,const RuntimeShape & output_shape)758 inline void Softmax(const uint8* input_data, const RuntimeShape& input_shape,
759                     int32 input_beta_multiplier, int32 input_beta_left_shift,
760                     int diff_min, uint8* output_data,
761                     const RuntimeShape& output_shape) {
762   SoftmaxParams params;
763   params.input_multiplier = input_beta_multiplier;
764   params.input_left_shift = input_beta_left_shift;
765   params.diff_min = diff_min;
766   Softmax(params, input_shape, input_data, output_shape, output_data);
767 }
768 
LogSoftmax(const float * input_data,const RuntimeShape & input_shape,float * output_data,const RuntimeShape & output_shape)769 inline void LogSoftmax(const float* input_data, const RuntimeShape& input_shape,
770                        float* output_data, const RuntimeShape& output_shape) {
771   SoftmaxParams params;
772   // No params currently used for float LogSoftmax.
773   LogSoftmax(params, input_shape, input_data, output_shape, output_data);
774 }
775 
LogSoftmax(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_multiplier,int32 input_left_shift,int32 reverse_scaling_divisor,int32 reverse_scaling_right_shift,int diff_min,uint8 * output_data,const RuntimeShape & output_shape)776 inline void LogSoftmax(const uint8* input_data, const RuntimeShape& input_shape,
777                        int32 input_multiplier, int32 input_left_shift,
778                        int32 reverse_scaling_divisor,
779                        int32 reverse_scaling_right_shift, int diff_min,
780                        uint8* output_data, const RuntimeShape& output_shape) {
781   SoftmaxParams params;
782   params.input_multiplier = input_multiplier;
783   params.input_left_shift = input_left_shift;
784   params.reverse_scaling_divisor = reverse_scaling_divisor;
785   params.reverse_scaling_right_shift = reverse_scaling_right_shift;
786   params.diff_min = diff_min;
787   LogSoftmax(params, input_shape, input_data, output_shape, output_data);
788 }
789 
Logistic(const LogisticParams & params,const RuntimeShape & input_shape,const uint8 * input_data,const RuntimeShape & output_shape,uint8 * output_data)790 inline void Logistic(const LogisticParams& params,
791                      const RuntimeShape& input_shape, const uint8* input_data,
792                      const RuntimeShape& output_shape, uint8* output_data) {
793   const int32 input_zero_point = params.input_zero_point;
794   const int32 input_range_radius = params.input_range_radius;
795   const int32 input_multiplier = params.input_multiplier;
796   const int input_left_shift = params.input_left_shift;
797   const int flat_size = MatchingFlatSize(input_shape, output_shape);
798 
799   for (int i = 0; i < flat_size; i++) {
800     const uint8 input_val_u8 = input_data[i];
801     const int32 input_val_centered =
802         static_cast<int32>(input_val_u8) - input_zero_point;
803     uint8 output_val;
804     if (input_val_centered <= -input_range_radius) {
805       output_val = 0;
806     } else if (input_val_centered >= input_range_radius) {
807       output_val = 255;
808     } else {
809       const int32 input_val_rescaled =
810           MultiplyByQuantizedMultiplierGreaterThanOne(
811               input_val_centered, input_multiplier, input_left_shift);
812       using FixedPoint4 = gemmlowp::FixedPoint<int32, 4>;
813       using FixedPoint0 = gemmlowp::FixedPoint<int32, 0>;
814       const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled);
815       const FixedPoint0 output_val_f0 = gemmlowp::logistic(input_val_f4);
816       // Convert from Q0.31 to Q23.8.
817       using gemmlowp::RoundingDivideByPOT;
818       int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 23);
819       if (output_val_s32 == 256) {
820         output_val_s32 = 255;
821       }
822       // Reinterpret as U0.8.
823       TFLITE_DCHECK_GE(output_val_s32, 0);
824       TFLITE_DCHECK_LE(output_val_s32, 255);
825       output_val = static_cast<uint8>(output_val_s32);
826     }
827     output_data[i] = output_val;
828   }
829 }
830 
Logistic(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const RuntimeShape & output_shape)831 inline void Logistic(const uint8* input_data, const RuntimeShape& input_shape,
832                      int32 input_zero_point, int32 input_range_radius,
833                      int32 input_multiplier, int input_left_shift,
834                      uint8* output_data, const RuntimeShape& output_shape) {
835   LogisticParams params;
836   params.input_zero_point = input_zero_point;
837   params.input_range_radius = input_range_radius;
838   params.input_multiplier = input_multiplier;
839   params.input_left_shift = input_left_shift;
840   Logistic(params, input_shape, input_data, output_shape, output_data);
841 }
842 
Logistic(const RuntimeShape & input_shape,const int16 * input_data,const RuntimeShape & output_shape,int16 * output_data)843 inline void Logistic(const RuntimeShape& input_shape, const int16* input_data,
844                      const RuntimeShape& output_shape, int16* output_data) {
845   LogisticParams params;
846   // No params currently needed by int16 Logistic.
847   Logistic(params, input_shape, input_data, output_shape, output_data);
848 }
849 
Tanh(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const RuntimeShape & output_shape)850 inline void Tanh(const uint8* input_data, const RuntimeShape& input_shape,
851                  int32 input_zero_point, int32 input_range_radius,
852                  int32 input_multiplier, int input_left_shift,
853                  uint8* output_data, const RuntimeShape& output_shape) {
854   TanhParams params;
855   params.input_zero_point = input_zero_point;
856   params.input_range_radius = input_range_radius;
857   params.input_multiplier = input_multiplier;
858   params.input_left_shift = input_left_shift;
859   Tanh(params, input_shape, input_data, output_shape, output_data);
860 }
861 
Tanh(const int16 * input_data,const RuntimeShape & input_shape,int input_left_shift,int16 * output_data,const RuntimeShape & output_shape)862 inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
863                  int input_left_shift, int16* output_data,
864                  const RuntimeShape& output_shape) {
865   TanhParams params;
866   params.input_left_shift = input_left_shift;
867   Tanh(params, input_shape, input_data, output_shape, output_data);
868 }
869 
Dequantize(const uint8 * input_data,const Dims<4> & input_dims,int32 zero_point,double scale,float * output_data,const Dims<4> & output_dims)870 inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
871                        int32 zero_point, double scale, float* output_data,
872                        const Dims<4>& output_dims) {
873   tflite::DequantizationParams op_params;
874   op_params.zero_point = zero_point;
875   op_params.scale = scale;
876 
877   Dequantize(op_params, DimsToShape(input_dims), input_data,
878              DimsToShape(output_dims), output_data);
879 }
880 
FakeQuant(const float * input_data,const Dims<4> & input_dims,float rmin,float rmax,int num_bits,float * output_data,const Dims<4> & output_dims)881 inline void FakeQuant(const float* input_data, const Dims<4>& input_dims,
882                       float rmin, float rmax, int num_bits, float* output_data,
883                       const Dims<4>& output_dims) {
884   tflite::FakeQuantParams op_params;
885   op_params.num_bits = num_bits;
886   op_params.minmax.min = rmin;
887   op_params.minmax.max = rmax;
888 
889   FakeQuant(op_params, DimsToShape(input_dims), input_data,
890             DimsToShape(output_dims), output_data);
891 }
892 
893 template <typename T>
Gather(const T * input_data,const Dims<4> & input_dims,int input_rank,const int32 * coords_data,const Dims<4> & coords_dims,T * output_data,const Dims<4> & output_dims)894 inline void Gather(const T* input_data, const Dims<4>& input_dims,
895                    int input_rank, const int32* coords_data,
896                    const Dims<4>& coords_dims, T* output_data,
897                    const Dims<4>& output_dims) {
898   tflite::GatherParams op_params;
899   op_params.axis = 4 - input_rank;
900   op_params.batch_dims = 0;
901 
902   Gather(op_params, DimsToShape(input_dims), input_data,
903          DimsToShape(coords_dims), coords_data, DimsToShape(output_dims),
904          output_data);
905 }
906 
LegacyReverseBits32(uint32 n)907 inline uint32 LegacyReverseBits32(uint32 n) {
908   n = ((n >> 1) & 0x55555555) | ((n & 0x55555555) << 1);
909   n = ((n >> 2) & 0x33333333) | ((n & 0x33333333) << 2);
910   n = ((n >> 4) & 0x0F0F0F0F) | ((n & 0x0F0F0F0F) << 4);
911   return (((n & 0xFF) << 24) | ((n & 0xFF00) << 8) | ((n & 0xFF0000) >> 8) |
912           ((n & 0xFF000000) >> 24));
913 }
914 
StridedSliceReverseIndices(tflite::StridedSliceParams * p)915 inline void StridedSliceReverseIndices(tflite::StridedSliceParams* p) {
916   TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count);
917   TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count);
918 
919   std::reverse(p->start_indices, p->start_indices + p->start_indices_count);
920   std::reverse(p->stop_indices, p->stop_indices + p->stop_indices_count);
921   std::reverse(p->strides, p->strides + p->strides_count);
922 
923   p->begin_mask = LegacyReverseBits32(static_cast<uint32>(p->begin_mask)) >>
924                   (32 - p->start_indices_count);
925   p->ellipsis_mask =
926       LegacyReverseBits32(static_cast<uint32>(p->ellipsis_mask)) >>
927       (32 - p->start_indices_count);
928   p->end_mask = LegacyReverseBits32(static_cast<uint32>(p->end_mask)) >>
929                 (32 - p->start_indices_count);
930   p->new_axis_mask =
931       LegacyReverseBits32(static_cast<uint32>(p->new_axis_mask)) >>
932       (32 - p->start_indices_count);
933   p->shrink_axis_mask =
934       LegacyReverseBits32(static_cast<uint32>(p->shrink_axis_mask)) >>
935       (32 - p->start_indices_count);
936 }
937 
938 template <typename T>
StridedSlice(const T * input_data,const Dims<4> & input_dims,int begin_mask,int end_mask,int shrink_axis_mask,const std::vector<int> & start_indices,const std::vector<int> & stop_indices,const std::vector<int> & strides,T * output_data,const Dims<4> & output_dims)939 inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
940                          int begin_mask, int end_mask, int shrink_axis_mask,
941                          const std::vector<int>& start_indices,
942                          const std::vector<int>& stop_indices,
943                          const std::vector<int>& strides, T* output_data,
944                          const Dims<4>& output_dims) {
945   TFLITE_DCHECK_EQ(start_indices.size(), 4);
946   auto op_params = strided_slice::BuildStridedSliceParams(
947       begin_mask, end_mask, shrink_axis_mask, start_indices, stop_indices,
948       strides);
949   StridedSliceReverseIndices(&op_params);
950 
951   StridedSlice(op_params, DimsToShape(input_dims), input_data,
952                DimsToShape(output_dims), output_data);
953 }
954 
955 template <typename T>
Mean(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & reduction_indices,T * output_data,const Dims<4> & output_dims)956 inline void Mean(const T* input_data, const Dims<4>& input_dims,
957                  const std::vector<int>& reduction_indices, T* output_data,
958                  const Dims<4>& output_dims) {
959   tflite::MeanParams op_params;
960   op_params.axis_count = reduction_indices.size();
961   for (int i = 0; i < op_params.axis_count; ++i) {
962     op_params.axis[i] = reduction_indices[op_params.axis_count - 1 - i];
963   }
964 
965   Mean(op_params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
966        output_data);
967 }
968 
969 template <typename T>
Transpose(const T * input,const Dims<4> & input_dims,T * output,const Dims<4> & output_dims,const int * permuted_axes)970 void Transpose(const T* input, const Dims<4>& input_dims, T* output,
971                const Dims<4>& output_dims, const int* permuted_axes) {
972   TransposeParams params;
973   params.perm_count = 4;
974   for (int i = 0; i < 4; ++i) {
975     params.perm[i] = 3 - permuted_axes[3 - i];
976   }
977   Transpose(params, DimsToShape(input_dims), input, DimsToShape(output_dims),
978             output);
979 }
980 
981 template <typename T, ComparisonFn<T> F>
Comparison(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,bool * output_data,const Dims<4> & output_dims)982 inline void Comparison(const T* input1_data, const Dims<4>& input1_dims,
983                        const T* input2_data, const Dims<4>& input2_dims,
984                        bool* output_data, const Dims<4>& output_dims) {
985   ComparisonParams op_params;
986   // No parameters needed.
987   ComparisonImpl<T, F>(op_params, DimsToShape(input1_dims), input1_data,
988                        DimsToShape(input2_dims), input2_data,
989                        DimsToShape(output_dims), output_data);
990 }
991 
992 template <typename T, ComparisonFn<int32> F>
Comparison(int left_shift,const T * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const T * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,bool * output_data,const Dims<4> & output_dims)993 inline void Comparison(int left_shift, const T* input1_data,
994                        const Dims<4>& input1_dims, int32 input1_offset,
995                        int32 input1_multiplier, int input1_shift,
996                        const T* input2_data, const Dims<4>& input2_dims,
997                        int32 input2_offset, int32 input2_multiplier,
998                        int input2_shift, bool* output_data,
999                        const Dims<4>& output_dims) {
1000   tflite::ComparisonParams op_params;
1001   op_params.left_shift = left_shift;
1002   op_params.input1_offset = input1_offset;
1003   op_params.input1_multiplier = input1_multiplier;
1004   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
1005   op_params.input1_shift = kReverseShift * input1_shift;
1006   op_params.input2_offset = input2_offset;
1007   op_params.input2_multiplier = input2_multiplier;
1008   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
1009   op_params.input2_shift = kReverseShift * input2_shift;
1010 
1011   ComparisonWithScaling<T, F>(op_params, DimsToShape(input1_dims), input1_data,
1012                               DimsToShape(input2_dims), input2_data,
1013                               DimsToShape(output_dims), output_data);
1014 }
1015 
1016 template <typename T, ComparisonFn<T> F>
BroadcastComparison(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,bool * output_data,const Dims<4> & output_dims)1017 inline void BroadcastComparison(const T* input1_data,
1018                                 const Dims<4>& input1_dims,
1019                                 const T* input2_data,
1020                                 const Dims<4>& input2_dims, bool* output_data,
1021                                 const Dims<4>& output_dims) {
1022   ComparisonParams op_params;
1023   // No parameters needed.
1024   BroadcastComparison4DSlowImpl<T, F>(op_params, DimsToShape(input1_dims),
1025                                       input1_data, DimsToShape(input2_dims),
1026                                       input2_data, DimsToShape(output_dims),
1027                                       output_data);
1028 }
1029 
1030 template <typename T, ComparisonFn<int32> F>
BroadcastComparison(int left_shift,const T * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const T * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,bool * output_data,const Dims<4> & output_dims)1031 inline void BroadcastComparison(int left_shift, const T* input1_data,
1032                                 const Dims<4>& input1_dims, int32 input1_offset,
1033                                 int32 input1_multiplier, int input1_shift,
1034                                 const T* input2_data,
1035                                 const Dims<4>& input2_dims, int32 input2_offset,
1036                                 int32 input2_multiplier, int input2_shift,
1037                                 bool* output_data, const Dims<4>& output_dims) {
1038   ComparisonParams op_params;
1039 
1040   op_params.left_shift = left_shift;
1041   op_params.input1_offset = input1_offset;
1042   op_params.input1_multiplier = input1_multiplier;
1043   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
1044   op_params.input1_shift = kReverseShift * input1_shift;
1045   op_params.input2_offset = input2_offset;
1046   op_params.input2_multiplier = input2_multiplier;
1047   // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
1048   op_params.input2_shift = kReverseShift * input2_shift;
1049 
1050   BroadcastComparison4DSlowWithScaling<T, F>(
1051       op_params, DimsToShape(input1_dims), input1_data,
1052       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1053       output_data);
1054 }
1055 
1056 #define TFLITE_LEGACY_COMPARISON_OP(name)                                     \
1057   template <typename T>                                                       \
1058   inline void name(const T* input1_data, const Dims<4>& input1_dims,          \
1059                    const T* input2_data, const Dims<4>& input2_dims,          \
1060                    bool* output_data, const Dims<4>& output_dims) {           \
1061     ruy::profiler::ScopeLabel label(#name);                                   \
1062     Comparison<T, name##Fn>(input1_data, input1_dims, input2_data,            \
1063                             input2_dims, output_data, output_dims);           \
1064   }                                                                           \
1065   template <typename T>                                                       \
1066   inline void name(                                                           \
1067       int left_shift, const T* input1_data, const Dims<4>& input1_dims,       \
1068       int32 input1_offset, int32 input1_multiplier, int input1_shift,         \
1069       const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset,  \
1070       int32 input2_multiplier, int input2_shift, bool* output_data,           \
1071       const Dims<4>& output_dims) {                                           \
1072     ruy::profiler::ScopeLabel label(#name "/8bit");                           \
1073     Comparison<T, name##Fn>(left_shift, input1_data, input1_dims,             \
1074                             input1_offset, input1_multiplier, input1_shift,   \
1075                             input2_data, input2_dims, input2_offset,          \
1076                             input2_multiplier, input2_shift, output_data,     \
1077                             output_dims);                                     \
1078   }                                                                           \
1079   template <typename T>                                                       \
1080   inline void Broadcast##name(                                                \
1081       const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, \
1082       const Dims<4>& input2_dims, bool* output_data,                          \
1083       const Dims<4>& output_dims) {                                           \
1084     ruy::profiler::ScopeLabel label("Broadcast" #name);                       \
1085     BroadcastComparison<T, name##Fn>(input1_data, input1_dims, input2_data,   \
1086                                      input2_dims, output_data, output_dims);  \
1087   }                                                                           \
1088   template <typename T>                                                       \
1089   inline void Broadcast##name(                                                \
1090       int left_shift, const T* input1_data, const Dims<4>& input1_dims,       \
1091       int32 input1_offset, int32 input1_multiplier, int input1_shift,         \
1092       const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset,  \
1093       int32 input2_multiplier, int input2_shift, bool* output_data,           \
1094       const Dims<4>& output_dims) {                                           \
1095     ruy::profiler::ScopeLabel label("Broadcast" #name "/8bit");               \
1096     BroadcastComparison<T, name##Fn>(left_shift, input1_data, input1_dims,    \
1097                                      input1_offset, input1_multiplier,        \
1098                                      input1_shift, input2_data, input2_dims,  \
1099                                      input2_offset, input2_multiplier,        \
1100                                      input2_shift, output_data, output_dims); \
1101   }
1102 TFLITE_LEGACY_COMPARISON_OP(Equal);
1103 TFLITE_LEGACY_COMPARISON_OP(NotEqual);
1104 TFLITE_LEGACY_COMPARISON_OP(Greater);
1105 TFLITE_LEGACY_COMPARISON_OP(GreaterEqual);
1106 TFLITE_LEGACY_COMPARISON_OP(Less);
1107 TFLITE_LEGACY_COMPARISON_OP(LessEqual);
1108 #undef TFLITE_LEGACY_COMPARISON_OP
1109 
1110 template <typename D, typename T>
Select(const D * input_condition_data,const Dims<4> & input_condition_dims,const T * input_x_data,const Dims<4> & input_x_dims,const T * input_y_data,const Dims<4> & input_y_dims,T * output_data,const Dims<4> & output_dims)1111 inline void Select(const D* input_condition_data,
1112                    const Dims<4>& input_condition_dims, const T* input_x_data,
1113                    const Dims<4>& input_x_dims, const T* input_y_data,
1114                    const Dims<4>& input_y_dims, T* output_data,
1115                    const Dims<4>& output_dims) {
1116   Select(DimsToShape(input_condition_dims), input_condition_data,
1117          DimsToShape(input_x_dims), input_x_data, DimsToShape(input_y_dims),
1118          input_y_data, DimsToShape(output_dims), output_data);
1119 }
1120 
1121 template <typename D, typename T>
RankOneSelect(const D * input_condition_data,const Dims<4> & input_condition_dims,const T * input_x_data,const Dims<4> & input_x_dims,const T * input_y_data,const Dims<4> & input_y_dims,T * output_data,const Dims<4> & output_dims)1122 inline void RankOneSelect(const D* input_condition_data,
1123                           const Dims<4>& input_condition_dims,
1124                           const T* input_x_data, const Dims<4>& input_x_dims,
1125                           const T* input_y_data, const Dims<4>& input_y_dims,
1126                           T* output_data, const Dims<4>& output_dims) {
1127   RankOneSelect(DimsToShape(input_condition_dims), input_condition_data,
1128                 DimsToShape(input_x_dims), input_x_data,
1129                 DimsToShape(input_y_dims), input_y_data,
1130                 DimsToShape(output_dims), output_data);
1131 }
1132 
1133 template <typename T, typename TI>
SparseToDense(const std::vector<std::vector<TI>> & indices,const T * values,T default_value,T * output_data,const Dims<4> & output_dims,bool value_is_scalar)1134 inline void SparseToDense(const std::vector<std::vector<TI>>& indices,
1135                           const T* values, T default_value, T* output_data,
1136                           const Dims<4>& output_dims, bool value_is_scalar) {
1137   SparseToDense(indices, values, default_value, value_is_scalar,
1138                 DimsToShape(output_dims), output_data);
1139 }
1140 
1141 template <typename Scalar>
Pack(int dim,const Scalar * const * input_data,const Dims<4> * const * input_dims,int inputs_count,Scalar * output_data,const Dims<4> & output_dims)1142 void Pack(int dim, const Scalar* const* input_data,
1143           const Dims<4>* const* input_dims, int inputs_count,
1144           Scalar* output_data, const Dims<4>& output_dims) {
1145   std::vector<RuntimeShape> input_shapes(inputs_count);
1146   std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
1147   for (int i = 0; i < inputs_count; ++i) {
1148     ShapeFromDims(*input_dims[i], &input_shapes[i]);
1149     input_shapes_indirect[i] = &input_shapes[i];
1150   }
1151   tflite::PackParams op_params;
1152   op_params.axis = 3 - dim;
1153   op_params.inputs_count = inputs_count;
1154 
1155   Pack(op_params, input_shapes_indirect.data(), input_data,
1156        DimsToShape(output_dims), output_data);
1157 }
1158 
1159 template <typename Scalar>
Unpack(int axis,const Scalar * input_data,const Dims<4> & input_dims,int dimensions,int outputs_count,Scalar * const * output_datas,const Dims<4> & output_dims)1160 void Unpack(int axis, const Scalar* input_data, const Dims<4>& input_dims,
1161             int dimensions, int outputs_count, Scalar* const* output_datas,
1162             const Dims<4>& output_dims) {
1163   tflite::UnpackParams op_params;
1164   op_params.axis = 3 - axis;
1165   op_params.num_split = outputs_count;
1166 
1167   Unpack(op_params, DimsToShape(input_dims), input_data,
1168          DimsToShape(output_dims), output_datas);
1169 }
1170 
1171 template <typename Scalar>
Pack(int dim,const Scalar * const * input_data,const Dims<4> * const * input_dims,const int32 * input_zeropoint,const float * input_scale,int inputs_count,Scalar * output_data,const Dims<4> & output_dims,const int32 output_zeropoint,const float output_scale)1172 void Pack(int dim, const Scalar* const* input_data,
1173           const Dims<4>* const* input_dims, const int32* input_zeropoint,
1174           const float* input_scale, int inputs_count, Scalar* output_data,
1175           const Dims<4>& output_dims, const int32 output_zeropoint,
1176           const float output_scale) {
1177   std::vector<RuntimeShape> input_shapes(inputs_count);
1178   std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
1179   for (int i = 0; i < inputs_count; ++i) {
1180     ShapeFromDims(*input_dims[i], &input_shapes[i]);
1181     input_shapes_indirect[i] = &input_shapes[i];
1182   }
1183   tflite::PackParams op_params;
1184   op_params.axis = 3 - dim;
1185   op_params.input_zeropoint = input_zeropoint;
1186   op_params.input_scale = input_scale;
1187   op_params.inputs_count = inputs_count;
1188   op_params.output_zeropoint = output_zeropoint;
1189   op_params.output_scale = output_scale;
1190 
1191   PackWithScaling(op_params, input_shapes_indirect.data(), input_data,
1192                   DimsToShape(output_dims), output_data);
1193 }
1194 
1195 template <FusedActivationFunctionType Ac>
L2Normalization(const float * input_data,const RuntimeShape & input_shape,float * output_data,const RuntimeShape & output_shape)1196 void L2Normalization(const float* input_data, const RuntimeShape& input_shape,
1197                      float* output_data, const RuntimeShape& output_shape) {
1198   static_assert(Ac == FusedActivationFunctionType::kNone, "");
1199   tflite::L2NormalizationParams op_params;
1200   // No params need to be set for float.
1201 
1202   L2Normalization(op_params, input_shape, input_data, output_shape,
1203                   output_data);
1204 }
1205 
L2Normalization(const uint8 * input_data,const RuntimeShape & input_shape,int32 input_zero_point,uint8 * output_data,const RuntimeShape & output_shape)1206 inline void L2Normalization(const uint8* input_data,
1207                             const RuntimeShape& input_shape,
1208                             int32 input_zero_point, uint8* output_data,
1209                             const RuntimeShape& output_shape) {
1210   tflite::L2NormalizationParams op_params;
1211   op_params.input_zero_point = input_zero_point;
1212 
1213   L2Normalization(op_params, input_shape, input_data, output_shape,
1214                   output_data);
1215 }
1216 
1217 template <FusedActivationFunctionType Ac>
L2Normalization(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1218 void L2Normalization(const float* input_data, const Dims<4>& input_dims,
1219                      float* output_data, const Dims<4>& output_dims) {
1220   L2Normalization<Ac>(input_data, DimsToShape(input_dims), output_data,
1221                       DimsToShape(output_dims));
1222 }
1223 
L2Normalization(const uint8 * input_data,const Dims<4> & input_dims,int32 input_zero_point,uint8 * output_data,const Dims<4> & output_dims)1224 inline void L2Normalization(const uint8* input_data, const Dims<4>& input_dims,
1225                             int32 input_zero_point, uint8* output_data,
1226                             const Dims<4>& output_dims) {
1227   L2Normalization(input_data, DimsToShape(input_dims), input_zero_point,
1228                   output_data, DimsToShape(output_dims));
1229 }
1230 
Relu(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1231 inline void Relu(const float* input_data, const Dims<4>& input_dims,
1232                  float* output_data, const Dims<4>& output_dims) {
1233   Relu(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1234        output_data);
1235 }
1236 
Relu1(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1237 inline void Relu1(const float* input_data, const Dims<4>& input_dims,
1238                   float* output_data, const Dims<4>& output_dims) {
1239   Relu1(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1240         output_data);
1241 }
1242 
Relu6(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1243 inline void Relu6(const float* input_data, const Dims<4>& input_dims,
1244                   float* output_data, const Dims<4>& output_dims) {
1245   Relu6(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1246         output_data);
1247 }
1248 
ReluX(uint8 min_value,uint8 max_value,const uint8 * input_data,const RuntimeShape & input_shape,uint8 * output_data,const RuntimeShape & output_shape)1249 inline void ReluX(uint8 min_value, uint8 max_value, const uint8* input_data,
1250                   const RuntimeShape& input_shape, uint8* output_data,
1251                   const RuntimeShape& output_shape) {
1252   tflite::ActivationParams params;
1253   params.quantized_activation_max = max_value;
1254   params.quantized_activation_min = min_value;
1255   ReluX(params, input_shape, input_data, output_shape, output_data);
1256 }
1257 
1258 template <FusedActivationFunctionType Ac>
Add(int left_shift,const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1259 inline void Add(int left_shift, const uint8* input1_data,
1260                 const Dims<4>& input1_dims, int32 input1_offset,
1261                 int32 input1_multiplier, int input1_shift,
1262                 const uint8* input2_data, const Dims<4>& input2_dims,
1263                 int32 input2_offset, int32 input2_multiplier, int input2_shift,
1264                 int32 output_offset, int32 output_multiplier, int output_shift,
1265                 int32 output_activation_min, int32 output_activation_max,
1266                 uint8* output_data, const Dims<4>& output_dims) {
1267   constexpr int kReverseShift = -1;
1268   static_assert(Ac == FusedActivationFunctionType::kNone ||
1269                     Ac == FusedActivationFunctionType::kRelu ||
1270                     Ac == FusedActivationFunctionType::kRelu6 ||
1271                     Ac == FusedActivationFunctionType::kRelu1,
1272                 "");
1273   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
1274   if (Ac == FusedActivationFunctionType::kNone) {
1275     TFLITE_DCHECK_EQ(output_activation_min, 0);
1276     TFLITE_DCHECK_EQ(output_activation_max, 255);
1277   }
1278 
1279   tflite::ArithmeticParams op_params;
1280   op_params.left_shift = left_shift;
1281   op_params.input1_offset = input1_offset;
1282   op_params.input1_multiplier = input1_multiplier;
1283   op_params.input1_shift = kReverseShift * input1_shift;
1284   op_params.input2_offset = input2_offset;
1285   op_params.input2_multiplier = input2_multiplier;
1286   op_params.input2_shift = kReverseShift * input2_shift;
1287   op_params.output_offset = output_offset;
1288   op_params.output_multiplier = output_multiplier;
1289   op_params.output_shift = kReverseShift * output_shift;
1290   op_params.quantized_activation_min = output_activation_min;
1291   op_params.quantized_activation_max = output_activation_max;
1292   Add(op_params, DimsToShape(input1_dims), input1_data,
1293       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1294       output_data);
1295 }
1296 
1297 template <FusedActivationFunctionType Ac>
Add(const int32 * input1_data,const Dims<4> & input1_dims,const int32 * input2_data,const Dims<4> & input2_dims,int32 * output_data,const Dims<4> & output_dims)1298 void Add(const int32* input1_data, const Dims<4>& input1_dims,
1299          const int32* input2_data, const Dims<4>& input2_dims,
1300          int32* output_data, const Dims<4>& output_dims) {
1301   ruy::profiler::ScopeLabel label("Add/int32");
1302   TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
1303 
1304   tflite::ArithmeticParams op_params;
1305   op_params.quantized_activation_min = std::numeric_limits<int32>::min();
1306   op_params.quantized_activation_max = std::numeric_limits<int32>::max();
1307   Add(op_params, DimsToShape(input1_dims), input1_data,
1308       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1309       output_data);
1310 }
1311 
1312 template <FusedActivationFunctionType Ac>
BroadcastAdd(int left_shift,const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1313 inline void BroadcastAdd(int left_shift, const uint8* input1_data,
1314                          const Dims<4>& input1_dims, int32 input1_offset,
1315                          int32 input1_multiplier, int input1_shift,
1316                          const uint8* input2_data, const Dims<4>& input2_dims,
1317                          int32 input2_offset, int32 input2_multiplier,
1318                          int input2_shift, int32 output_offset,
1319                          int32 output_multiplier, int output_shift,
1320                          int32 output_activation_min,
1321                          int32 output_activation_max, uint8* output_data,
1322                          const Dims<4>& output_dims) {
1323   constexpr int kReverseShift = -1;
1324   static_assert(Ac == FusedActivationFunctionType::kNone ||
1325                     Ac == FusedActivationFunctionType::kRelu ||
1326                     Ac == FusedActivationFunctionType::kRelu6 ||
1327                     Ac == FusedActivationFunctionType::kRelu1,
1328                 "");
1329   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
1330   if (Ac == FusedActivationFunctionType::kNone) {
1331     TFLITE_DCHECK_EQ(output_activation_min, 0);
1332     TFLITE_DCHECK_EQ(output_activation_max, 255);
1333   }
1334 
1335   tflite::ArithmeticParams op_params;
1336   op_params.left_shift = left_shift;
1337   op_params.input1_offset = input1_offset;
1338   op_params.input1_multiplier = input1_multiplier;
1339   op_params.input1_shift = kReverseShift * input1_shift;
1340   op_params.input2_offset = input2_offset;
1341   op_params.input2_multiplier = input2_multiplier;
1342   op_params.input2_shift = kReverseShift * input2_shift;
1343   op_params.output_offset = output_offset;
1344   op_params.output_multiplier = output_multiplier;
1345   op_params.output_shift = kReverseShift * output_shift;
1346   op_params.quantized_activation_min = output_activation_min;
1347   op_params.quantized_activation_max = output_activation_max;
1348   BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
1349                      DimsToShape(input2_dims), input2_data,
1350                      DimsToShape(output_dims), output_data);
1351 }
1352 
1353 template <FusedActivationFunctionType Ac>
Add(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)1354 void Add(const float* input1_data, const Dims<4>& input1_dims,
1355          const float* input2_data, const Dims<4>& input2_dims,
1356          float* output_data, const Dims<4>& output_dims) {
1357   float output_activation_min, output_activation_max;
1358   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1359 
1360   tflite::ArithmeticParams op_params;
1361   op_params.float_activation_min = output_activation_min;
1362   op_params.float_activation_max = output_activation_max;
1363   Add(op_params, DimsToShape(input1_dims), input1_data,
1364       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1365       output_data);
1366 }
1367 
1368 template <typename T>
BroadcastAdd(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)1369 void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
1370                   const T* input2_data, const Dims<4>& input2_dims,
1371                   T output_activation_min, T output_activation_max,
1372                   T* output_data, const Dims<4>& output_dims) {
1373   tflite::ArithmeticParams op_params;
1374   op_params.float_activation_min = output_activation_min;
1375   op_params.float_activation_max = output_activation_max;
1376   BroadcastAdd4DSlow(op_params, DimsToShape(input1_dims), input1_data,
1377                      DimsToShape(input2_dims), input2_data,
1378                      DimsToShape(output_dims), output_data);
1379 }
1380 
1381 template <FusedActivationFunctionType Ac>
BroadcastAddFivefold(int y0,int y1,int y2,int y3,int y4,int left_shift,const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,int32 input1_multiplier,int input1_shift,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 input2_multiplier,int input2_shift,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1382 inline void BroadcastAddFivefold(
1383     int y0, int y1, int y2, int y3, int y4, int left_shift,
1384     const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset,
1385     int32 input1_multiplier, int input1_shift, const uint8* input2_data,
1386     const Dims<4>& input2_dims, int32 input2_offset, int32 input2_multiplier,
1387     int input2_shift, int32 output_offset, int32 output_multiplier,
1388     int output_shift, int32 output_activation_min, int32 output_activation_max,
1389     uint8* output_data, const Dims<4>& output_dims) {
1390   constexpr int kReverseShift = -1;
1391   static_assert(Ac == FusedActivationFunctionType::kNone ||
1392                     Ac == FusedActivationFunctionType::kRelu ||
1393                     Ac == FusedActivationFunctionType::kRelu6 ||
1394                     Ac == FusedActivationFunctionType::kRelu1,
1395                 "");
1396   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
1397   if (Ac == FusedActivationFunctionType::kNone) {
1398     TFLITE_DCHECK_EQ(output_activation_min, 0);
1399     TFLITE_DCHECK_EQ(output_activation_max, 255);
1400   }
1401   tflite::ArithmeticParams op_params;
1402   op_params.broadcast_category =
1403       tflite::BroadcastableOpCategory::kFirstInputBroadcastsFast;
1404   op_params.left_shift = left_shift;
1405   op_params.input1_offset = input1_offset;
1406   op_params.input1_multiplier = input1_multiplier;
1407   op_params.input1_shift = kReverseShift * input1_shift;
1408   op_params.input2_offset = input2_offset;
1409   op_params.input2_multiplier = input2_multiplier;
1410   op_params.input2_shift = kReverseShift * input2_shift;
1411   op_params.output_offset = output_offset;
1412   op_params.output_multiplier = output_multiplier;
1413   op_params.output_shift = kReverseShift * output_shift;
1414   op_params.quantized_activation_min = output_activation_min;
1415   op_params.quantized_activation_max = output_activation_max;
1416   op_params.broadcast_shape[4] = y0;
1417   op_params.broadcast_shape[3] = y1;
1418   op_params.broadcast_shape[2] = y2;
1419   op_params.broadcast_shape[1] = y3;
1420   op_params.broadcast_shape[0] = y4;
1421   BroadcastAddFivefold(op_params, DimsToShape(input1_dims), input1_data,
1422                        DimsToShape(input2_dims), input2_data,
1423                        DimsToShape(output_dims), output_data);
1424 }
1425 
1426 // legacy, for compatibility with old checked-in code
1427 template <FusedActivationFunctionType Ac, typename T>
BroadcastAdd(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)1428 void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims,
1429                   const T* input2_data, const Dims<4>& input2_dims,
1430                   T* output_data, const Dims<4>& output_dims) {
1431   T output_activation_min, output_activation_max;
1432   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1433 
1434   BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims,
1435                output_activation_min, output_activation_max, output_data,
1436                output_dims);
1437 }
1438 
1439 template <FusedActivationFunctionType Ac>
Add(const int16 * input1_data,const Dims<4> & input1_dims,int input1_shift,const int16 * input2_data,const Dims<4> & input2_dims,int input2_shift,int16 output_activation_min,int16 output_activation_max,int16 * output_data,const Dims<4> & output_dims)1440 inline void Add(const int16* input1_data, const Dims<4>& input1_dims,
1441                 int input1_shift, const int16* input2_data,
1442                 const Dims<4>& input2_dims, int input2_shift,
1443                 int16 output_activation_min, int16 output_activation_max,
1444                 int16* output_data, const Dims<4>& output_dims) {
1445   static_assert(Ac == FusedActivationFunctionType::kNone ||
1446                     Ac == FusedActivationFunctionType::kRelu ||
1447                     Ac == FusedActivationFunctionType::kRelu6 ||
1448                     Ac == FusedActivationFunctionType::kRelu1,
1449                 "");
1450   TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
1451   if (Ac == FusedActivationFunctionType::kNone) {
1452     TFLITE_DCHECK_EQ(output_activation_min, -32768);
1453     TFLITE_DCHECK_EQ(output_activation_max, 32767);
1454   }
1455 
1456   tflite::ArithmeticParams op_params;
1457   op_params.input1_shift = kReverseShift * input1_shift;
1458   op_params.input2_shift = kReverseShift * input2_shift;
1459   op_params.quantized_activation_min = output_activation_min;
1460   op_params.quantized_activation_max = output_activation_max;
1461   Add(op_params, DimsToShape(input1_dims), input1_data,
1462       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1463       output_data);
1464 }
1465 
Sub(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)1466 inline void Sub(const float* input1_data, const Dims<4>& input1_dims,
1467                 const float* input2_data, const Dims<4>& input2_dims,
1468                 float* output_data, const Dims<4>& output_dims) {
1469   float output_activation_min, output_activation_max;
1470   GetActivationMinMax(FusedActivationFunctionType::kNone,
1471                       &output_activation_min, &output_activation_max);
1472   tflite::ArithmeticParams op_params;
1473   op_params.float_activation_min = output_activation_min;
1474   op_params.float_activation_max = output_activation_max;
1475   Sub(op_params, DimsToShape(input1_dims), input1_data,
1476       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1477       output_data);
1478 }
1479 
1480 template <typename T>
Sub(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)1481 void Sub(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data,
1482          const Dims<4>& input2_dims, T* output_data,
1483          const Dims<4>& output_dims) {
1484   tflite::ArithmeticParams op_params;
1485   op_params.quantized_activation_min = std::numeric_limits<T>::min();
1486   op_params.quantized_activation_max = std::numeric_limits<T>::max();
1487   Sub(op_params, DimsToShape(input1_dims), input1_data,
1488       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1489       output_data);
1490 }
1491 
AveragePool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)1492 inline bool AveragePool(const float* input_data, const Dims<4>& input_dims,
1493                         int stride_width, int stride_height, int pad_width,
1494                         int pad_height, int kwidth, int kheight,
1495                         float output_activation_min,
1496                         float output_activation_max, float* output_data,
1497                         const Dims<4>& output_dims) {
1498   tflite::PoolParams params;
1499   params.stride_height = stride_height;
1500   params.stride_width = stride_width;
1501   params.filter_height = kheight;
1502   params.filter_width = kwidth;
1503   params.padding_values.height = pad_height;
1504   params.padding_values.width = pad_width;
1505   params.float_activation_min = output_activation_min;
1506   params.float_activation_max = output_activation_max;
1507   return AveragePool(params, DimsToShape(input_dims), input_data,
1508                      DimsToShape(output_dims), output_data);
1509 }
1510 
1511 // Transitional version that will be moved shortly to legacy_reference_ops, as
1512 // part of RuntimeShape revisions.
BroadcastMul4DSlow(const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1513 inline void BroadcastMul4DSlow(const uint8* input1_data,
1514                                const Dims<4>& input1_dims, int32 input1_offset,
1515                                const uint8* input2_data,
1516                                const Dims<4>& input2_dims, int32 input2_offset,
1517                                int32 output_offset, int32 output_multiplier,
1518                                int output_shift, int32 output_activation_min,
1519                                int32 output_activation_max, uint8* output_data,
1520                                const Dims<4>& output_dims) {
1521   tflite::ArithmeticParams op_params;
1522   SetActivationParams(output_activation_min, output_activation_max, &op_params);
1523   op_params.input1_offset = input1_offset;
1524   op_params.input2_offset = input2_offset;
1525   op_params.output_offset = output_offset;
1526   op_params.output_multiplier = output_multiplier;
1527   op_params.output_shift = output_shift;
1528 
1529   BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
1530                      DimsToShape(input2_dims), input2_data,
1531                      DimsToShape(output_dims), output_data);
1532 }
1533 
BroadcastMul(const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1534 inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
1535                          int32 input1_offset, const uint8* input2_data,
1536                          const Dims<4>& input2_dims, int32 input2_offset,
1537                          int32 output_offset, int32 output_multiplier,
1538                          int output_shift, int32 output_activation_min,
1539                          int32 output_activation_max, uint8* output_data,
1540                          const Dims<4>& output_dims) {
1541   BroadcastMul4DSlow(
1542       input1_data, input1_dims, input1_offset, input2_data, input2_dims,
1543       input2_offset, output_offset, output_multiplier,
1544       //
1545       kReverseShift * output_shift,
1546       //
1547       output_activation_min, output_activation_max, output_data, output_dims);
1548 }
1549 
1550 // legacy, for compatibility with old checked-in code
1551 template <FusedActivationFunctionType Ac>
BroadcastMul(const uint8 * input1_data,const Dims<4> & input1_dims,int32 input1_offset,const uint8 * input2_data,const Dims<4> & input2_dims,int32 input2_offset,int32 output_offset,int32 output_multiplier,int output_shift,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1552 inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
1553                          int32 input1_offset, const uint8* input2_data,
1554                          const Dims<4>& input2_dims, int32 input2_offset,
1555                          int32 output_offset, int32 output_multiplier,
1556                          int output_shift, int32 output_activation_min,
1557                          int32 output_activation_max, uint8* output_data,
1558                          const Dims<4>& output_dims) {
1559   BroadcastMul(input1_data, input1_dims, input1_offset, input2_data,
1560                input2_dims, input2_offset, output_offset, output_multiplier,
1561                output_shift, output_activation_min, output_activation_max,
1562                output_data, output_dims);
1563 }
1564 
1565 // legacy, for compatibility with old checked-in code
1566 template <FusedActivationFunctionType Ac>
AveragePool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float * output_data,const Dims<4> & output_dims)1567 bool AveragePool(const float* input_data, const Dims<4>& input_dims,
1568                  int stride_width, int stride_height, int pad_width,
1569                  int pad_height, int kwidth, int kheight, float* output_data,
1570                  const Dims<4>& output_dims) {
1571   float output_activation_min, output_activation_max;
1572   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1573 
1574   return AveragePool(input_data, input_dims, stride_width, stride_height,
1575                      pad_width, pad_height, kwidth, kheight,
1576                      output_activation_min, output_activation_max, output_data,
1577                      output_dims);
1578 }
1579 
1580 // legacy, for compatibility with old checked-in code
1581 template <FusedActivationFunctionType Ac>
AveragePool(const float * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)1582 bool AveragePool(const float* input_data, const Dims<4>& input_dims, int stride,
1583                  int pad_width, int pad_height, int filter_width,
1584                  int filter_height, float* output_data,
1585                  const Dims<4>& output_dims) {
1586   return AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width,
1587                          pad_height, filter_width, filter_height, output_data,
1588                          output_dims);
1589 }
1590 
AveragePool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1591 inline bool AveragePool(const uint8* input_data, const Dims<4>& input_dims,
1592                         int stride_width, int stride_height, int pad_width,
1593                         int pad_height, int filter_width, int filter_height,
1594                         int32 output_activation_min,
1595                         int32 output_activation_max, uint8* output_data,
1596                         const Dims<4>& output_dims) {
1597   tflite::PoolParams params;
1598   params.stride_height = stride_height;
1599   params.stride_width = stride_width;
1600   params.filter_height = filter_height;
1601   params.filter_width = filter_width;
1602   params.padding_values.height = pad_height;
1603   params.padding_values.width = pad_width;
1604   params.quantized_activation_min = output_activation_min;
1605   params.quantized_activation_max = output_activation_max;
1606   return AveragePool(params, DimsToShape(input_dims), input_data,
1607                      DimsToShape(output_dims), output_data);
1608 }
1609 
1610 // legacy, for compatibility with old checked-in code
1611 template <FusedActivationFunctionType Ac>
AveragePool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1612 bool AveragePool(const uint8* input_data, const Dims<4>& input_dims,
1613                  int stride_width, int stride_height, int pad_width,
1614                  int pad_height, int filter_width, int filter_height,
1615                  int32 output_activation_min, int32 output_activation_max,
1616                  uint8* output_data, const Dims<4>& output_dims) {
1617   static_assert(Ac == FusedActivationFunctionType::kNone ||
1618                     Ac == FusedActivationFunctionType::kRelu ||
1619                     Ac == FusedActivationFunctionType::kRelu6 ||
1620                     Ac == FusedActivationFunctionType::kRelu1,
1621                 "");
1622   if (Ac == FusedActivationFunctionType::kNone) {
1623     TFLITE_DCHECK_EQ(output_activation_min, 0);
1624     TFLITE_DCHECK_EQ(output_activation_max, 255);
1625   }
1626   return AveragePool(input_data, input_dims, stride_width, stride_height,
1627                      pad_width, pad_height, filter_width, filter_height,
1628                      output_activation_min, output_activation_max, output_data,
1629                      output_dims);
1630 }
1631 
1632 // legacy, for compatibility with old checked-in code
1633 template <FusedActivationFunctionType Ac>
AveragePool(const uint8 * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1634 bool AveragePool(const uint8* input_data, const Dims<4>& input_dims, int stride,
1635                  int pad_width, int pad_height, int filter_width,
1636                  int filter_height, int32 output_activation_min,
1637                  int32 output_activation_max, uint8* output_data,
1638                  const Dims<4>& output_dims) {
1639   return AveragePool<Ac>(input_data, input_dims, stride, stride, pad_width,
1640                          pad_height, filter_width, filter_height,
1641                          output_activation_min, output_activation_max,
1642                          output_data, output_dims);
1643 }
1644 
MaxPool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)1645 inline void MaxPool(const float* input_data, const Dims<4>& input_dims,
1646                     int stride_width, int stride_height, int pad_width,
1647                     int pad_height, int kwidth, int kheight,
1648                     float output_activation_min, float output_activation_max,
1649                     float* output_data, const Dims<4>& output_dims) {
1650   tflite::PoolParams params;
1651   params.stride_height = stride_height;
1652   params.stride_width = stride_width;
1653   params.filter_height = kheight;
1654   params.filter_width = kwidth;
1655   params.padding_values.height = pad_height;
1656   params.padding_values.width = pad_width;
1657   params.float_activation_min = output_activation_min;
1658   params.float_activation_max = output_activation_max;
1659   MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1660           output_data);
1661 }
1662 
1663 // legacy, for compatibility with old checked-in code
1664 template <FusedActivationFunctionType Ac>
MaxPool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int kwidth,int kheight,float * output_data,const Dims<4> & output_dims)1665 void MaxPool(const float* input_data, const Dims<4>& input_dims,
1666              int stride_width, int stride_height, int pad_width, int pad_height,
1667              int kwidth, int kheight, float* output_data,
1668              const Dims<4>& output_dims) {
1669   float output_activation_min, output_activation_max;
1670   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1671   MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
1672           pad_height, kwidth, kheight, output_activation_min,
1673           output_activation_max, output_data, output_dims);
1674 }
1675 
1676 // legacy, for compatibility with old checked-in code
1677 template <FusedActivationFunctionType Ac>
MaxPool(const float * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)1678 void MaxPool(const float* input_data, const Dims<4>& input_dims, int stride,
1679              int pad_width, int pad_height, int filter_width, int filter_height,
1680              float* output_data, const Dims<4>& output_dims) {
1681   MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
1682               filter_width, filter_height, output_data, output_dims);
1683 }
1684 
MaxPool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1685 inline void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
1686                     int stride_width, int stride_height, int pad_width,
1687                     int pad_height, int filter_width, int filter_height,
1688                     int32 output_activation_min, int32 output_activation_max,
1689                     uint8* output_data, const Dims<4>& output_dims) {
1690   PoolParams params;
1691   params.stride_height = stride_height;
1692   params.stride_width = stride_width;
1693   params.filter_height = filter_height;
1694   params.filter_width = filter_width;
1695   params.padding_values.height = pad_height;
1696   params.padding_values.width = pad_width;
1697   params.quantized_activation_min = output_activation_min;
1698   params.quantized_activation_max = output_activation_max;
1699   MaxPool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1700           output_data);
1701 }
1702 
1703 // legacy, for compatibility with old checked-in code
1704 template <FusedActivationFunctionType Ac>
MaxPool(const uint8 * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1705 void MaxPool(const uint8* input_data, const Dims<4>& input_dims,
1706              int stride_width, int stride_height, int pad_width, int pad_height,
1707              int filter_width, int filter_height, int32 output_activation_min,
1708              int32 output_activation_max, uint8* output_data,
1709              const Dims<4>& output_dims) {
1710   static_assert(Ac == FusedActivationFunctionType::kNone ||
1711                     Ac == FusedActivationFunctionType::kRelu ||
1712                     Ac == FusedActivationFunctionType::kRelu6 ||
1713                     Ac == FusedActivationFunctionType::kRelu1,
1714                 "");
1715   if (Ac == FusedActivationFunctionType::kNone) {
1716     TFLITE_DCHECK_EQ(output_activation_min, 0);
1717     TFLITE_DCHECK_EQ(output_activation_max, 255);
1718   }
1719   MaxPool(input_data, input_dims, stride_width, stride_height, pad_width,
1720           pad_height, filter_width, filter_height, output_activation_min,
1721           output_activation_max, output_data, output_dims);
1722 }
1723 
1724 // legacy, for compatibility with old checked-in code
1725 template <FusedActivationFunctionType Ac>
MaxPool(const uint8 * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1726 void MaxPool(const uint8* input_data, const Dims<4>& input_dims, int stride,
1727              int pad_width, int pad_height, int filter_width, int filter_height,
1728              int32 output_activation_min, int32 output_activation_max,
1729              uint8* output_data, const Dims<4>& output_dims) {
1730   MaxPool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
1731               filter_width, filter_height, output_activation_min,
1732               output_activation_max, output_data, output_dims);
1733 }
1734 
L2Pool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,float output_activation_min,float output_activation_max,float * output_data,const Dims<4> & output_dims)1735 inline void L2Pool(const float* input_data, const Dims<4>& input_dims,
1736                    int stride_width, int stride_height, int pad_width,
1737                    int pad_height, int filter_width, int filter_height,
1738                    float output_activation_min, float output_activation_max,
1739                    float* output_data, const Dims<4>& output_dims) {
1740   PoolParams params;
1741   params.stride_height = stride_height;
1742   params.stride_width = stride_width;
1743   params.filter_height = filter_height;
1744   params.filter_width = filter_width;
1745   params.padding_values.height = pad_height;
1746   params.padding_values.width = pad_width;
1747   params.float_activation_min = output_activation_min;
1748   params.float_activation_max = output_activation_max;
1749   L2Pool(params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1750          output_data);
1751 }
1752 
1753 // legacy, for compatibility with old checked-in code
1754 template <FusedActivationFunctionType Ac>
L2Pool(const float * input_data,const Dims<4> & input_dims,int stride_width,int stride_height,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)1755 void L2Pool(const float* input_data, const Dims<4>& input_dims,
1756             int stride_width, int stride_height, int pad_width, int pad_height,
1757             int filter_width, int filter_height, float* output_data,
1758             const Dims<4>& output_dims) {
1759   float output_activation_min, output_activation_max;
1760   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1761   L2Pool(input_data, input_dims, stride_width, stride_height, pad_width,
1762          pad_height, filter_width, filter_height, output_activation_min,
1763          output_activation_max, output_data, output_dims);
1764 }
1765 
1766 // legacy, for compatibility with old checked-in code
1767 template <FusedActivationFunctionType Ac>
L2Pool(const float * input_data,const Dims<4> & input_dims,int stride,int pad_width,int pad_height,int filter_width,int filter_height,float * output_data,const Dims<4> & output_dims)1768 void L2Pool(const float* input_data, const Dims<4>& input_dims, int stride,
1769             int pad_width, int pad_height, int filter_width, int filter_height,
1770             float* output_data, const Dims<4>& output_dims) {
1771   L2Pool<Ac>(input_data, input_dims, stride, stride, pad_width, pad_height,
1772              filter_width, filter_height, output_data, output_dims);
1773 }
1774 
Softmax(const float * input_data,const Dims<4> & input_dims,float beta,float * output_data,const Dims<4> & output_dims)1775 inline void Softmax(const float* input_data, const Dims<4>& input_dims,
1776                     float beta, float* output_data,
1777                     const Dims<4>& output_dims) {
1778   Softmax(input_data, DimsToShape(input_dims), beta, output_data,
1779           DimsToShape(output_dims));
1780 }
1781 
Softmax(const uint8 * input_data,const Dims<4> & input_dims,int32 input_beta_multiplier,int32 input_beta_left_shift,int diff_min,uint8 * output_data,const Dims<4> & output_dims)1782 inline void Softmax(const uint8* input_data, const Dims<4>& input_dims,
1783                     int32 input_beta_multiplier, int32 input_beta_left_shift,
1784                     int diff_min, uint8* output_data,
1785                     const Dims<4>& output_dims) {
1786   Softmax(input_data, DimsToShape(input_dims), input_beta_multiplier,
1787           input_beta_left_shift, diff_min, output_data,
1788           DimsToShape(output_dims));
1789 }
1790 
LogSoftmax(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1791 inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims,
1792                        float* output_data, const Dims<4>& output_dims) {
1793   LogSoftmax(input_data, DimsToShape(input_dims), output_data,
1794              DimsToShape(output_dims));
1795 }
1796 
LogSoftmax(const uint8 * input_data,const Dims<4> & input_dims,int32 input_multiplier,int32 input_left_shift,int32 reverse_scaling_divisor,int32 reverse_scaling_right_shift,int diff_min,uint8 * output_data,const Dims<4> & output_dims)1797 inline void LogSoftmax(const uint8* input_data, const Dims<4>& input_dims,
1798                        int32 input_multiplier, int32 input_left_shift,
1799                        int32 reverse_scaling_divisor,
1800                        int32 reverse_scaling_right_shift, int diff_min,
1801                        uint8* output_data, const Dims<4>& output_dims) {
1802   LogSoftmax(input_data, DimsToShape(input_dims), input_multiplier,
1803              input_left_shift, reverse_scaling_divisor,
1804              reverse_scaling_right_shift, diff_min, output_data,
1805              DimsToShape(output_dims));
1806 }
1807 
Logistic(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1808 inline void Logistic(const float* input_data, const Dims<4>& input_dims,
1809                      float* output_data, const Dims<4>& output_dims) {
1810   Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1811            output_data);
1812 }
1813 
Logistic(const uint8 * input_data,const Dims<4> & input_dims,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const Dims<4> & output_dims)1814 inline void Logistic(const uint8* input_data, const Dims<4>& input_dims,
1815                      int32 input_zero_point, int32 input_range_radius,
1816                      int32 input_multiplier, int input_left_shift,
1817                      uint8* output_data, const Dims<4>& output_dims) {
1818   Logistic(input_data, DimsToShape(input_dims), input_zero_point,
1819            input_range_radius, input_multiplier, input_left_shift, output_data,
1820            DimsToShape(output_dims));
1821 }
1822 
Logistic(const int16 * input_data,const Dims<4> & input_dims,int16 * output_data,const Dims<4> & output_dims)1823 inline void Logistic(const int16* input_data, const Dims<4>& input_dims,
1824                      int16* output_data, const Dims<4>& output_dims) {
1825   Logistic(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1826            output_data);
1827 }
1828 
Tanh(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1829 inline void Tanh(const float* input_data, const Dims<4>& input_dims,
1830                  float* output_data, const Dims<4>& output_dims) {
1831   Tanh(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1832        output_data);
1833 }
1834 
Tanh(const uint8 * input_data,const Dims<4> & input_dims,int32 input_zero_point,int32 input_range_radius,int32 input_multiplier,int input_left_shift,uint8 * output_data,const Dims<4> & output_dims)1835 inline void Tanh(const uint8* input_data, const Dims<4>& input_dims,
1836                  int32 input_zero_point, int32 input_range_radius,
1837                  int32 input_multiplier, int input_left_shift,
1838                  uint8* output_data, const Dims<4>& output_dims) {
1839   Tanh(input_data, DimsToShape(input_dims), input_zero_point,
1840        input_range_radius, input_multiplier, input_left_shift, output_data,
1841        DimsToShape(output_dims));
1842 }
1843 
Tanh(const int16 * input_data,const Dims<4> & input_dims,int input_left_shift,int16 * output_data,const Dims<4> & output_dims)1844 inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
1845                  int input_left_shift, int16* output_data,
1846                  const Dims<4>& output_dims) {
1847   Tanh(input_data, DimsToShape(input_dims), input_left_shift, output_data,
1848        DimsToShape(output_dims));
1849 }
1850 
1851 template <typename T>
DepthToSpace(const T * input_data,const Dims<4> & input_dims,int block_size,T * output_data,const Dims<4> & output_dims)1852 inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
1853                          int block_size, T* output_data,
1854                          const Dims<4>& output_dims) {
1855   tflite::DepthToSpaceParams op_params;
1856   op_params.block_size = block_size;
1857 
1858   DepthToSpace(op_params, DimsToShape(input_dims), input_data,
1859                DimsToShape(output_dims), output_data);
1860 }
1861 
1862 template <typename T>
SpaceToDepth(const T * input_data,const Dims<4> & input_dims,int block_size,T * output_data,const Dims<4> & output_dims)1863 inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
1864                          int block_size, T* output_data,
1865                          const Dims<4>& output_dims) {
1866   tflite::SpaceToDepthParams op_params;
1867   op_params.block_size = block_size;
1868 
1869   SpaceToDepth(op_params, DimsToShape(input_dims), input_data,
1870                DimsToShape(output_dims), output_data);
1871 }
1872 
1873 template <typename T>
Mul(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)1874 inline void Mul(const T* input1_data, const Dims<4>& input1_dims,
1875                 const T* input2_data, const Dims<4>& input2_dims,
1876                 T output_activation_min, T output_activation_max,
1877                 T* output_data, const Dims<4>& output_dims) {
1878   tflite::ArithmeticParams op_params;
1879   SetActivationParams(output_activation_min, output_activation_max, &op_params);
1880 
1881   Mul(op_params, DimsToShape(input1_dims), input1_data,
1882       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1883       output_data);
1884 }
1885 
1886 // legacy, for compatibility with old checked-in code
1887 template <FusedActivationFunctionType Ac>
Mul(const float * input1_data,const Dims<4> & input1_dims,const float * input2_data,const Dims<4> & input2_dims,float * output_data,const Dims<4> & output_dims)1888 void Mul(const float* input1_data, const Dims<4>& input1_dims,
1889          const float* input2_data, const Dims<4>& input2_dims,
1890          float* output_data, const Dims<4>& output_dims) {
1891   float output_activation_min, output_activation_max;
1892   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1893 
1894   tflite::ArithmeticParams op_params;
1895   SetActivationParams(output_activation_min, output_activation_max, &op_params);
1896 
1897   Mul(op_params, DimsToShape(input1_dims), input1_data,
1898       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1899       output_data);
1900 }
1901 
1902 template <typename T>
BroadcastMul(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T output_activation_min,T output_activation_max,T * output_data,const Dims<4> & output_dims)1903 void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
1904                   const T* input2_data, const Dims<4>& input2_dims,
1905                   T output_activation_min, T output_activation_max,
1906                   T* output_data, const Dims<4>& output_dims) {
1907   tflite::ArithmeticParams op_params;
1908   SetActivationParams(output_activation_min, output_activation_max, &op_params);
1909 
1910   BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
1911                      DimsToShape(input2_dims), input2_data,
1912                      DimsToShape(output_dims), output_data);
1913 }
1914 
1915 // legacy, for compatibility with old checked-in code
1916 template <FusedActivationFunctionType Ac, typename T>
BroadcastMul(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)1917 void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
1918                   const T* input2_data, const Dims<4>& input2_dims,
1919                   T* output_data, const Dims<4>& output_dims) {
1920   T output_activation_min, output_activation_max;
1921   GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
1922 
1923   tflite::ArithmeticParams op_params;
1924   SetActivationParams(output_activation_min, output_activation_max, &op_params);
1925 
1926   BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
1927                      DimsToShape(input2_dims), input2_data,
1928                      DimsToShape(output_dims), output_data);
1929 }
1930 
Mul(const int16 * input1_data,const Dims<4> & input1_dims,const int16 * input2_data,const Dims<4> & input2_dims,int16 * output_data,const Dims<4> & output_dims)1931 inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
1932                 const int16* input2_data, const Dims<4>& input2_dims,
1933                 int16* output_data, const Dims<4>& output_dims) {
1934   tflite::ArithmeticParams op_params;
1935   // No params in this version.
1936 
1937   Mul(op_params, DimsToShape(input1_dims), input1_data,
1938       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1939       output_data);
1940 }
1941 
Mul(const int16 * input1_data,const Dims<4> & input1_dims,const int16 * input2_data,const Dims<4> & input2_dims,int32 output_offset,int32 output_activation_min,int32 output_activation_max,uint8 * output_data,const Dims<4> & output_dims)1942 inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
1943                 const int16* input2_data, const Dims<4>& input2_dims,
1944                 int32 output_offset, int32 output_activation_min,
1945                 int32 output_activation_max, uint8* output_data,
1946                 const Dims<4>& output_dims) {
1947   tflite::ArithmeticParams op_params;
1948   op_params.quantized_activation_min = output_activation_min;
1949   op_params.quantized_activation_max = output_activation_max;
1950   op_params.output_offset = output_offset;
1951 
1952   Mul(op_params, DimsToShape(input1_dims), input1_data,
1953       DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
1954       output_data);
1955 }
1956 
LocalResponseNormalization(const float * input_data,const Dims<4> & input_dims,int range,float bias,float alpha,float beta,float * output_data,const Dims<4> & output_dims)1957 inline void LocalResponseNormalization(const float* input_data,
1958                                        const Dims<4>& input_dims, int range,
1959                                        float bias, float alpha, float beta,
1960                                        float* output_data,
1961                                        const Dims<4>& output_dims) {
1962   tflite::LocalResponseNormalizationParams op_params;
1963   op_params.range = range;
1964   op_params.bias = bias;
1965   op_params.alpha = alpha;
1966   op_params.beta = beta;
1967 
1968   LocalResponseNormalization(op_params, DimsToShape(input_dims), input_data,
1969                              DimsToShape(output_dims), output_data);
1970 }
1971 
1972 template <typename SrcT, typename DstT>
Cast(const SrcT * input_data,const Dims<4> & input_dims,DstT * output_data,const Dims<4> & output_dims)1973 void Cast(const SrcT* input_data, const Dims<4>& input_dims, DstT* output_data,
1974           const Dims<4>& output_dims) {
1975   Cast(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1976        output_data);
1977 }
1978 
Floor(const float * input_data,const Dims<4> & input_dims,float * output_data,const Dims<4> & output_dims)1979 inline void Floor(const float* input_data, const Dims<4>& input_dims,
1980                   float* output_data, const Dims<4>& output_dims) {
1981   Floor(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
1982         output_data);
1983 }
1984 
1985 template <typename T>
ResizeBilinear(const T * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,T * output_data,const Dims<4> & output_dims,bool align_corners)1986 inline void ResizeBilinear(const T* input_data, const Dims<4>& input_dims,
1987                            const int32* output_size_data,
1988                            const Dims<4>& output_size_dims, T* output_data,
1989                            const Dims<4>& output_dims, bool align_corners) {
1990   tflite::ResizeBilinearParams op_params;
1991   op_params.align_corners = align_corners;
1992   op_params.half_pixel_centers = false;
1993   ResizeBilinear(op_params, DimsToShape(input_dims), input_data,
1994                  DimsToShape(output_size_dims), output_size_data,
1995                  DimsToShape(output_dims), output_data);
1996 }
1997 
1998 // legacy, for compatibility with old checked-in code
ResizeBilinear(const float * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,float * output_data,const Dims<4> & output_dims)1999 inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
2000                            const int32* output_size_data,
2001                            const Dims<4>& output_size_dims, float* output_data,
2002                            const Dims<4>& output_dims) {
2003   ResizeBilinear<float>(input_data, input_dims, output_size_data,
2004                         output_size_dims, output_data, output_dims,
2005                         /*align_corners=*/false);
2006 }
2007 
ResizeBilinear(const uint8 * input_data,const Dims<4> & input_dims,const int32 * output_size_data,const Dims<4> & output_size_dims,uint8 * output_data,const Dims<4> & output_dims)2008 inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
2009                            const int32* output_size_data,
2010                            const Dims<4>& output_size_dims, uint8* output_data,
2011                            const Dims<4>& output_dims) {
2012   ResizeBilinear<uint8>(input_data, input_dims, output_size_data,
2013                         output_size_dims, output_data, output_dims,
2014                         /*align_corners=*/false);
2015 }
2016 
2017 template <typename T>
SpaceToBatchND(const T * input_data,const Dims<4> & input_dims,const int32 * block_shape_data,const Dims<4> & block_shape_dims,const int32 * paddings_data,const Dims<4> & paddings_dims,T * output_data,const Dims<4> & output_dims,const int32_t pad_value)2018 inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
2019                            const int32* block_shape_data,
2020                            const Dims<4>& block_shape_dims,
2021                            const int32* paddings_data,
2022                            const Dims<4>& paddings_dims, T* output_data,
2023                            const Dims<4>& output_dims,
2024                            const int32_t pad_value) {
2025   tflite::SpaceToBatchParams op_params;
2026   op_params.output_offset = pad_value;
2027 
2028   SpaceToBatchND(op_params, DimsToShape(input_dims), input_data,
2029                  DimsToShape(block_shape_dims), block_shape_data,
2030                  DimsToShape(paddings_dims), paddings_data,
2031                  DimsToShape(output_dims), output_data);
2032 }
2033 
2034 template <typename T>
SpaceToBatchND(const T * input_data,const Dims<4> & input_dims,const int32 * block_shape_data,const Dims<4> & block_shape_dims,const int32 * paddings_data,const Dims<4> & paddings_dims,T * output_data,const Dims<4> & output_dims)2035 inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
2036                            const int32* block_shape_data,
2037                            const Dims<4>& block_shape_dims,
2038                            const int32* paddings_data,
2039                            const Dims<4>& paddings_dims, T* output_data,
2040                            const Dims<4>& output_dims) {
2041   tflite::SpaceToBatchParams op_params;
2042   op_params.output_offset = 0;
2043 
2044   SpaceToBatchND(op_params, DimsToShape(input_dims), input_data,
2045                  DimsToShape(block_shape_dims), block_shape_data,
2046                  DimsToShape(paddings_dims), paddings_data,
2047                  DimsToShape(output_dims), output_data);
2048 }
2049 
2050 template <typename T>
BatchToSpaceND(const T * input_data,const Dims<4> & input_dims,const int32 * block_shape_data,const Dims<4> & block_shape_dims,const int32 * crops_data,const Dims<4> & crops_dims,T * output_data,const Dims<4> & output_dims)2051 inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
2052                            const int32* block_shape_data,
2053                            const Dims<4>& block_shape_dims,
2054                            const int32* crops_data, const Dims<4>& crops_dims,
2055                            T* output_data, const Dims<4>& output_dims) {
2056   BatchToSpaceND(DimsToShape(input_dims), input_data,
2057                  DimsToShape(block_shape_dims), block_shape_data,
2058                  DimsToShape(crops_dims), crops_data, DimsToShape(output_dims),
2059                  output_data);
2060 }
2061 
2062 // Legacy signature, function covered both Pad and PadV2.
2063 template <typename T>
PadV2(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & left_paddings,const std::vector<int> & right_paddings,T * output_data,const Dims<4> & output_dims,const T pad_value)2064 inline void PadV2(const T* input_data, const Dims<4>& input_dims,
2065                   const std::vector<int>& left_paddings,
2066                   const std::vector<int>& right_paddings, T* output_data,
2067                   const Dims<4>& output_dims, const T pad_value) {
2068   TFLITE_DCHECK_EQ(left_paddings.size(), 4);
2069   TFLITE_DCHECK_EQ(right_paddings.size(), 4);
2070   tflite::PadParams op_params;
2071   op_params.left_padding_count = 4;
2072   op_params.right_padding_count = 4;
2073   for (int i = 0; i < 4; ++i) {
2074     op_params.left_padding[i] = left_paddings[3 - i];
2075     op_params.right_padding[i] = right_paddings[3 - i];
2076   }
2077   // SetFloatOrInt(pad_value, &op_params.pad_value);
2078   const T pad_value_copy = pad_value;
2079 
2080   Pad(op_params, DimsToShape(input_dims), input_data, &pad_value_copy,
2081       DimsToShape(output_dims), output_data);
2082 }
2083 
2084 // Old Pad that calls legacy PadV2.
2085 template <typename T>
Pad(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & left_paddings,const std::vector<int> & right_paddings,T * output_data,const Dims<4> & output_dims,const int32_t pad_value)2086 inline void Pad(const T* input_data, const Dims<4>& input_dims,
2087                 const std::vector<int>& left_paddings,
2088                 const std::vector<int>& right_paddings, T* output_data,
2089                 const Dims<4>& output_dims, const int32_t pad_value) {
2090   const T converted_pad_value = static_cast<T>(pad_value);
2091   PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
2092            output_dims, converted_pad_value);
2093 }
2094 
2095 // Old Pad that only padded with 0.
2096 template <typename T>
Pad(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & left_paddings,const std::vector<int> & right_paddings,T * output_data,const Dims<4> & output_dims)2097 inline void Pad(const T* input_data, const Dims<4>& input_dims,
2098                 const std::vector<int>& left_paddings,
2099                 const std::vector<int>& right_paddings, T* output_data,
2100                 const Dims<4>& output_dims) {
2101   const T pad_value = static_cast<T>(0);
2102   PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
2103            output_dims, pad_value);
2104 }
2105 
2106 template <typename T>
TensorFlowMinimum(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,T * output_data,const Dims<4> & output_dims)2107 void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
2108                        const T* input2_data, T* output_data,
2109                        const Dims<4>& output_dims) {
2110   Minimum(DimsToShape(input1_dims), input1_data, input2_data,
2111           DimsToShape(output_dims), output_data);
2112 }
2113 
2114 template <typename T>
TensorFlowMaximum(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,T * output_data,const Dims<4> & output_dims)2115 void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
2116                        const T* input2_data, T* output_data,
2117                        const Dims<4>& output_dims) {
2118   Maximum(DimsToShape(input1_dims), input1_data, input2_data,
2119           DimsToShape(output_dims), output_data);
2120 }
2121 
2122 template <typename T, typename Op>
TensorFlowMaximumMinimum(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims,Op op)2123 void TensorFlowMaximumMinimum(const T* input1_data, const Dims<4>& input1_dims,
2124                               const T* input2_data, const Dims<4>& input2_dims,
2125                               T* output_data, const Dims<4>& output_dims,
2126                               Op op) {
2127   MaximumMinimumBroadcastSlow(DimsToShape(input1_dims), input1_data,
2128                               DimsToShape(input2_dims), input2_data,
2129                               DimsToShape(output_dims), output_data, op);
2130 }
2131 
2132 template <typename T1, typename T2, typename T3>
ArgMax(const T3 * axis,const T1 * input_data,const tflite::Dims<4> & input_dims,T2 * output_data,const tflite::Dims<4> & output_dims)2133 void ArgMax(const T3* axis, const T1* input_data,
2134             const tflite::Dims<4>& input_dims, T2* output_data,
2135             const tflite::Dims<4>& output_dims) {
2136   // Assumes the input always has 4 dimensions, and therefore,
2137   // output always has three dimensions.
2138   auto output_shape = RuntimeShape(
2139       {output_dims.sizes[2], output_dims.sizes[1], output_dims.sizes[0]});
2140   // Another way to interpret this is that output_dims.sizes[4] is always 1.
2141   TFLITE_DCHECK_EQ(output_shape.FlatSize(),
2142                    DimsToShape(output_dims).FlatSize());
2143   // Legacy path only supported this.
2144   TFLITE_DCHECK_EQ(axis[0], 3);
2145   ArgMinMax(DimsToShape(input_dims), input_data, axis, output_shape,
2146             output_data, std::greater<T1>());
2147 }
2148 
2149 template <typename T1, typename T2, typename T3, typename Cmp>
ArgMinMax(const T3 * axis,const T1 * input_data,const Dims<4> & input_dims,T2 * output_data,const Dims<4> & output_dims,const Cmp & cmp)2150 void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
2151                T2* output_data, const Dims<4>& output_dims, const Cmp& cmp) {
2152   ArgMinMax(axis, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
2153             output_data, cmp);
2154 }
2155 
2156 template <typename T>
Pow(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)2157 inline void Pow(const T* input1_data, const Dims<4>& input1_dims,
2158                 const T* input2_data, const Dims<4>& input2_dims,
2159                 T* output_data, const Dims<4>& output_dims) {
2160   Pow(DimsToShape(input1_dims), input1_data, DimsToShape(input2_dims),
2161       input2_data, DimsToShape(output_dims), output_data);
2162 }
2163 
2164 template <typename T>
BroadcastPow(const T * input1_data,const Dims<4> & input1_dims,const T * input2_data,const Dims<4> & input2_dims,T * output_data,const Dims<4> & output_dims)2165 inline void BroadcastPow(const T* input1_data, const Dims<4>& input1_dims,
2166                          const T* input2_data, const Dims<4>& input2_dims,
2167                          T* output_data, const Dims<4>& output_dims) {
2168   BroadcastPow4DSlow(DimsToShape(input1_dims), input1_data,
2169                      DimsToShape(input2_dims), input2_data,
2170                      DimsToShape(output_dims), output_data);
2171 }
2172 
2173 // R: Result type. T1: Input 1 type. T2: Input 2 type.
2174 template <typename R, typename T1, typename T2>
BroadcastBinaryFunction(const T1 * input1_data,const Dims<4> & input1_dims,const T2 * input2_data,const Dims<4> & input2_dims,R * output_data,const Dims<4> & output_dims,R (* func)(T1,T2))2175 inline void BroadcastBinaryFunction(const T1* input1_data,
2176                                     const Dims<4>& input1_dims,
2177                                     const T2* input2_data,
2178                                     const Dims<4>& input2_dims, R* output_data,
2179                                     const Dims<4>& output_dims,
2180                                     R (*func)(T1, T2)) {
2181   BroadcastBinaryFunction(DimsToShape(input1_dims), input1_data,
2182                           DimsToShape(input2_dims), input2_data,
2183                           DimsToShape(output_dims), output_data, func);
2184 }
2185 
2186 // R: Result type. T1: Input 1 type. T2: Input 2 type.
2187 template <typename R, typename T1, typename T2>
BinaryFunction(const T1 * input1_data,const Dims<4> & input1_dims,const T2 * input2_data,const Dims<4> & input2_dims,R * output_data,const Dims<4> & output_dims,R (* func)(T1,T2))2188 inline void BinaryFunction(const T1* input1_data, const Dims<4>& input1_dims,
2189                            const T2* input2_data, const Dims<4>& input2_dims,
2190                            R* output_data, const Dims<4>& output_dims,
2191                            R (*func)(T1, T2)) {
2192   BinaryFunction(DimsToShape(input1_dims), input1_data,
2193                  DimsToShape(input2_dims), input2_data,
2194                  DimsToShape(output_dims), output_data, func);
2195 }
2196 
2197 template <typename T>
Slice(const T * input_data,const Dims<4> & input_dims,const std::vector<int> & begin,const std::vector<int> & size,T * output_data,const Dims<4> & output_dims)2198 inline void Slice(const T* input_data, const Dims<4>& input_dims,
2199                   const std::vector<int>& begin, const std::vector<int>& size,
2200                   T* output_data, const Dims<4>& output_dims) {
2201   tflite::SliceParams op_params;
2202   op_params.begin_count = 4;
2203   op_params.size_count = 4;
2204   for (int i = 0; i < 4; ++i) {
2205     op_params.begin[i] = begin[3 - i];
2206     op_params.size[i] = size[3 - i];
2207   }
2208 
2209   Slice(op_params, DimsToShape(input_dims), input_data,
2210         DimsToShape(output_dims), output_data);
2211 }
2212 
2213 }  // namespace reference_ops
2214 }  // namespace tflite
2215 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_
2216