1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_REDUCE_UTILS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_REDUCE_UTILS_H_
17
18 #include <stdint.h>
19
20 #include <algorithm>
21
22 namespace tflite {
23 namespace reduce_utils {
24
RemoveSize1Dims(int * shape_out,int & out_num_dims,int * axis_out,int & out_num_axis)25 inline void RemoveSize1Dims(int* shape_out, int& out_num_dims, int* axis_out,
26 int& out_num_axis) {
27 for (int64_t i = 0; i < out_num_dims;) {
28 if (shape_out[i] == 1) {
29 for (int64_t j = i + 1; j < out_num_dims; ++j) {
30 shape_out[j - 1] = shape_out[j];
31 }
32 for (int64_t j = 0; j < out_num_axis; ++j) {
33 if (axis_out[j] == i) {
34 for (int64_t k = j + 1; k < out_num_axis; ++k) {
35 axis_out[k - 1] = axis_out[k];
36 }
37 out_num_axis -= 1;
38 break;
39 }
40 }
41 for (int64_t j = 0; j < out_num_axis; ++j) {
42 if (axis_out[j] > i) {
43 axis_out[j] -= 1;
44 }
45 }
46 --out_num_dims;
47 } else {
48 ++i;
49 }
50 }
51 }
52
53 // This method parses the input 'axis' to remove duplicates, handle negative
54 // values and remove redundant dimensions. It returns a valid 'axis_out' and
55 // 'shape_out' contains the flattened input shape. 'out_num_dims' contains the
56 // reduced number of dimensions.
ResolveAxis(const int num_dims,const int * axis,const int64_t num_axis,int * axis_out,int & out_num_axis,const int * shape_in,int * shape_out,int & out_num_dims)57 inline bool ResolveAxis(const int num_dims, const int* axis,
58 const int64_t num_axis, int* axis_out,
59 int& out_num_axis, const int* shape_in, int* shape_out,
60 int& out_num_dims) {
61 // Short-circuit axis resolution for scalars; the axis will go unused.
62 if (num_dims == 0) {
63 out_num_axis = 0;
64 out_num_dims = 0;
65 return true;
66 }
67 out_num_axis = 0;
68 out_num_dims = num_dims;
69 // o(n^2) is fine since out_num_axis should be really small, mostly <= 4
70 for (int64_t idx = 0; idx < num_axis; ++idx) {
71 // Handle negative index. A positive index 'p_idx' can be represented as a
72 // negative index 'n_idx' as: n_idx = p_idx-num_dims
73 // eg: For num_dims=3, [0, 1, 2] is the same as [-3, -2, -1] */
74 int current = axis[idx] < 0 ? (axis[idx] + num_dims) : axis[idx];
75 if (current < 0 || current >= num_dims) {
76 return false;
77 }
78 bool is_dup = false;
79 for (int j = 0; j < out_num_axis; ++j) {
80 if (axis_out[j] == current) {
81 is_dup = true;
82 break;
83 }
84 }
85 if (!is_dup) {
86 axis_out[out_num_axis] = current;
87 out_num_axis += 1;
88 }
89 }
90 // If two or more adjacent dimensions are either reduced
91 // over or not, then the second and subsequent dimensions may be flattened.
92 memcpy(shape_out, shape_in, num_dims * sizeof(int));
93 std::sort(&axis_out[0], &axis_out[out_num_axis]);
94
95 RemoveSize1Dims(shape_out, out_num_dims, axis_out, out_num_axis);
96 if (out_num_axis > 0) {
97 int64_t j = out_num_axis - 1;
98 // true if the previous index is present in axis_out.
99 bool previous_here = (axis_out[j] == out_num_dims - 1);
100 if (previous_here) {
101 j -= 1;
102 }
103
104 for (int64_t i = out_num_dims - 2; i >= 0; --i) {
105 // true if the current index is present in axis_out.
106 bool current_here = j >= 0 ? (axis_out[j] == i) : false;
107 if (current_here == previous_here) {
108 shape_out[i] *= shape_out[i + 1];
109 for (int64_t k = i + 1; k + 1 < out_num_dims; ++k) {
110 shape_out[k] = shape_out[k + 1];
111 }
112 // All axis bigger than this need to be reduced by 1.
113 for (int64_t k = 0; k < out_num_axis; ++k) {
114 if (axis_out[k] > i) {
115 axis_out[k] -= 1;
116 }
117 }
118 if (current_here) {
119 for (int64_t k = j + 1; k + 1 < out_num_axis; ++k) {
120 axis_out[k] = axis_out[k + 1];
121 }
122 out_num_axis -= 1;
123 }
124 out_num_dims -= 1;
125 }
126 if (current_here) {
127 j -= 1;
128 }
129 previous_here = current_here;
130 }
131 }
132 return true;
133 }
134 } // namespace reduce_utils
135 } // namespace tflite
136
137 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_REDUCE_UTILS_H_
138