xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/internal/optimized/reduce_utils.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_REDUCE_UTILS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_REDUCE_UTILS_H_
17 
18 #include <stdint.h>
19 
20 #include <algorithm>
21 
22 namespace tflite {
23 namespace reduce_utils {
24 
RemoveSize1Dims(int * shape_out,int & out_num_dims,int * axis_out,int & out_num_axis)25 inline void RemoveSize1Dims(int* shape_out, int& out_num_dims, int* axis_out,
26                             int& out_num_axis) {
27   for (int64_t i = 0; i < out_num_dims;) {
28     if (shape_out[i] == 1) {
29       for (int64_t j = i + 1; j < out_num_dims; ++j) {
30         shape_out[j - 1] = shape_out[j];
31       }
32       for (int64_t j = 0; j < out_num_axis; ++j) {
33         if (axis_out[j] == i) {
34           for (int64_t k = j + 1; k < out_num_axis; ++k) {
35             axis_out[k - 1] = axis_out[k];
36           }
37           out_num_axis -= 1;
38           break;
39         }
40       }
41       for (int64_t j = 0; j < out_num_axis; ++j) {
42         if (axis_out[j] > i) {
43           axis_out[j] -= 1;
44         }
45       }
46       --out_num_dims;
47     } else {
48       ++i;
49     }
50   }
51 }
52 
53 // This method parses the input 'axis' to remove duplicates, handle negative
54 // values and remove redundant dimensions. It returns a valid 'axis_out' and
55 // 'shape_out' contains the flattened input shape. 'out_num_dims' contains the
56 // reduced number of dimensions.
ResolveAxis(const int num_dims,const int * axis,const int64_t num_axis,int * axis_out,int & out_num_axis,const int * shape_in,int * shape_out,int & out_num_dims)57 inline bool ResolveAxis(const int num_dims, const int* axis,
58                         const int64_t num_axis, int* axis_out,
59                         int& out_num_axis, const int* shape_in, int* shape_out,
60                         int& out_num_dims) {
61   // Short-circuit axis resolution for scalars; the axis will go unused.
62   if (num_dims == 0) {
63     out_num_axis = 0;
64     out_num_dims = 0;
65     return true;
66   }
67   out_num_axis = 0;
68   out_num_dims = num_dims;
69   // o(n^2) is fine since out_num_axis should be really small, mostly <= 4
70   for (int64_t idx = 0; idx < num_axis; ++idx) {
71     // Handle negative index. A positive index 'p_idx' can be represented as a
72     // negative index 'n_idx' as: n_idx = p_idx-num_dims
73     // eg: For num_dims=3, [0, 1, 2] is the same as [-3, -2, -1]  */
74     int current = axis[idx] < 0 ? (axis[idx] + num_dims) : axis[idx];
75     if (current < 0 || current >= num_dims) {
76       return false;
77     }
78     bool is_dup = false;
79     for (int j = 0; j < out_num_axis; ++j) {
80       if (axis_out[j] == current) {
81         is_dup = true;
82         break;
83       }
84     }
85     if (!is_dup) {
86       axis_out[out_num_axis] = current;
87       out_num_axis += 1;
88     }
89   }
90   // If two or more adjacent dimensions are either reduced
91   // over or not, then the second and subsequent dimensions may be flattened.
92   memcpy(shape_out, shape_in, num_dims * sizeof(int));
93   std::sort(&axis_out[0], &axis_out[out_num_axis]);
94 
95   RemoveSize1Dims(shape_out, out_num_dims, axis_out, out_num_axis);
96   if (out_num_axis > 0) {
97     int64_t j = out_num_axis - 1;
98     // true if the previous index is present in axis_out.
99     bool previous_here = (axis_out[j] == out_num_dims - 1);
100     if (previous_here) {
101       j -= 1;
102     }
103 
104     for (int64_t i = out_num_dims - 2; i >= 0; --i) {
105       // true if the current index is present in axis_out.
106       bool current_here = j >= 0 ? (axis_out[j] == i) : false;
107       if (current_here == previous_here) {
108         shape_out[i] *= shape_out[i + 1];
109         for (int64_t k = i + 1; k + 1 < out_num_dims; ++k) {
110           shape_out[k] = shape_out[k + 1];
111         }
112         // All axis bigger than this need to be reduced by 1.
113         for (int64_t k = 0; k < out_num_axis; ++k) {
114           if (axis_out[k] > i) {
115             axis_out[k] -= 1;
116           }
117         }
118         if (current_here) {
119           for (int64_t k = j + 1; k + 1 < out_num_axis; ++k) {
120             axis_out[k] = axis_out[k + 1];
121           }
122           out_num_axis -= 1;
123         }
124         out_num_dims -= 1;
125       }
126       if (current_here) {
127         j -= 1;
128       }
129       previous_here = current_here;
130     }
131   }
132   return true;
133 }
134 }  // namespace reduce_utils
135 }  // namespace tflite
136 
137 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_REDUCE_UTILS_H_
138