xref: /aosp_15_r20/external/XNNPACK/src/normalization.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2022 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker 
6*4bdc9457SAndroid Build Coastguard Worker #include <stdbool.h>
7*4bdc9457SAndroid Build Coastguard Worker #include <stddef.h>
8*4bdc9457SAndroid Build Coastguard Worker #include <string.h>
9*4bdc9457SAndroid Build Coastguard Worker 
10*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/math.h>
11*4bdc9457SAndroid Build Coastguard Worker 
12*4bdc9457SAndroid Build Coastguard Worker // Returns true if input stride and output stride are NULL or the expected input/output stride matches the actual input/output stride.
can_dimension_be_removed(const size_t * input_stride,const size_t * output_stride,const size_t * shape,const size_t * perm,size_t dim)13*4bdc9457SAndroid Build Coastguard Worker static bool can_dimension_be_removed(
14*4bdc9457SAndroid Build Coastguard Worker     const size_t* input_stride,
15*4bdc9457SAndroid Build Coastguard Worker     const size_t* output_stride,
16*4bdc9457SAndroid Build Coastguard Worker     const size_t* shape,
17*4bdc9457SAndroid Build Coastguard Worker     const size_t* perm,
18*4bdc9457SAndroid Build Coastguard Worker     size_t dim) {
19*4bdc9457SAndroid Build Coastguard Worker   if (dim == 0 && perm[dim] == 0) {
20*4bdc9457SAndroid Build Coastguard Worker     return true;
21*4bdc9457SAndroid Build Coastguard Worker   }
22*4bdc9457SAndroid Build Coastguard Worker   if (input_stride != NULL && dim > 0) {
23*4bdc9457SAndroid Build Coastguard Worker     if (input_stride[dim - 1] != input_stride[dim] * shape[dim]) {
24*4bdc9457SAndroid Build Coastguard Worker       return false;
25*4bdc9457SAndroid Build Coastguard Worker     }
26*4bdc9457SAndroid Build Coastguard Worker   }
27*4bdc9457SAndroid Build Coastguard Worker   if (output_stride != NULL && perm[dim] > 0) {
28*4bdc9457SAndroid Build Coastguard Worker     if (output_stride[perm[dim] - 1] != output_stride[perm[dim]] * shape[dim]) {
29*4bdc9457SAndroid Build Coastguard Worker       return false;
30*4bdc9457SAndroid Build Coastguard Worker     }
31*4bdc9457SAndroid Build Coastguard Worker   }
32*4bdc9457SAndroid Build Coastguard Worker   return true;
33*4bdc9457SAndroid Build Coastguard Worker }
34*4bdc9457SAndroid Build Coastguard Worker 
35*4bdc9457SAndroid Build Coastguard Worker // Remove dimension perm[dim] from shape, perm, input & output strides.
remove_dimension(size_t * shape,size_t * perm,size_t * input_stride,size_t * output_stride,size_t num_dims,size_t dim)36*4bdc9457SAndroid Build Coastguard Worker static void remove_dimension(
37*4bdc9457SAndroid Build Coastguard Worker     size_t* shape,
38*4bdc9457SAndroid Build Coastguard Worker     size_t* perm,
39*4bdc9457SAndroid Build Coastguard Worker     size_t* input_stride,
40*4bdc9457SAndroid Build Coastguard Worker     size_t* output_stride,
41*4bdc9457SAndroid Build Coastguard Worker     size_t num_dims,
42*4bdc9457SAndroid Build Coastguard Worker     size_t dim)
43*4bdc9457SAndroid Build Coastguard Worker {
44*4bdc9457SAndroid Build Coastguard Worker   for (size_t j = perm[dim]; j + 1 < num_dims; ++j) {
45*4bdc9457SAndroid Build Coastguard Worker     shape[j] = shape[j + 1];
46*4bdc9457SAndroid Build Coastguard Worker   }
47*4bdc9457SAndroid Build Coastguard Worker   if (input_stride != NULL) {
48*4bdc9457SAndroid Build Coastguard Worker     for (size_t j = max(1, perm[dim]) - 1; j + 1 < num_dims; ++j) {
49*4bdc9457SAndroid Build Coastguard Worker       input_stride[j] = input_stride[j + 1];
50*4bdc9457SAndroid Build Coastguard Worker     }
51*4bdc9457SAndroid Build Coastguard Worker   }
52*4bdc9457SAndroid Build Coastguard Worker   if (output_stride != NULL) {
53*4bdc9457SAndroid Build Coastguard Worker     for (size_t j = max(1, dim) - 1; j + 1 < num_dims; ++j) {
54*4bdc9457SAndroid Build Coastguard Worker       output_stride[j] = output_stride[j + 1];
55*4bdc9457SAndroid Build Coastguard Worker     }
56*4bdc9457SAndroid Build Coastguard Worker   }
57*4bdc9457SAndroid Build Coastguard Worker   for (size_t j = 0; j < num_dims; ++j) {
58*4bdc9457SAndroid Build Coastguard Worker     if (perm[j] > perm[dim]) {
59*4bdc9457SAndroid Build Coastguard Worker       perm[j] -= 1;
60*4bdc9457SAndroid Build Coastguard Worker     }
61*4bdc9457SAndroid Build Coastguard Worker   }
62*4bdc9457SAndroid Build Coastguard Worker   for (size_t j = dim; j + 1 < num_dims; ++j) {
63*4bdc9457SAndroid Build Coastguard Worker     perm[j] = perm[j + 1];
64*4bdc9457SAndroid Build Coastguard Worker   }
65*4bdc9457SAndroid Build Coastguard Worker }
xnn_normalize_transpose_permutation(const size_t num_dims,const size_t element_size,const size_t * perm,const size_t * shape,const size_t * input_stride,const size_t * output_stride,size_t * normalized_num_dims,size_t * normalized_element_size_out,size_t * normalized_perm,size_t * normalized_shape,size_t * normalized_input_stride,size_t * normalized_output_stride)66*4bdc9457SAndroid Build Coastguard Worker void xnn_normalize_transpose_permutation(
67*4bdc9457SAndroid Build Coastguard Worker     const size_t num_dims,
68*4bdc9457SAndroid Build Coastguard Worker     const size_t element_size,
69*4bdc9457SAndroid Build Coastguard Worker     const size_t* perm,
70*4bdc9457SAndroid Build Coastguard Worker     const size_t* shape,
71*4bdc9457SAndroid Build Coastguard Worker     const size_t* input_stride,
72*4bdc9457SAndroid Build Coastguard Worker     const size_t* output_stride,
73*4bdc9457SAndroid Build Coastguard Worker     size_t* normalized_num_dims,
74*4bdc9457SAndroid Build Coastguard Worker     size_t* normalized_element_size_out,
75*4bdc9457SAndroid Build Coastguard Worker     size_t* normalized_perm,
76*4bdc9457SAndroid Build Coastguard Worker     size_t* normalized_shape,
77*4bdc9457SAndroid Build Coastguard Worker     size_t* normalized_input_stride,
78*4bdc9457SAndroid Build Coastguard Worker     size_t* normalized_output_stride)
79*4bdc9457SAndroid Build Coastguard Worker {
80*4bdc9457SAndroid Build Coastguard Worker   size_t output_dims = num_dims;
81*4bdc9457SAndroid Build Coastguard Worker   memcpy(normalized_perm, perm, num_dims * sizeof(size_t));
82*4bdc9457SAndroid Build Coastguard Worker   memcpy(normalized_shape, shape, num_dims * sizeof(size_t));
83*4bdc9457SAndroid Build Coastguard Worker   size_t* normalized_input_stride_ptr = NULL;
84*4bdc9457SAndroid Build Coastguard Worker   size_t* normalized_output_stride_ptr = NULL;
85*4bdc9457SAndroid Build Coastguard Worker   if (input_stride != NULL) {
86*4bdc9457SAndroid Build Coastguard Worker     memcpy(normalized_input_stride, input_stride, num_dims * sizeof(size_t));
87*4bdc9457SAndroid Build Coastguard Worker     normalized_input_stride_ptr = normalized_input_stride;
88*4bdc9457SAndroid Build Coastguard Worker   }
89*4bdc9457SAndroid Build Coastguard Worker   if (output_stride != NULL) {
90*4bdc9457SAndroid Build Coastguard Worker     memcpy(normalized_output_stride, output_stride, num_dims * sizeof(size_t));
91*4bdc9457SAndroid Build Coastguard Worker     normalized_output_stride_ptr = normalized_output_stride;
92*4bdc9457SAndroid Build Coastguard Worker   }
93*4bdc9457SAndroid Build Coastguard Worker 
94*4bdc9457SAndroid Build Coastguard Worker   size_t output_pos = 0;
95*4bdc9457SAndroid Build Coastguard Worker   // Remove dimensions of size 1 and fold dimensions which are adjacent in both input and output tensors.
96*4bdc9457SAndroid Build Coastguard Worker   for (; output_pos < output_dims;) {
97*4bdc9457SAndroid Build Coastguard Worker     if (can_dimension_be_removed(normalized_input_stride_ptr, normalized_output_stride_ptr, normalized_shape,
98*4bdc9457SAndroid Build Coastguard Worker                                  normalized_perm, normalized_perm[output_pos])
99*4bdc9457SAndroid Build Coastguard Worker         && ((normalized_shape[normalized_perm[output_pos]] == 1)
100*4bdc9457SAndroid Build Coastguard Worker             || (output_pos > 0 && normalized_perm[output_pos] == normalized_perm[output_pos - 1] + 1))) {
101*4bdc9457SAndroid Build Coastguard Worker       if (output_pos > 0) {
102*4bdc9457SAndroid Build Coastguard Worker         normalized_shape[normalized_perm[output_pos - 1]] *= normalized_shape[normalized_perm[output_pos]];
103*4bdc9457SAndroid Build Coastguard Worker       }
104*4bdc9457SAndroid Build Coastguard Worker       remove_dimension(normalized_shape, normalized_perm, normalized_input_stride_ptr, normalized_output_stride_ptr,
105*4bdc9457SAndroid Build Coastguard Worker                        output_dims, output_pos);
106*4bdc9457SAndroid Build Coastguard Worker       output_dims -= 1;
107*4bdc9457SAndroid Build Coastguard Worker       // When a dimension has been removed, new folds may be possible so check
108*4bdc9457SAndroid Build Coastguard Worker       // it again.
109*4bdc9457SAndroid Build Coastguard Worker       if (output_pos > 0) {
110*4bdc9457SAndroid Build Coastguard Worker         output_pos -= 1;
111*4bdc9457SAndroid Build Coastguard Worker       }
112*4bdc9457SAndroid Build Coastguard Worker     } else {
113*4bdc9457SAndroid Build Coastguard Worker       output_pos += 1;
114*4bdc9457SAndroid Build Coastguard Worker     }
115*4bdc9457SAndroid Build Coastguard Worker   }
116*4bdc9457SAndroid Build Coastguard Worker   // All dimensions are size 1.
117*4bdc9457SAndroid Build Coastguard Worker   if (output_pos == 0) {
118*4bdc9457SAndroid Build Coastguard Worker     *normalized_num_dims = 1;
119*4bdc9457SAndroid Build Coastguard Worker     *normalized_element_size_out = element_size;
120*4bdc9457SAndroid Build Coastguard Worker     normalized_perm[0] = 0;
121*4bdc9457SAndroid Build Coastguard Worker     normalized_shape[0] = 1;
122*4bdc9457SAndroid Build Coastguard Worker     normalized_input_stride[0] = element_size;
123*4bdc9457SAndroid Build Coastguard Worker     normalized_output_stride[0] = element_size;
124*4bdc9457SAndroid Build Coastguard Worker     return;
125*4bdc9457SAndroid Build Coastguard Worker   }
126*4bdc9457SAndroid Build Coastguard Worker 
127*4bdc9457SAndroid Build Coastguard Worker   // If The last input and output dimensions are the same, treat it as one large
128*4bdc9457SAndroid Build Coastguard Worker   // element.
129*4bdc9457SAndroid Build Coastguard Worker   size_t normalized_element_size = element_size;
130*4bdc9457SAndroid Build Coastguard Worker   if (normalized_perm[output_dims - 1] == output_dims - 1) {
131*4bdc9457SAndroid Build Coastguard Worker     normalized_element_size = element_size * normalized_shape[output_dims - 1];
132*4bdc9457SAndroid Build Coastguard Worker     if (output_dims > 1 && can_dimension_be_removed(normalized_input_stride_ptr, normalized_output_stride_ptr, normalized_shape,
133*4bdc9457SAndroid Build Coastguard Worker                                  normalized_perm, output_dims - 1)) {
134*4bdc9457SAndroid Build Coastguard Worker       output_dims -= 1;
135*4bdc9457SAndroid Build Coastguard Worker     } else {
136*4bdc9457SAndroid Build Coastguard Worker       if (normalized_input_stride != NULL) {
137*4bdc9457SAndroid Build Coastguard Worker         normalized_input_stride[output_dims - 1] *= normalized_shape[output_dims - 1];
138*4bdc9457SAndroid Build Coastguard Worker       }
139*4bdc9457SAndroid Build Coastguard Worker       if (normalized_output_stride != NULL) {
140*4bdc9457SAndroid Build Coastguard Worker         normalized_output_stride[normalized_perm[output_dims - 1]] *= normalized_shape[output_dims - 1];
141*4bdc9457SAndroid Build Coastguard Worker       }
142*4bdc9457SAndroid Build Coastguard Worker       normalized_shape[output_dims - 1] = 1;
143*4bdc9457SAndroid Build Coastguard Worker     }
144*4bdc9457SAndroid Build Coastguard Worker   }
145*4bdc9457SAndroid Build Coastguard Worker   // If input_strides is not provided, calculate it using normalized_shape and normalized_element_size.
146*4bdc9457SAndroid Build Coastguard Worker   if (input_stride == NULL) {
147*4bdc9457SAndroid Build Coastguard Worker     normalized_input_stride[output_dims - 1] = normalized_element_size;
148*4bdc9457SAndroid Build Coastguard Worker     for(size_t i = output_dims - 1; i > 0; --i) {
149*4bdc9457SAndroid Build Coastguard Worker       normalized_input_stride[i - 1] = normalized_input_stride[i] * normalized_shape[i];
150*4bdc9457SAndroid Build Coastguard Worker     }
151*4bdc9457SAndroid Build Coastguard Worker   } else {
152*4bdc9457SAndroid Build Coastguard Worker     // Scale input_stride by element size.
153*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = 0; i < output_dims; ++i) {
154*4bdc9457SAndroid Build Coastguard Worker       normalized_input_stride[i] *= element_size;
155*4bdc9457SAndroid Build Coastguard Worker     }
156*4bdc9457SAndroid Build Coastguard Worker   }
157*4bdc9457SAndroid Build Coastguard Worker   // If output_strides is not provided, calculate it using normalized_shape and normalized_element_size.
158*4bdc9457SAndroid Build Coastguard Worker   if (output_stride == NULL) {
159*4bdc9457SAndroid Build Coastguard Worker     normalized_output_stride[output_dims - 1] = normalized_element_size;
160*4bdc9457SAndroid Build Coastguard Worker     for(size_t i = output_dims - 1; i > 0; --i) {
161*4bdc9457SAndroid Build Coastguard Worker       normalized_output_stride[i - 1] = normalized_output_stride[i] * normalized_shape[normalized_perm[i]];
162*4bdc9457SAndroid Build Coastguard Worker     }
163*4bdc9457SAndroid Build Coastguard Worker   } else {
164*4bdc9457SAndroid Build Coastguard Worker     // Scale output_stride by element size.
165*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = 0; i < output_dims; ++i) {
166*4bdc9457SAndroid Build Coastguard Worker       normalized_output_stride[i] *= element_size;
167*4bdc9457SAndroid Build Coastguard Worker     }
168*4bdc9457SAndroid Build Coastguard Worker   }
169*4bdc9457SAndroid Build Coastguard Worker   *normalized_element_size_out = normalized_element_size;
170*4bdc9457SAndroid Build Coastguard Worker   *normalized_num_dims = output_dims;
171*4bdc9457SAndroid Build Coastguard Worker }
172