1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2022 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker
6*4bdc9457SAndroid Build Coastguard Worker #include <stdbool.h>
7*4bdc9457SAndroid Build Coastguard Worker #include <stddef.h>
8*4bdc9457SAndroid Build Coastguard Worker #include <string.h>
9*4bdc9457SAndroid Build Coastguard Worker
10*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/math.h>
11*4bdc9457SAndroid Build Coastguard Worker
12*4bdc9457SAndroid Build Coastguard Worker // Returns true if input stride and output stride are NULL or the expected input/output stride matches the actual input/output stride.
can_dimension_be_removed(const size_t * input_stride,const size_t * output_stride,const size_t * shape,const size_t * perm,size_t dim)13*4bdc9457SAndroid Build Coastguard Worker static bool can_dimension_be_removed(
14*4bdc9457SAndroid Build Coastguard Worker const size_t* input_stride,
15*4bdc9457SAndroid Build Coastguard Worker const size_t* output_stride,
16*4bdc9457SAndroid Build Coastguard Worker const size_t* shape,
17*4bdc9457SAndroid Build Coastguard Worker const size_t* perm,
18*4bdc9457SAndroid Build Coastguard Worker size_t dim) {
19*4bdc9457SAndroid Build Coastguard Worker if (dim == 0 && perm[dim] == 0) {
20*4bdc9457SAndroid Build Coastguard Worker return true;
21*4bdc9457SAndroid Build Coastguard Worker }
22*4bdc9457SAndroid Build Coastguard Worker if (input_stride != NULL && dim > 0) {
23*4bdc9457SAndroid Build Coastguard Worker if (input_stride[dim - 1] != input_stride[dim] * shape[dim]) {
24*4bdc9457SAndroid Build Coastguard Worker return false;
25*4bdc9457SAndroid Build Coastguard Worker }
26*4bdc9457SAndroid Build Coastguard Worker }
27*4bdc9457SAndroid Build Coastguard Worker if (output_stride != NULL && perm[dim] > 0) {
28*4bdc9457SAndroid Build Coastguard Worker if (output_stride[perm[dim] - 1] != output_stride[perm[dim]] * shape[dim]) {
29*4bdc9457SAndroid Build Coastguard Worker return false;
30*4bdc9457SAndroid Build Coastguard Worker }
31*4bdc9457SAndroid Build Coastguard Worker }
32*4bdc9457SAndroid Build Coastguard Worker return true;
33*4bdc9457SAndroid Build Coastguard Worker }
34*4bdc9457SAndroid Build Coastguard Worker
35*4bdc9457SAndroid Build Coastguard Worker // Remove dimension perm[dim] from shape, perm, input & output strides.
remove_dimension(size_t * shape,size_t * perm,size_t * input_stride,size_t * output_stride,size_t num_dims,size_t dim)36*4bdc9457SAndroid Build Coastguard Worker static void remove_dimension(
37*4bdc9457SAndroid Build Coastguard Worker size_t* shape,
38*4bdc9457SAndroid Build Coastguard Worker size_t* perm,
39*4bdc9457SAndroid Build Coastguard Worker size_t* input_stride,
40*4bdc9457SAndroid Build Coastguard Worker size_t* output_stride,
41*4bdc9457SAndroid Build Coastguard Worker size_t num_dims,
42*4bdc9457SAndroid Build Coastguard Worker size_t dim)
43*4bdc9457SAndroid Build Coastguard Worker {
44*4bdc9457SAndroid Build Coastguard Worker for (size_t j = perm[dim]; j + 1 < num_dims; ++j) {
45*4bdc9457SAndroid Build Coastguard Worker shape[j] = shape[j + 1];
46*4bdc9457SAndroid Build Coastguard Worker }
47*4bdc9457SAndroid Build Coastguard Worker if (input_stride != NULL) {
48*4bdc9457SAndroid Build Coastguard Worker for (size_t j = max(1, perm[dim]) - 1; j + 1 < num_dims; ++j) {
49*4bdc9457SAndroid Build Coastguard Worker input_stride[j] = input_stride[j + 1];
50*4bdc9457SAndroid Build Coastguard Worker }
51*4bdc9457SAndroid Build Coastguard Worker }
52*4bdc9457SAndroid Build Coastguard Worker if (output_stride != NULL) {
53*4bdc9457SAndroid Build Coastguard Worker for (size_t j = max(1, dim) - 1; j + 1 < num_dims; ++j) {
54*4bdc9457SAndroid Build Coastguard Worker output_stride[j] = output_stride[j + 1];
55*4bdc9457SAndroid Build Coastguard Worker }
56*4bdc9457SAndroid Build Coastguard Worker }
57*4bdc9457SAndroid Build Coastguard Worker for (size_t j = 0; j < num_dims; ++j) {
58*4bdc9457SAndroid Build Coastguard Worker if (perm[j] > perm[dim]) {
59*4bdc9457SAndroid Build Coastguard Worker perm[j] -= 1;
60*4bdc9457SAndroid Build Coastguard Worker }
61*4bdc9457SAndroid Build Coastguard Worker }
62*4bdc9457SAndroid Build Coastguard Worker for (size_t j = dim; j + 1 < num_dims; ++j) {
63*4bdc9457SAndroid Build Coastguard Worker perm[j] = perm[j + 1];
64*4bdc9457SAndroid Build Coastguard Worker }
65*4bdc9457SAndroid Build Coastguard Worker }
xnn_normalize_transpose_permutation(const size_t num_dims,const size_t element_size,const size_t * perm,const size_t * shape,const size_t * input_stride,const size_t * output_stride,size_t * normalized_num_dims,size_t * normalized_element_size_out,size_t * normalized_perm,size_t * normalized_shape,size_t * normalized_input_stride,size_t * normalized_output_stride)66*4bdc9457SAndroid Build Coastguard Worker void xnn_normalize_transpose_permutation(
67*4bdc9457SAndroid Build Coastguard Worker const size_t num_dims,
68*4bdc9457SAndroid Build Coastguard Worker const size_t element_size,
69*4bdc9457SAndroid Build Coastguard Worker const size_t* perm,
70*4bdc9457SAndroid Build Coastguard Worker const size_t* shape,
71*4bdc9457SAndroid Build Coastguard Worker const size_t* input_stride,
72*4bdc9457SAndroid Build Coastguard Worker const size_t* output_stride,
73*4bdc9457SAndroid Build Coastguard Worker size_t* normalized_num_dims,
74*4bdc9457SAndroid Build Coastguard Worker size_t* normalized_element_size_out,
75*4bdc9457SAndroid Build Coastguard Worker size_t* normalized_perm,
76*4bdc9457SAndroid Build Coastguard Worker size_t* normalized_shape,
77*4bdc9457SAndroid Build Coastguard Worker size_t* normalized_input_stride,
78*4bdc9457SAndroid Build Coastguard Worker size_t* normalized_output_stride)
79*4bdc9457SAndroid Build Coastguard Worker {
80*4bdc9457SAndroid Build Coastguard Worker size_t output_dims = num_dims;
81*4bdc9457SAndroid Build Coastguard Worker memcpy(normalized_perm, perm, num_dims * sizeof(size_t));
82*4bdc9457SAndroid Build Coastguard Worker memcpy(normalized_shape, shape, num_dims * sizeof(size_t));
83*4bdc9457SAndroid Build Coastguard Worker size_t* normalized_input_stride_ptr = NULL;
84*4bdc9457SAndroid Build Coastguard Worker size_t* normalized_output_stride_ptr = NULL;
85*4bdc9457SAndroid Build Coastguard Worker if (input_stride != NULL) {
86*4bdc9457SAndroid Build Coastguard Worker memcpy(normalized_input_stride, input_stride, num_dims * sizeof(size_t));
87*4bdc9457SAndroid Build Coastguard Worker normalized_input_stride_ptr = normalized_input_stride;
88*4bdc9457SAndroid Build Coastguard Worker }
89*4bdc9457SAndroid Build Coastguard Worker if (output_stride != NULL) {
90*4bdc9457SAndroid Build Coastguard Worker memcpy(normalized_output_stride, output_stride, num_dims * sizeof(size_t));
91*4bdc9457SAndroid Build Coastguard Worker normalized_output_stride_ptr = normalized_output_stride;
92*4bdc9457SAndroid Build Coastguard Worker }
93*4bdc9457SAndroid Build Coastguard Worker
94*4bdc9457SAndroid Build Coastguard Worker size_t output_pos = 0;
95*4bdc9457SAndroid Build Coastguard Worker // Remove dimensions of size 1 and fold dimensions which are adjacent in both input and output tensors.
96*4bdc9457SAndroid Build Coastguard Worker for (; output_pos < output_dims;) {
97*4bdc9457SAndroid Build Coastguard Worker if (can_dimension_be_removed(normalized_input_stride_ptr, normalized_output_stride_ptr, normalized_shape,
98*4bdc9457SAndroid Build Coastguard Worker normalized_perm, normalized_perm[output_pos])
99*4bdc9457SAndroid Build Coastguard Worker && ((normalized_shape[normalized_perm[output_pos]] == 1)
100*4bdc9457SAndroid Build Coastguard Worker || (output_pos > 0 && normalized_perm[output_pos] == normalized_perm[output_pos - 1] + 1))) {
101*4bdc9457SAndroid Build Coastguard Worker if (output_pos > 0) {
102*4bdc9457SAndroid Build Coastguard Worker normalized_shape[normalized_perm[output_pos - 1]] *= normalized_shape[normalized_perm[output_pos]];
103*4bdc9457SAndroid Build Coastguard Worker }
104*4bdc9457SAndroid Build Coastguard Worker remove_dimension(normalized_shape, normalized_perm, normalized_input_stride_ptr, normalized_output_stride_ptr,
105*4bdc9457SAndroid Build Coastguard Worker output_dims, output_pos);
106*4bdc9457SAndroid Build Coastguard Worker output_dims -= 1;
107*4bdc9457SAndroid Build Coastguard Worker // When a dimension has been removed, new folds may be possible so check
108*4bdc9457SAndroid Build Coastguard Worker // it again.
109*4bdc9457SAndroid Build Coastguard Worker if (output_pos > 0) {
110*4bdc9457SAndroid Build Coastguard Worker output_pos -= 1;
111*4bdc9457SAndroid Build Coastguard Worker }
112*4bdc9457SAndroid Build Coastguard Worker } else {
113*4bdc9457SAndroid Build Coastguard Worker output_pos += 1;
114*4bdc9457SAndroid Build Coastguard Worker }
115*4bdc9457SAndroid Build Coastguard Worker }
116*4bdc9457SAndroid Build Coastguard Worker // All dimensions are size 1.
117*4bdc9457SAndroid Build Coastguard Worker if (output_pos == 0) {
118*4bdc9457SAndroid Build Coastguard Worker *normalized_num_dims = 1;
119*4bdc9457SAndroid Build Coastguard Worker *normalized_element_size_out = element_size;
120*4bdc9457SAndroid Build Coastguard Worker normalized_perm[0] = 0;
121*4bdc9457SAndroid Build Coastguard Worker normalized_shape[0] = 1;
122*4bdc9457SAndroid Build Coastguard Worker normalized_input_stride[0] = element_size;
123*4bdc9457SAndroid Build Coastguard Worker normalized_output_stride[0] = element_size;
124*4bdc9457SAndroid Build Coastguard Worker return;
125*4bdc9457SAndroid Build Coastguard Worker }
126*4bdc9457SAndroid Build Coastguard Worker
127*4bdc9457SAndroid Build Coastguard Worker // If The last input and output dimensions are the same, treat it as one large
128*4bdc9457SAndroid Build Coastguard Worker // element.
129*4bdc9457SAndroid Build Coastguard Worker size_t normalized_element_size = element_size;
130*4bdc9457SAndroid Build Coastguard Worker if (normalized_perm[output_dims - 1] == output_dims - 1) {
131*4bdc9457SAndroid Build Coastguard Worker normalized_element_size = element_size * normalized_shape[output_dims - 1];
132*4bdc9457SAndroid Build Coastguard Worker if (output_dims > 1 && can_dimension_be_removed(normalized_input_stride_ptr, normalized_output_stride_ptr, normalized_shape,
133*4bdc9457SAndroid Build Coastguard Worker normalized_perm, output_dims - 1)) {
134*4bdc9457SAndroid Build Coastguard Worker output_dims -= 1;
135*4bdc9457SAndroid Build Coastguard Worker } else {
136*4bdc9457SAndroid Build Coastguard Worker if (normalized_input_stride != NULL) {
137*4bdc9457SAndroid Build Coastguard Worker normalized_input_stride[output_dims - 1] *= normalized_shape[output_dims - 1];
138*4bdc9457SAndroid Build Coastguard Worker }
139*4bdc9457SAndroid Build Coastguard Worker if (normalized_output_stride != NULL) {
140*4bdc9457SAndroid Build Coastguard Worker normalized_output_stride[normalized_perm[output_dims - 1]] *= normalized_shape[output_dims - 1];
141*4bdc9457SAndroid Build Coastguard Worker }
142*4bdc9457SAndroid Build Coastguard Worker normalized_shape[output_dims - 1] = 1;
143*4bdc9457SAndroid Build Coastguard Worker }
144*4bdc9457SAndroid Build Coastguard Worker }
145*4bdc9457SAndroid Build Coastguard Worker // If input_strides is not provided, calculate it using normalized_shape and normalized_element_size.
146*4bdc9457SAndroid Build Coastguard Worker if (input_stride == NULL) {
147*4bdc9457SAndroid Build Coastguard Worker normalized_input_stride[output_dims - 1] = normalized_element_size;
148*4bdc9457SAndroid Build Coastguard Worker for(size_t i = output_dims - 1; i > 0; --i) {
149*4bdc9457SAndroid Build Coastguard Worker normalized_input_stride[i - 1] = normalized_input_stride[i] * normalized_shape[i];
150*4bdc9457SAndroid Build Coastguard Worker }
151*4bdc9457SAndroid Build Coastguard Worker } else {
152*4bdc9457SAndroid Build Coastguard Worker // Scale input_stride by element size.
153*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < output_dims; ++i) {
154*4bdc9457SAndroid Build Coastguard Worker normalized_input_stride[i] *= element_size;
155*4bdc9457SAndroid Build Coastguard Worker }
156*4bdc9457SAndroid Build Coastguard Worker }
157*4bdc9457SAndroid Build Coastguard Worker // If output_strides is not provided, calculate it using normalized_shape and normalized_element_size.
158*4bdc9457SAndroid Build Coastguard Worker if (output_stride == NULL) {
159*4bdc9457SAndroid Build Coastguard Worker normalized_output_stride[output_dims - 1] = normalized_element_size;
160*4bdc9457SAndroid Build Coastguard Worker for(size_t i = output_dims - 1; i > 0; --i) {
161*4bdc9457SAndroid Build Coastguard Worker normalized_output_stride[i - 1] = normalized_output_stride[i] * normalized_shape[normalized_perm[i]];
162*4bdc9457SAndroid Build Coastguard Worker }
163*4bdc9457SAndroid Build Coastguard Worker } else {
164*4bdc9457SAndroid Build Coastguard Worker // Scale output_stride by element size.
165*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < output_dims; ++i) {
166*4bdc9457SAndroid Build Coastguard Worker normalized_output_stride[i] *= element_size;
167*4bdc9457SAndroid Build Coastguard Worker }
168*4bdc9457SAndroid Build Coastguard Worker }
169*4bdc9457SAndroid Build Coastguard Worker *normalized_element_size_out = normalized_element_size;
170*4bdc9457SAndroid Build Coastguard Worker *normalized_num_dims = output_dims;
171*4bdc9457SAndroid Build Coastguard Worker }
172