1 #include "xa_nnlib_common.h"
2 #include <string.h>
3 //#include "xa_nn_basic_state.h"
4 #include "xa_nnlib_common_macros.h"
5
6 #define ALIGNMENT_8 8
7
8 #define ALIGN_PTR(x, bytes) ((((unsigned)(x))+(bytes-1))&(~(bytes-1)))
9
vecmean16_inpx3(const xtfloatx2 * p_src1,const xtfloat * p_src2,const xtfloat * p_src3,xtfloatx2 * p_dst,int N)10 static void vecmean16_inpx3(const xtfloatx2 *p_src1, const xtfloat* p_src2, const xtfloat* p_src3, xtfloatx2 *p_dst, int N){
11 int i = 0;
12 ae_valign align_src1, align_dst;
13 ae_valign align_src2, align_src3;
14 align_src1 = AE_LA64_PP(p_src1);
15 align_src2 = AE_LA64_PP(p_src2);
16 align_src3 = AE_LA64_PP(p_src3);
17 align_dst = AE_ZALIGN64();
18
19 for(i=0; i < (N >> 2); i++)
20 {
21 xtfloatx2 j1_h, j1_l, j2_h, j2_l;
22
23 xtfloatx2 wout1, wout2;
24 XT_LASX2IP(wout1, align_src1, p_src1);
25 XT_LASX2IP(wout2, align_src1, p_src1);
26
27 XT_LASX2IP(j1_h, align_src2, (xtfloatx2 *)p_src2);
28 XT_LASX2IP(j1_l, align_src2, (xtfloatx2 *)p_src2);
29 XT_LASX2IP(j2_h, align_src3, (xtfloatx2 *)p_src3);
30 XT_LASX2IP(j2_l, align_src3, (xtfloatx2 *)p_src3);
31
32 j1_h = XT_ADD_SX2(j1_h, j2_h);
33 j1_l = XT_ADD_SX2(j1_l, j2_l);
34 wout1 = XT_ADD_SX2(wout1, j1_h);
35 wout2 = XT_ADD_SX2(wout2, j1_l);
36
37 XT_SASX2IP(wout1, align_dst, p_dst);
38 XT_SASX2IP(wout2, align_dst, p_dst);
39 }
40 AE_SA64POS_FP(align_dst, p_dst); // finalize the stream
41
42 //Remainder Loop
43 for(i=0; i < (N & 3); i++)
44 {
45 xtfloat j1, j2;
46 xtfloat wout1;
47 XT_LSXP(wout1, (xtfloat *)p_src1, sizeof(xtfloat));
48 j1 = (xtfloat) *(p_src2 + i);
49 j2 = (xtfloat) *(p_src3 + i);
50
51 j1 = XT_ADD_S(j1, j2);
52 wout1 = XT_ADD_S(wout1, j1);
53 XT_SSXP(wout1, (xtfloat *)p_dst, sizeof(xtfloat));
54 }
55 }
56
vecmean16_inpx2(const xtfloatx2 * p_src1,const xtfloat * p_src2,xtfloatx2 * p_dst,int N)57 static void vecmean16_inpx2(const xtfloatx2 *p_src1, const xtfloat* p_src2, xtfloatx2 *p_dst, int N){
58 ae_valign align_src1, align_dst;
59 ae_valign align_src2;
60 align_src1 = AE_LA64_PP(p_src1);
61 align_src2 = AE_LA64_PP(p_src2);
62 align_dst = AE_ZALIGN64();
63
64 int i = 0;
65 for(i=0; i < (N >> 2); i++)
66 {
67 xtfloatx2 j1, j2;
68 xtfloatx2 wout1, wout2;
69 XT_LASX2IP(wout1, align_src1, p_src1);
70 XT_LASX2IP(wout2, align_src1, p_src1);
71
72 XT_LASX2IP(j1, align_src2, (xtfloatx2 *)p_src2);
73 XT_LASX2IP(j2, align_src2, (xtfloatx2 *)p_src2);
74
75 wout1 = XT_ADD_SX2(wout1, j1);
76 wout2 = XT_ADD_SX2(wout2, j2);
77
78 XT_SASX2IP(wout1, align_dst, p_dst);
79 XT_SASX2IP(wout2, align_dst, p_dst);
80 }
81 AE_SA64POS_FP(align_dst, p_dst); // finalize the stream
82
83 //Remainder Loop
84 for(i=0; i < (N & 3); i++)
85 {
86 xtfloat j1;
87 xtfloat wout1;
88 XT_LSXP(wout1, (xtfloat *)p_src1, sizeof(xtfloat));
89 j1 = (xtfloat) *(p_src2 + i);
90 wout1 = XT_ADD_S(wout1, j1);
91 XT_SSXP(wout1, (xtfloat *)p_dst, sizeof(xtfloat));
92 }
93 }
94
vecmean32_inpx3(const xtfloatx2 * p_src1,const xtfloatx2 * p_wsrc2,const xtfloatx2 * p_wsrc3,xtfloatx2 * p_dst,int N)95 static void vecmean32_inpx3(const xtfloatx2* p_src1, const xtfloatx2* p_wsrc2, const xtfloatx2* p_wsrc3, xtfloatx2 *p_dst, int N){
96 ae_valign align_src1, align_src2, align_src3, align_dst;
97 align_src1 = AE_LA64_PP(p_src1);
98 align_src2 = AE_LA64_PP(p_wsrc2);
99 align_src3 = AE_LA64_PP(p_wsrc3);
100 align_dst = AE_ZALIGN64();
101
102 int i = 0;
103 for(i=0; i < (N >> 2); i++)
104 {
105 xtfloatx2 j1, j2, j3, j4;
106 xtfloatx2 wj1, wj2;
107 xtfloatx2 wout1, wout2;
108 XT_LASX2IP(wout1, align_src1, p_src1);
109 XT_LASX2IP(wout2, align_src1, p_src1);
110 XT_LASX2IP(j1, align_src2, p_wsrc2);
111 XT_LASX2IP(j2, align_src3, p_wsrc3);
112 XT_LASX2IP(j3, align_src2, p_wsrc2);
113 XT_LASX2IP(j4, align_src3, p_wsrc3);
114
115 wj1 = XT_ADD_SX2(j1, j2);
116 wj2 = XT_ADD_SX2(j3, j4);
117 wout1 = XT_ADD_SX2(wout1, wj1);
118 wout2 = XT_ADD_SX2(wout2, wj2);
119 XT_SASX2IP(wout1, align_dst, p_dst);
120 XT_SASX2IP(wout2, align_dst, p_dst);
121 }
122 AE_SA64POS_FP(align_dst, p_dst); // finalize the stream
123
124 //Remainder Loop
125 for(i=0; i < (N & 3); i++)
126 {
127 xtfloat j1, j2;
128 xtfloat wj1;
129 xtfloat wout1;
130 XT_LSXP(wout1, (xtfloat *)p_src1, 4);
131 XT_LSXP(j1, (xtfloat *)p_wsrc2, 4);
132 XT_LSXP(j2, (xtfloat *)p_wsrc3, 4);
133 wj1 = XT_ADD_S(j1, j2);
134 wout1 = XT_ADD_S(wout1, wj1);
135 XT_SSXP(wout1, (xtfloat *)p_dst, sizeof(xtfloat));
136 }
137 }
138
vecmean32_inpx2(const xtfloatx2 * p_src1,const xtfloatx2 * p_wsrc2,xtfloatx2 * p_dst,int N)139 static void vecmean32_inpx2(const xtfloatx2* p_src1, const xtfloatx2* p_wsrc2, xtfloatx2 *p_dst, int N){
140 ae_valign align_src1, align_src2, align_dst;
141 align_src1 = AE_LA64_PP(p_src1);
142 align_src2 = AE_LA64_PP(p_wsrc2);
143 align_dst = AE_ZALIGN64();
144
145 int i = 0;
146 for(i=0; i < (N >> 2); i++)
147 {
148 xtfloatx2 j1, j2;
149 xtfloatx2 wout1, wout2;
150 XT_LASX2IP(wout1, align_src1, p_src1);
151 XT_LASX2IP(wout2, align_src1, p_src1);
152 XT_LASX2IP(j1, align_src2, p_wsrc2);
153 XT_LASX2IP(j2, align_src2, p_wsrc2);
154 wout1 = XT_ADD_SX2(wout1, j1);
155 wout2 = XT_ADD_SX2(wout2, j2);
156 XT_SASX2IP(wout1, align_dst, p_dst);
157 XT_SASX2IP(wout2, align_dst, p_dst);
158 }
159 AE_SA64POS_FP(align_dst, p_dst); // finalize the stream
160
161 //Remainder Loop
162 for(i=0; i < (N & 3); i++)
163 {
164 xtfloat j1;
165 xtfloat wout1;
166 XT_LSXP(wout1, (xtfloat *)p_src1, 4);
167 XT_LSXP(j1, (xtfloat *)p_wsrc2, 4);
168 wout1 = XT_ADD_S(wout1, j1);
169 XT_SSXP(wout1, (xtfloat *)p_dst, sizeof(WORD32));
170 }
171 }
172
xa_nn_reduce_sum_4D_f32_f32(const FLOAT32 * __restrict__ p_inp,const WORD32 * const p_4D_inp_shape,const WORD32 * __restrict__ p_axis_data,WORD32 num_inp_dims,WORD32 num_axis_dims,pVOID p_scratch_in)173 static inline void xa_nn_reduce_sum_4D_f32_f32(const FLOAT32 * __restrict__ p_inp
174 ,const WORD32 *const p_4D_inp_shape
175 ,const WORD32 * __restrict__ p_axis_data
176 ,WORD32 num_inp_dims
177 ,WORD32 num_axis_dims
178 ,pVOID p_scratch_in)
179 {
180 xtfloat *p_in = (xtfloat *)(p_inp);
181 xtfloat *p_scratch = (xtfloat *)(p_scratch_in);
182
183 int temp_inp_n = p_4D_inp_shape[0];
184 int temp_inp_h = p_4D_inp_shape[1];
185 int temp_inp_w = p_4D_inp_shape[2];
186 int temp_inp_c = p_4D_inp_shape[3];
187
188 int itr_axis = 0, itr_n = 0, itr_h = 0, itr_w = 0, itr_c = 0;
189 xtfloat *p_src2, *p_src3;
190 xtfloatx2 *p_src1;
191 xtfloatx2 * p_dst;
192 ae_valign align_src2;
193
194 int axis_dims_count = num_axis_dims;
195 if(axis_dims_count)
196 {
197 switch(p_axis_data[itr_axis])
198 {
199 case 0: {
200 int plane_size = temp_inp_h * temp_inp_w * temp_inp_c;
201 for(itr_n=0; itr_n < (temp_inp_n & ~(2 - 1)); itr_n += 2)
202 {
203 p_src1 = (xtfloatx2 *)p_scratch;
204 p_src2 = p_in + itr_n * plane_size;
205 p_src3 = p_in + (itr_n + 1) * plane_size;
206 p_dst = (xtfloatx2 *)p_scratch;
207 vecmean16_inpx3(p_src1, p_src2, p_src3, p_dst, plane_size);
208 }
209
210 if(temp_inp_n & 1)
211 {
212 p_src1 = (xtfloatx2 *)p_scratch;
213 p_src2 = (p_in + itr_n * plane_size);
214 p_dst = (xtfloatx2 *)p_scratch;
215 vecmean16_inpx2(p_src1, p_src2, p_dst, plane_size);
216 }
217 temp_inp_n = 1;
218 }break;
219 case 1: {
220 int plane_size = temp_inp_h * temp_inp_w * temp_inp_c;
221 int wc_plane_size = temp_inp_w * temp_inp_c;
222 for(itr_n=0; itr_n < (temp_inp_n); itr_n++)
223 {
224 p_src1 = (xtfloatx2 *)(p_scratch + (itr_n * wc_plane_size));
225 for(itr_h=0; itr_h < (temp_inp_h & ~(2 - 1)); itr_h += 2)
226 {
227 p_src2 = p_in + (itr_n * plane_size) + (itr_h * wc_plane_size);
228 p_src3 = p_in + (itr_n * plane_size) + ((itr_h + 1) * wc_plane_size);
229 p_dst = (xtfloatx2 *)(p_scratch + (itr_n * wc_plane_size));
230 vecmean16_inpx3(p_src1, p_src2, p_src3, p_dst, wc_plane_size);
231 p_src1 = (xtfloatx2 *)(p_scratch + (itr_n * wc_plane_size));
232 }
233
234 if(temp_inp_h & 1)
235 {
236 p_src2 = p_in + (itr_n * plane_size) + (itr_h * wc_plane_size);
237 p_dst = (xtfloatx2 *)(p_scratch + (itr_n * wc_plane_size));
238 vecmean16_inpx2(p_src1, p_src2, p_dst, wc_plane_size);
239 }
240 }
241 temp_inp_h = 1;
242 }break;
243 case 2:{
244 int plane_size = temp_inp_h * temp_inp_w * temp_inp_c;
245 int wc_plane_size = temp_inp_w * temp_inp_c;
246 int hc_plane_size = temp_inp_h * temp_inp_c;
247
248 for(itr_n=0; itr_n < (temp_inp_n); itr_n++)
249 {
250 for(itr_h=0; itr_h < (temp_inp_h); itr_h++)
251 {
252 p_src1 = (xtfloatx2 *)(p_scratch + (((itr_n * hc_plane_size) + itr_h * temp_inp_c)));
253 for(itr_w=0; itr_w < (temp_inp_w & ~(2 - 1)); itr_w += 2)
254 {
255 p_src2 = p_in + (itr_n * plane_size) + (itr_h * wc_plane_size) + (itr_w * temp_inp_c);
256 p_src3 = p_in + (itr_n * plane_size) + (itr_h * wc_plane_size) + ((itr_w + 1) * temp_inp_c);
257 p_dst = (xtfloatx2 *)(p_scratch + (itr_n * hc_plane_size) + itr_h * temp_inp_c);
258 vecmean16_inpx3(p_src1, p_src2, p_src3, p_dst, temp_inp_c);
259 p_src1 = (xtfloatx2 *)(p_scratch + (itr_n * hc_plane_size) + (itr_h * temp_inp_c));
260 }
261
262 if(temp_inp_w & 1)
263 {
264 p_src2 = p_in + (itr_n * plane_size) + (itr_h * wc_plane_size) + (itr_w * temp_inp_c);
265 p_dst = (xtfloatx2 *)(p_scratch + (itr_n * hc_plane_size) + itr_h * temp_inp_c);
266 vecmean16_inpx2(p_src1, p_src2, p_dst, temp_inp_c);
267 }
268 }
269 }
270 temp_inp_w = 1;
271 }break;
272 case 3: {
273 int plane_size = temp_inp_h * temp_inp_w * temp_inp_c;
274 int wc_plane_size = temp_inp_w * temp_inp_c;
275 int hw_plane_size = temp_inp_h * temp_inp_w;
276 int rem_c = (temp_inp_c & 7);
277
278 for(itr_n=0; itr_n < (temp_inp_n); itr_n++)
279 {
280 for(itr_h=0; itr_h < (temp_inp_h); itr_h++)
281 {
282 for(itr_w=0; itr_w < (temp_inp_w); itr_w++)
283 {
284 p_src1 = (xtfloatx2 *)(p_scratch + (((itr_n * hw_plane_size) + (itr_h * temp_inp_w) + itr_w)));
285 p_src2 = p_in + (itr_n * plane_size) + (itr_h * wc_plane_size) + (itr_w * temp_inp_c);
286 p_dst = (xtfloatx2 *)(p_scratch + (itr_n * hw_plane_size) + (itr_h * temp_inp_w) + itr_w);
287 align_src2 = AE_LA64_PP(p_src2);
288
289 for(itr_c=0; itr_c < (temp_inp_c >> 3); itr_c++)
290 {
291 xtfloatx2 j11, j12, j21, j22, i1;
292 i1 = XT_LSX((xtfloat *)p_src1, 0);
293 XT_LASX2IP(j11, align_src2, (xtfloatx2 *)p_src2);
294 XT_LASX2IP(j12, align_src2, (xtfloatx2 *)p_src2);
295 XT_LASX2IP(j21, align_src2, (xtfloatx2 *)p_src2);
296 XT_LASX2IP(j22, align_src2, (xtfloatx2 *)p_src2);
297
298 j11 = XT_ADD_SX2(j11, j12);
299 j21 = XT_ADD_SX2(j21, j22);
300
301 xtfloatx2 t1 = XT_SEL32_HH_SX2(j11, j11);
302 xtfloatx2 t2 = XT_SEL32_HH_SX2(j21, j21);
303
304 j11 = XT_ADD_SX2(j11, t1);
305 j21 = XT_ADD_SX2(j21, t2);
306
307 j11 = XT_ADD_SX2(j11, j21);
308 i1 = XT_ADD_SX2(i1, j11);
309
310 XT_SSX(i1, (xtfloat *)p_dst, 0);
311
312 p_src1 = p_dst;
313 }
314 //Remainder Loop
315 for(itr_c=0; itr_c < rem_c ; itr_c++)
316 {
317 xtfloat j1;
318 xtfloat i1;
319 i1 = XT_LSX((xtfloat *)p_src1, 0);
320 j1 = *p_src2++;
321
322 i1 = XT_ADD_S(i1, j1);
323 XT_SSX(i1, (xtfloat *)p_dst, 0);
324 }
325 }
326 }
327 }
328 temp_inp_c = 1;
329 }break;
330 default:
331 break;
332 }
333
334 axis_dims_count--;
335 itr_axis++;
336 }
337
338 while(axis_dims_count)
339 {
340 ae_valign align_src;
341 xtfloat *p_scr_in = p_scratch;
342 xtfloatx2 *p_wsrc2, *p_wsrc3;
343 switch(p_axis_data[itr_axis])
344 {
345 case 0: {
346 int plane_size = temp_inp_h * temp_inp_w * temp_inp_c;
347 for(itr_n=1; itr_n < ((temp_inp_n -1) & ~(2 - 1)); itr_n += 2)
348 {
349 p_src1 = (xtfloatx2 *)p_scratch;
350 p_wsrc2 = (xtfloatx2 *)(p_scr_in + itr_n * plane_size);
351 p_wsrc3 = (xtfloatx2 *)(p_scr_in + (itr_n + 1) * plane_size);
352 p_dst = (xtfloatx2 *)p_scratch;
353 vecmean32_inpx3(p_src1, p_wsrc2, p_wsrc3, p_dst, plane_size);
354 }
355
356 if((temp_inp_n - 1) & 1)
357 {
358 p_src1 = (xtfloatx2 *)p_scratch;
359 p_wsrc2 = (xtfloatx2 *)(p_scr_in + itr_n * plane_size);
360 p_dst = (xtfloatx2 *)p_scratch;
361 vecmean32_inpx2(p_src1, p_wsrc2, p_dst, plane_size);
362 }
363 temp_inp_n = 1;
364 }break;
365 case 1: {
366 int plane_size = temp_inp_h * temp_inp_w * temp_inp_c;
367 int wc_plane_size = temp_inp_w * temp_inp_c;
368 for(itr_n=0; itr_n < (temp_inp_n); itr_n++)
369 {
370 p_src1 = (xtfloatx2 *)(p_scratch + + (itr_n * plane_size));
371 for(itr_h = 1; itr_h < ((temp_inp_h - 1) & ~(2 - 1)); itr_h += 2)
372 {
373 p_wsrc2 = (xtfloatx2 *)(p_scr_in + (itr_n * plane_size) + (itr_h * wc_plane_size));
374 p_wsrc3 = (xtfloatx2 *)(p_scr_in + (itr_n * plane_size) + ((itr_h + 1) * wc_plane_size));
375 p_dst = (xtfloatx2 *)(p_scratch + (itr_n * wc_plane_size));
376 vecmean32_inpx3(p_src1, p_wsrc2, p_wsrc3, p_dst, wc_plane_size);
377 p_src1 = (xtfloatx2 *)(p_scratch + (itr_n * wc_plane_size));
378 }
379
380 if((temp_inp_h - 1) & 1)
381 {
382 p_wsrc2 = (xtfloatx2 *)(p_scr_in + (itr_n * plane_size) + (itr_h * wc_plane_size));
383 p_dst = (xtfloatx2 *)(p_scratch + (itr_n * wc_plane_size));
384 vecmean32_inpx2(p_src1, p_wsrc2, p_dst, plane_size);
385 }
386 }
387 temp_inp_h = 1;
388 }break;
389 case 2:{
390 int plane_size = temp_inp_h * temp_inp_w * temp_inp_c;
391 int wc_plane_size = temp_inp_w * temp_inp_c;
392 int hc_plane_size = temp_inp_h * temp_inp_c;
393 for(itr_n=0; itr_n < (temp_inp_n); itr_n++)
394 {
395 for(itr_h=0; itr_h < (temp_inp_h); itr_h++)
396 {
397 p_src1 = (xtfloatx2 *)(p_scratch + ((itr_n * plane_size) + (itr_h * wc_plane_size)));
398 for(itr_w = 1; itr_w < ((temp_inp_w - 1) & ~(2 - 1)); itr_w += 2)
399 {
400 p_wsrc2 = (xtfloatx2 *)(p_scr_in + (itr_n * plane_size) + (itr_h * wc_plane_size) + (itr_w * temp_inp_c));
401 p_wsrc3 = (xtfloatx2 *)(p_scr_in + (itr_n * plane_size) + (itr_h * wc_plane_size) + ((itr_w + 1) * temp_inp_c));
402 p_dst = (xtfloatx2 *)(p_scratch + (itr_n * hc_plane_size) + itr_h * temp_inp_c);
403 vecmean32_inpx3(p_src1, p_wsrc2, p_wsrc3, p_dst, temp_inp_c);
404 p_src1 = (xtfloatx2 *)(p_scratch + (itr_n * hc_plane_size) + (itr_h * temp_inp_c));
405 }
406
407 if((temp_inp_w - 1) & 1)
408 {
409 p_wsrc2 = (xtfloatx2 *)(p_scr_in + (itr_n * plane_size) + (itr_h * wc_plane_size) + (itr_w * temp_inp_c));
410 p_dst = (xtfloatx2 *)(p_scratch + (itr_n * hc_plane_size) + itr_h * temp_inp_c);
411 vecmean32_inpx2(p_src1, p_wsrc2, p_dst, temp_inp_c);
412 }
413 }
414 }
415 temp_inp_w = 1;
416 }break;
417 case 3: {
418 int plane_size = temp_inp_h * temp_inp_w * temp_inp_c;
419 int wc_plane_size = temp_inp_w * temp_inp_c;
420 int hw_plane_size = temp_inp_h * temp_inp_w;
421 int rem_c = ((temp_inp_c) & 3);
422 for(itr_n=0; itr_n < (temp_inp_n); itr_n++)
423 {
424 for(itr_h=0; itr_h < (temp_inp_h); itr_h++)
425 {
426 for(itr_w=0; itr_w < (temp_inp_w); itr_w++)
427 {
428 p_wsrc2 = (xtfloatx2 *)(p_scr_in + (itr_n * plane_size) + (itr_h * wc_plane_size) + (itr_w * temp_inp_c));
429 p_dst = (xtfloatx2 *)(p_scratch + (itr_n * hw_plane_size) + (itr_h * temp_inp_w) + itr_w);
430 align_src = AE_LA64_PP(p_wsrc2);
431 xtfloatx2 i1 = AE_MOVXTFLOATX2_FROMF32X2(AE_MOVDA32(0));
432 for(itr_c = 0; itr_c < (temp_inp_c >> 2); itr_c++)
433 {
434 xtfloatx2 j1, j2;
435 XT_LASX2IP(j1, align_src, p_wsrc2);
436 XT_LASX2IP(j2, align_src, p_wsrc2);
437
438 xtfloatx2 t1 = XT_SEL32_HH_SX2(j1, j1);
439 xtfloatx2 t2 = XT_SEL32_HH_SX2(j2, j2);
440
441 j1 = XT_ADD_SX2(t1, j1);
442 j2 = XT_ADD_SX2(t2, j2);
443
444 i1 = XT_ADD_SX2(i1, j1);
445 i1 = XT_ADD_SX2(i1, j2);
446 }
447
448 //Remainder Loop
449 for(itr_c=0; itr_c < rem_c; itr_c++)
450 {
451 xtfloat j1;
452 XT_LSXP(j1, (xtfloat *)p_wsrc2, sizeof(xtfloat));
453 i1 = XT_ADD_S(i1, j1);
454 }
455 XT_SSX(i1, (xtfloat *)p_dst, 0);
456 }
457 }
458 }
459 temp_inp_c = 1;
460 }break;
461 default:
462 break;
463 }
464 axis_dims_count--;
465 itr_axis++;
466 }
467 }
468
xa_nn_reduce_mean_4D_f32_f32(FLOAT32 * __restrict__ p_out,const WORD32 * const p_out_shape,const FLOAT32 * __restrict__ p_inp,const WORD32 * const p_inp_shape,const WORD32 * __restrict__ p_axis,WORD32 num_out_dims,WORD32 num_inp_dims,WORD32 num_axis_dims,void * __restrict__ p_scratch_in)469 WORD32 xa_nn_reduce_mean_4D_f32_f32(
470 FLOAT32 * __restrict__ p_out,
471 const WORD32 *const p_out_shape,
472 const FLOAT32 * __restrict__ p_inp,
473 const WORD32 *const p_inp_shape,
474 const WORD32 * __restrict__ p_axis,
475 WORD32 num_out_dims,
476 WORD32 num_inp_dims,
477 WORD32 num_axis_dims,
478 void * __restrict__ p_scratch_in)
479 {
480 /* NULL pointer checks */
481 XA_NNLIB_ARG_CHK_PTR(p_out, -1);
482 XA_NNLIB_ARG_CHK_PTR(p_inp, -1);
483 XA_NNLIB_ARG_CHK_PTR(p_axis, -1);
484 XA_NNLIB_ARG_CHK_PTR(p_out_shape, -1);
485 XA_NNLIB_ARG_CHK_PTR(p_inp_shape, -1);
486
487 /* Invalid input checks */
488 XA_NNLIB_ARG_CHK_COND(((num_inp_dims <= 0) || (num_inp_dims > 4)), -1);
489 XA_NNLIB_ARG_CHK_COND(((num_out_dims <= 0) || (num_out_dims > 4)), -1);
490 XA_NNLIB_ARG_CHK_COND(((num_axis_dims < 0) || (num_axis_dims > 4)), -1);
491
492 int axis_itr = 0, inp_itr = 0, out_itr = 0;
493 int num_elm_in_axis = 1;
494 int current, past = -1;
495 for(axis_itr=0; axis_itr < num_axis_dims; axis_itr++)
496 {
497 current = p_axis[axis_itr];
498 XA_NNLIB_ARG_CHK_COND(((current < 0) || (current > (num_inp_dims - 1))), -1);
499 XA_NNLIB_ARG_CHK_COND((p_inp_shape[current] > 1024), -1);
500
501 /* Avoid calculation in case of repeated axis dims*/
502 if(current != past)
503 {
504 num_elm_in_axis *= p_inp_shape[current];
505 past = current;
506 }
507 }
508
509 for(inp_itr=0; inp_itr < num_inp_dims; inp_itr++)
510 {
511 XA_NNLIB_ARG_CHK_COND((p_inp_shape[inp_itr] <= 0), -1);
512 }
513
514 int out_length = 1;
515 for(out_itr=0; out_itr < num_out_dims; out_itr++)
516 {
517 XA_NNLIB_ARG_CHK_COND((p_out_shape[out_itr] <= 0), -1);
518 out_length *= p_out_shape[out_itr];
519 }
520
521 /* Pointer alignment checks */
522 XA_NNLIB_ARG_CHK_ALIGN(p_out, sizeof(FLOAT32), -1);
523 XA_NNLIB_ARG_CHK_ALIGN(p_inp, sizeof(FLOAT32), -1);
524 XA_NNLIB_ARG_CHK_ALIGN(p_axis, sizeof(WORD32), -1);
525 XA_NNLIB_ARG_CHK_ALIGN(p_out_shape, sizeof(WORD32), -1);
526 XA_NNLIB_ARG_CHK_ALIGN(p_inp_shape, sizeof(WORD32), -1);
527
528 FLOAT32 *p_in = (FLOAT32 *)(p_inp);
529 WORD32 *p_scratch = (WORD32 *)(ALIGN_PTR(p_scratch_in, ALIGNMENT_8));
530
531 // Changing order of axis data so that reduce max will be first computed
532 // across largest inp shape dim in axis. This is required to
533 // minimize the scratch usage.
534 int inp_length = 1, p_axis_data[4] = {0}, inp_shape_max;
535 if(num_axis_dims)
536 {
537 inp_shape_max = p_inp_shape[p_axis[0]];
538 axis_itr = 1;
539 int max_axis_itr = 0;
540 int temp_p_axis_0 = p_axis[0];
541 for(axis_itr = 0; axis_itr < num_axis_dims; axis_itr++)
542 {
543 p_axis_data[axis_itr] = p_axis[axis_itr];
544 }
545 for(axis_itr = 1; axis_itr < num_axis_dims; axis_itr++)
546 {
547 if(p_inp_shape[p_axis[axis_itr]] > inp_shape_max)
548 {
549 inp_shape_max = p_inp_shape[p_axis[axis_itr]];
550 max_axis_itr = axis_itr;
551 }
552 }
553 p_axis_data[0] = p_axis_data[max_axis_itr];
554 p_axis_data[max_axis_itr] = temp_p_axis_0;
555
556 inp_itr = 0;
557 for(inp_itr=0; inp_itr < num_inp_dims; inp_itr++)
558 {
559 inp_length *= p_inp_shape[inp_itr];
560 }
561
562 memset(p_scratch, 0, ((inp_length / inp_shape_max) * sizeof(WORD32))); //TODO: Alternate approach for memset?
563 }
564
565 // Promoting lesser dim tensors to 4D tensors. Also modifying axis
566 // data accordingly.
567 int p_4D_inp_shape[4] = {1, 1, 1, 1};
568 int itr = num_inp_dims - 1;
569 int count = 3;
570 while(itr >= 0)
571 {
572 p_4D_inp_shape[count] = p_inp_shape[itr];
573 itr--;
574 count--;
575 }
576 for(itr = 0; itr < num_axis_dims; itr++)
577 {
578 p_axis_data[itr] = p_axis_data[itr] + (4 - num_inp_dims);
579 }
580 ae_valign align_out = AE_ZALIGN64();
581
582 if(num_axis_dims)
583 {
584 if(num_elm_in_axis > 1)
585 {
586 xa_nn_reduce_sum_4D_f32_f32(p_in,
587 p_4D_inp_shape,
588 p_axis_data,
589 num_inp_dims,
590 num_axis_dims,
591 p_scratch);
592 itr = 0;
593 xtfloatx2 *p_src1 = (xtfloatx2 *)(p_scratch);
594
595 float div = 1;
596
597 for(int i = 0; i < num_axis_dims; i++)
598 {
599 div = div * (float)p_4D_inp_shape[p_axis_data[i]];
600 }
601
602 float mul = 1 / div;
603
604 xtfloatx2 multiplier = XT_LSX((xtfloat *)&mul, 0);
605
606 for(itr = 0; itr < (out_length >> 3); itr++)
607 {
608 xtfloatx2 temp1, temp2, temp3, temp4;
609
610 temp2 = XT_LSX2X(p_src1, 8);
611 temp3 = XT_LSX2X(p_src1, 16);
612 temp4 = XT_LSX2X(p_src1, 24);
613 XT_LSX2XP(temp1, p_src1, 32);
614
615 temp1 = XT_MUL_SX2(temp1, multiplier);
616 temp2 = XT_MUL_SX2(temp2, multiplier);
617 temp3 = XT_MUL_SX2(temp3, multiplier);
618 temp4 = XT_MUL_SX2(temp4, multiplier);
619
620 XT_SASX2IP(temp1, align_out, (xtfloatx2 *)p_out);
621 XT_SASX2IP(temp2, align_out, (xtfloatx2 *)p_out);
622 XT_SASX2IP(temp3, align_out, (xtfloatx2 *)p_out);
623 XT_SASX2IP(temp4, align_out, (xtfloatx2 *)p_out);
624 }
625 AE_SA64POS_FP(align_out, p_out);
626
627 for(itr = 0; itr < (out_length & 7); itr++)
628 {
629 xtfloat temp1;
630 XT_LSXP(temp1, (xtfloat *)p_src1, 4);
631 temp1 = XT_MUL_S(temp1, multiplier);
632 XT_SSXP(temp1, (xtfloat *)p_out, 4);
633 }
634 }
635 else
636 {
637
638 memcpy(p_out, p_inp, inp_length * sizeof(FLOAT32));
639 }
640 }
641 else
642 {
643 memcpy(p_out, p_inp, inp_length * sizeof(FLOAT32));
644 }
645
646 return 0;
647 }
648