xref: /aosp_15_r20/external/executorch/backends/cadence/hifi/third-party/nnlib/xa_nn_reduce_32_32.c (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 #include "xa_nnlib_common.h"
2 #include <string.h>
3 //#include "xa_nn_basic_state.h"
4 #include "xa_nnlib_common_macros.h"
5 
6 #define ALIGNMENT_8   8
7 
8 #define ALIGN_PTR(x, bytes)     ((((unsigned)(x))+(bytes-1))&(~(bytes-1)))
9 
vecmean16_inpx3(const xtfloatx2 * p_src1,const xtfloat * p_src2,const xtfloat * p_src3,xtfloatx2 * p_dst,int N)10 static void vecmean16_inpx3(const xtfloatx2 *p_src1, const xtfloat* p_src2, const xtfloat* p_src3, xtfloatx2 *p_dst, int N){
11   int i = 0;
12   ae_valign align_src1, align_dst;
13   ae_valign align_src2, align_src3;
14   align_src1 = AE_LA64_PP(p_src1);
15   align_src2 = AE_LA64_PP(p_src2);
16   align_src3 = AE_LA64_PP(p_src3);
17   align_dst = AE_ZALIGN64();
18 
19   for(i=0; i < (N >> 2); i++)
20   {
21     xtfloatx2 j1_h, j1_l, j2_h, j2_l;
22 
23     xtfloatx2 wout1, wout2;
24     XT_LASX2IP(wout1, align_src1, p_src1);
25     XT_LASX2IP(wout2, align_src1, p_src1);
26 
27     XT_LASX2IP(j1_h, align_src2, (xtfloatx2 *)p_src2);
28     XT_LASX2IP(j1_l, align_src2, (xtfloatx2 *)p_src2);
29     XT_LASX2IP(j2_h, align_src3, (xtfloatx2 *)p_src3);
30     XT_LASX2IP(j2_l, align_src3, (xtfloatx2 *)p_src3);
31 
32     j1_h = XT_ADD_SX2(j1_h, j2_h);
33     j1_l = XT_ADD_SX2(j1_l, j2_l);
34     wout1 = XT_ADD_SX2(wout1, j1_h);
35     wout2 = XT_ADD_SX2(wout2, j1_l);
36 
37     XT_SASX2IP(wout1, align_dst, p_dst);
38     XT_SASX2IP(wout2, align_dst, p_dst);
39   }
40   AE_SA64POS_FP(align_dst, p_dst); // finalize the stream
41 
42   //Remainder Loop
43   for(i=0; i < (N & 3); i++)
44   {
45     xtfloat j1, j2;
46     xtfloat wout1;
47     XT_LSXP(wout1, (xtfloat *)p_src1, sizeof(xtfloat));
48     j1 = (xtfloat) *(p_src2 + i);
49     j2 = (xtfloat) *(p_src3 + i);
50 
51     j1 = XT_ADD_S(j1, j2);
52     wout1 = XT_ADD_S(wout1, j1);
53     XT_SSXP(wout1, (xtfloat *)p_dst, sizeof(xtfloat));
54   }
55 }
56 
vecmean16_inpx2(const xtfloatx2 * p_src1,const xtfloat * p_src2,xtfloatx2 * p_dst,int N)57 static void vecmean16_inpx2(const xtfloatx2 *p_src1, const xtfloat* p_src2, xtfloatx2 *p_dst, int N){
58   ae_valign align_src1, align_dst;
59   ae_valign align_src2;
60   align_src1 = AE_LA64_PP(p_src1);
61   align_src2 = AE_LA64_PP(p_src2);
62   align_dst = AE_ZALIGN64();
63 
64   int i = 0;
65   for(i=0; i < (N >> 2); i++)
66   {
67     xtfloatx2 j1, j2;
68     xtfloatx2 wout1, wout2;
69     XT_LASX2IP(wout1, align_src1, p_src1);
70     XT_LASX2IP(wout2, align_src1, p_src1);
71 
72     XT_LASX2IP(j1, align_src2, (xtfloatx2 *)p_src2);
73     XT_LASX2IP(j2, align_src2, (xtfloatx2 *)p_src2);
74 
75     wout1 = XT_ADD_SX2(wout1, j1);
76     wout2 = XT_ADD_SX2(wout2, j2);
77 
78     XT_SASX2IP(wout1, align_dst, p_dst);
79     XT_SASX2IP(wout2, align_dst, p_dst);
80   }
81   AE_SA64POS_FP(align_dst, p_dst); // finalize the stream
82 
83   //Remainder Loop
84   for(i=0; i < (N & 3); i++)
85   {
86     xtfloat j1;
87     xtfloat wout1;
88     XT_LSXP(wout1, (xtfloat *)p_src1, sizeof(xtfloat));
89     j1 = (xtfloat) *(p_src2 + i);
90     wout1 = XT_ADD_S(wout1, j1);
91     XT_SSXP(wout1, (xtfloat *)p_dst, sizeof(xtfloat));
92   }
93 }
94 
vecmean32_inpx3(const xtfloatx2 * p_src1,const xtfloatx2 * p_wsrc2,const xtfloatx2 * p_wsrc3,xtfloatx2 * p_dst,int N)95 static void vecmean32_inpx3(const xtfloatx2* p_src1, const xtfloatx2* p_wsrc2, const xtfloatx2* p_wsrc3, xtfloatx2 *p_dst, int N){
96   ae_valign align_src1, align_src2, align_src3, align_dst;
97   align_src1 = AE_LA64_PP(p_src1);
98   align_src2 = AE_LA64_PP(p_wsrc2);
99   align_src3 = AE_LA64_PP(p_wsrc3);
100   align_dst = AE_ZALIGN64();
101 
102   int i = 0;
103   for(i=0; i < (N >> 2); i++)
104   {
105     xtfloatx2 j1, j2, j3, j4;
106     xtfloatx2 wj1, wj2;
107     xtfloatx2 wout1, wout2;
108     XT_LASX2IP(wout1, align_src1, p_src1);
109     XT_LASX2IP(wout2, align_src1, p_src1);
110     XT_LASX2IP(j1, align_src2, p_wsrc2);
111     XT_LASX2IP(j2, align_src3, p_wsrc3);
112     XT_LASX2IP(j3, align_src2, p_wsrc2);
113     XT_LASX2IP(j4, align_src3, p_wsrc3);
114 
115     wj1 = XT_ADD_SX2(j1, j2);
116     wj2 = XT_ADD_SX2(j3, j4);
117     wout1 = XT_ADD_SX2(wout1, wj1);
118     wout2 = XT_ADD_SX2(wout2, wj2);
119     XT_SASX2IP(wout1, align_dst, p_dst);
120     XT_SASX2IP(wout2, align_dst, p_dst);
121   }
122   AE_SA64POS_FP(align_dst, p_dst); // finalize the stream
123 
124   //Remainder Loop
125   for(i=0; i < (N & 3); i++)
126   {
127     xtfloat j1, j2;
128     xtfloat wj1;
129     xtfloat wout1;
130     XT_LSXP(wout1, (xtfloat *)p_src1, 4);
131     XT_LSXP(j1, (xtfloat *)p_wsrc2, 4);
132     XT_LSXP(j2, (xtfloat *)p_wsrc3, 4);
133     wj1 = XT_ADD_S(j1, j2);
134     wout1 = XT_ADD_S(wout1, wj1);
135     XT_SSXP(wout1, (xtfloat *)p_dst, sizeof(xtfloat));
136   }
137 }
138 
vecmean32_inpx2(const xtfloatx2 * p_src1,const xtfloatx2 * p_wsrc2,xtfloatx2 * p_dst,int N)139 static void vecmean32_inpx2(const xtfloatx2* p_src1, const xtfloatx2* p_wsrc2, xtfloatx2 *p_dst, int N){
140   ae_valign align_src1, align_src2, align_dst;
141   align_src1 = AE_LA64_PP(p_src1);
142   align_src2 = AE_LA64_PP(p_wsrc2);
143   align_dst = AE_ZALIGN64();
144 
145   int i = 0;
146   for(i=0; i < (N >> 2); i++)
147   {
148     xtfloatx2 j1, j2;
149     xtfloatx2 wout1, wout2;
150     XT_LASX2IP(wout1, align_src1, p_src1);
151     XT_LASX2IP(wout2, align_src1, p_src1);
152     XT_LASX2IP(j1, align_src2, p_wsrc2);
153     XT_LASX2IP(j2, align_src2, p_wsrc2);
154     wout1 = XT_ADD_SX2(wout1, j1);
155     wout2 = XT_ADD_SX2(wout2, j2);
156     XT_SASX2IP(wout1, align_dst, p_dst);
157     XT_SASX2IP(wout2, align_dst, p_dst);
158   }
159   AE_SA64POS_FP(align_dst, p_dst); // finalize the stream
160 
161   //Remainder Loop
162   for(i=0; i < (N & 3); i++)
163   {
164     xtfloat j1;
165     xtfloat wout1;
166     XT_LSXP(wout1, (xtfloat *)p_src1, 4);
167     XT_LSXP(j1, (xtfloat *)p_wsrc2, 4);
168     wout1 = XT_ADD_S(wout1, j1);
169     XT_SSXP(wout1, (xtfloat *)p_dst, sizeof(WORD32));
170   }
171 }
172 
xa_nn_reduce_sum_4D_f32_f32(const FLOAT32 * __restrict__ p_inp,const WORD32 * const p_4D_inp_shape,const WORD32 * __restrict__ p_axis_data,WORD32 num_inp_dims,WORD32 num_axis_dims,pVOID p_scratch_in)173 static inline void xa_nn_reduce_sum_4D_f32_f32(const FLOAT32 * __restrict__ p_inp
174                                                        ,const WORD32 *const p_4D_inp_shape
175                                                        ,const WORD32 * __restrict__ p_axis_data
176                                                        ,WORD32 num_inp_dims
177                                                        ,WORD32 num_axis_dims
178                                                        ,pVOID p_scratch_in)
179 {
180   xtfloat *p_in = (xtfloat *)(p_inp);
181   xtfloat *p_scratch = (xtfloat *)(p_scratch_in);
182 
183   int temp_inp_n = p_4D_inp_shape[0];
184   int temp_inp_h = p_4D_inp_shape[1];
185   int temp_inp_w = p_4D_inp_shape[2];
186   int temp_inp_c = p_4D_inp_shape[3];
187 
188   int itr_axis = 0, itr_n = 0, itr_h = 0, itr_w = 0, itr_c = 0;
189   xtfloat *p_src2, *p_src3;
190   xtfloatx2 *p_src1;
191   xtfloatx2 * p_dst;
192   ae_valign align_src2;
193 
194   int axis_dims_count = num_axis_dims;
195   if(axis_dims_count)
196   {
197     switch(p_axis_data[itr_axis])
198     {
199       case 0: {
200         int plane_size = temp_inp_h * temp_inp_w * temp_inp_c;
201         for(itr_n=0; itr_n < (temp_inp_n & ~(2 - 1)); itr_n += 2)
202         {
203           p_src1 = (xtfloatx2 *)p_scratch;
204           p_src2 = p_in + itr_n * plane_size;
205           p_src3 = p_in + (itr_n + 1) * plane_size;
206           p_dst  = (xtfloatx2 *)p_scratch;
207           vecmean16_inpx3(p_src1, p_src2, p_src3, p_dst, plane_size);
208         }
209 
210         if(temp_inp_n & 1)
211         {
212           p_src1 = (xtfloatx2 *)p_scratch;
213           p_src2 = (p_in + itr_n * plane_size);
214           p_dst  = (xtfloatx2 *)p_scratch;
215           vecmean16_inpx2(p_src1, p_src2, p_dst, plane_size);
216         }
217         temp_inp_n = 1;
218         }break;
219       case 1: {
220         int plane_size = temp_inp_h * temp_inp_w * temp_inp_c;
221         int wc_plane_size = temp_inp_w * temp_inp_c;
222         for(itr_n=0; itr_n < (temp_inp_n); itr_n++)
223         {
224           p_src1 = (xtfloatx2 *)(p_scratch + (itr_n * wc_plane_size));
225           for(itr_h=0; itr_h < (temp_inp_h & ~(2 - 1)); itr_h += 2)
226           {
227             p_src2 = p_in + (itr_n * plane_size) + (itr_h * wc_plane_size);
228             p_src3 = p_in + (itr_n * plane_size) + ((itr_h + 1) * wc_plane_size);
229             p_dst = (xtfloatx2 *)(p_scratch + (itr_n * wc_plane_size));
230             vecmean16_inpx3(p_src1, p_src2, p_src3, p_dst, wc_plane_size);
231             p_src1 = (xtfloatx2 *)(p_scratch + (itr_n * wc_plane_size));
232           }
233 
234           if(temp_inp_h & 1)
235           {
236             p_src2 = p_in + (itr_n * plane_size) + (itr_h * wc_plane_size);
237             p_dst = (xtfloatx2 *)(p_scratch + (itr_n * wc_plane_size));
238             vecmean16_inpx2(p_src1, p_src2, p_dst, wc_plane_size);
239           }
240         }
241         temp_inp_h = 1;
242         }break;
243       case 2:{
244         int plane_size = temp_inp_h * temp_inp_w * temp_inp_c;
245         int wc_plane_size = temp_inp_w * temp_inp_c;
246         int hc_plane_size = temp_inp_h * temp_inp_c;
247 
248         for(itr_n=0; itr_n < (temp_inp_n); itr_n++)
249         {
250           for(itr_h=0; itr_h < (temp_inp_h); itr_h++)
251           {
252             p_src1 = (xtfloatx2 *)(p_scratch + (((itr_n * hc_plane_size) + itr_h * temp_inp_c)));
253             for(itr_w=0; itr_w < (temp_inp_w & ~(2 - 1)); itr_w += 2)
254             {
255               p_src2 = p_in + (itr_n * plane_size) + (itr_h * wc_plane_size) + (itr_w * temp_inp_c);
256               p_src3 = p_in + (itr_n * plane_size) + (itr_h * wc_plane_size) + ((itr_w + 1) * temp_inp_c);
257               p_dst = (xtfloatx2 *)(p_scratch + (itr_n * hc_plane_size) + itr_h * temp_inp_c);
258               vecmean16_inpx3(p_src1, p_src2, p_src3, p_dst, temp_inp_c);
259               p_src1 = (xtfloatx2 *)(p_scratch + (itr_n * hc_plane_size) + (itr_h * temp_inp_c));
260             }
261 
262             if(temp_inp_w & 1)
263             {
264               p_src2 = p_in + (itr_n * plane_size) + (itr_h * wc_plane_size) + (itr_w * temp_inp_c);
265               p_dst = (xtfloatx2 *)(p_scratch + (itr_n * hc_plane_size) + itr_h * temp_inp_c);
266               vecmean16_inpx2(p_src1, p_src2, p_dst, temp_inp_c);
267             }
268           }
269           }
270         temp_inp_w = 1;
271         }break;
272       case 3: {
273         int plane_size = temp_inp_h * temp_inp_w * temp_inp_c;
274         int wc_plane_size = temp_inp_w * temp_inp_c;
275         int hw_plane_size = temp_inp_h * temp_inp_w;
276         int rem_c = (temp_inp_c & 7);
277 
278         for(itr_n=0; itr_n < (temp_inp_n); itr_n++)
279         {
280           for(itr_h=0; itr_h < (temp_inp_h); itr_h++)
281           {
282             for(itr_w=0; itr_w < (temp_inp_w); itr_w++)
283             {
284               p_src1 = (xtfloatx2 *)(p_scratch + (((itr_n * hw_plane_size) + (itr_h * temp_inp_w) + itr_w)));
285               p_src2 = p_in + (itr_n * plane_size) + (itr_h * wc_plane_size) + (itr_w * temp_inp_c);
286               p_dst = (xtfloatx2 *)(p_scratch + (itr_n * hw_plane_size) + (itr_h * temp_inp_w) + itr_w);
287               align_src2 = AE_LA64_PP(p_src2);
288 
289               for(itr_c=0; itr_c < (temp_inp_c >> 3); itr_c++)
290               {
291                 xtfloatx2 j11, j12, j21, j22, i1;
292                 i1 = XT_LSX((xtfloat *)p_src1, 0);
293                 XT_LASX2IP(j11, align_src2, (xtfloatx2 *)p_src2);
294                 XT_LASX2IP(j12, align_src2, (xtfloatx2 *)p_src2);
295                 XT_LASX2IP(j21, align_src2, (xtfloatx2 *)p_src2);
296                 XT_LASX2IP(j22, align_src2, (xtfloatx2 *)p_src2);
297 
298                 j11 = XT_ADD_SX2(j11, j12);
299                 j21 = XT_ADD_SX2(j21, j22);
300 
301                 xtfloatx2 t1 = XT_SEL32_HH_SX2(j11, j11);
302                 xtfloatx2 t2 = XT_SEL32_HH_SX2(j21, j21);
303 
304                 j11 = XT_ADD_SX2(j11, t1);
305                 j21 = XT_ADD_SX2(j21, t2);
306 
307                 j11 = XT_ADD_SX2(j11, j21);
308                 i1 = XT_ADD_SX2(i1, j11);
309 
310                 XT_SSX(i1, (xtfloat *)p_dst, 0);
311 
312                 p_src1 = p_dst;
313               }
314               //Remainder Loop
315               for(itr_c=0; itr_c < rem_c ; itr_c++)
316               {
317                 xtfloat j1;
318                 xtfloat i1;
319                 i1 = XT_LSX((xtfloat *)p_src1, 0);
320                 j1 = *p_src2++;
321 
322                 i1 = XT_ADD_S(i1, j1);
323                 XT_SSX(i1, (xtfloat *)p_dst, 0);
324               }
325             }
326           }
327         }
328         temp_inp_c = 1;
329         }break;
330       default:
331         break;
332     }
333 
334     axis_dims_count--;
335     itr_axis++;
336   }
337 
338   while(axis_dims_count)
339   {
340     ae_valign align_src;
341     xtfloat *p_scr_in = p_scratch;
342     xtfloatx2 *p_wsrc2, *p_wsrc3;
343     switch(p_axis_data[itr_axis])
344     {
345       case 0: {
346         int plane_size = temp_inp_h * temp_inp_w * temp_inp_c;
347         for(itr_n=1; itr_n < ((temp_inp_n -1) & ~(2 - 1)); itr_n += 2)
348         {
349           p_src1 = (xtfloatx2 *)p_scratch;
350           p_wsrc2 = (xtfloatx2 *)(p_scr_in + itr_n * plane_size);
351           p_wsrc3 = (xtfloatx2 *)(p_scr_in + (itr_n + 1) * plane_size);
352           p_dst  = (xtfloatx2 *)p_scratch;
353           vecmean32_inpx3(p_src1, p_wsrc2, p_wsrc3, p_dst, plane_size);
354         }
355 
356         if((temp_inp_n - 1) & 1)
357         {
358           p_src1 = (xtfloatx2 *)p_scratch;
359           p_wsrc2 = (xtfloatx2 *)(p_scr_in + itr_n * plane_size);
360           p_dst  =  (xtfloatx2 *)p_scratch;
361           vecmean32_inpx2(p_src1, p_wsrc2, p_dst, plane_size);
362         }
363         temp_inp_n = 1;
364         }break;
365       case 1: {
366         int plane_size = temp_inp_h * temp_inp_w * temp_inp_c;
367         int wc_plane_size = temp_inp_w * temp_inp_c;
368         for(itr_n=0; itr_n < (temp_inp_n); itr_n++)
369         {
370           p_src1 = (xtfloatx2 *)(p_scratch + + (itr_n * plane_size));
371           for(itr_h = 1; itr_h < ((temp_inp_h - 1) & ~(2 - 1)); itr_h += 2)
372           {
373             p_wsrc2 = (xtfloatx2 *)(p_scr_in + (itr_n * plane_size) + (itr_h * wc_plane_size));
374             p_wsrc3 = (xtfloatx2 *)(p_scr_in + (itr_n * plane_size) + ((itr_h + 1) * wc_plane_size));
375             p_dst = (xtfloatx2 *)(p_scratch + (itr_n * wc_plane_size));
376             vecmean32_inpx3(p_src1, p_wsrc2, p_wsrc3, p_dst, wc_plane_size);
377             p_src1 = (xtfloatx2 *)(p_scratch + (itr_n * wc_plane_size));
378           }
379 
380           if((temp_inp_h - 1) & 1)
381           {
382             p_wsrc2 = (xtfloatx2 *)(p_scr_in + (itr_n * plane_size) + (itr_h * wc_plane_size));
383             p_dst = (xtfloatx2 *)(p_scratch + (itr_n * wc_plane_size));
384             vecmean32_inpx2(p_src1, p_wsrc2, p_dst, plane_size);
385           }
386         }
387         temp_inp_h = 1;
388         }break;
389       case 2:{
390         int plane_size = temp_inp_h * temp_inp_w * temp_inp_c;
391         int wc_plane_size = temp_inp_w * temp_inp_c;
392         int hc_plane_size = temp_inp_h * temp_inp_c;
393         for(itr_n=0; itr_n < (temp_inp_n); itr_n++)
394         {
395           for(itr_h=0; itr_h < (temp_inp_h); itr_h++)
396           {
397             p_src1 = (xtfloatx2 *)(p_scratch + ((itr_n * plane_size) + (itr_h * wc_plane_size)));
398             for(itr_w = 1; itr_w < ((temp_inp_w - 1) & ~(2 - 1)); itr_w += 2)
399             {
400               p_wsrc2 = (xtfloatx2 *)(p_scr_in + (itr_n * plane_size) + (itr_h * wc_plane_size) + (itr_w * temp_inp_c));
401               p_wsrc3 = (xtfloatx2 *)(p_scr_in + (itr_n * plane_size) + (itr_h * wc_plane_size) + ((itr_w + 1) * temp_inp_c));
402               p_dst = (xtfloatx2 *)(p_scratch + (itr_n * hc_plane_size) + itr_h * temp_inp_c);
403               vecmean32_inpx3(p_src1, p_wsrc2, p_wsrc3, p_dst, temp_inp_c);
404               p_src1 = (xtfloatx2 *)(p_scratch + (itr_n * hc_plane_size) + (itr_h * temp_inp_c));
405             }
406 
407             if((temp_inp_w - 1) & 1)
408             {
409               p_wsrc2 = (xtfloatx2 *)(p_scr_in + (itr_n * plane_size) + (itr_h * wc_plane_size) + (itr_w * temp_inp_c));
410               p_dst = (xtfloatx2 *)(p_scratch + (itr_n * hc_plane_size) + itr_h * temp_inp_c);
411               vecmean32_inpx2(p_src1, p_wsrc2, p_dst, temp_inp_c);
412             }
413           }
414         }
415         temp_inp_w = 1;
416         }break;
417       case 3: {
418         int plane_size = temp_inp_h * temp_inp_w * temp_inp_c;
419         int wc_plane_size = temp_inp_w * temp_inp_c;
420         int hw_plane_size = temp_inp_h * temp_inp_w;
421         int rem_c = ((temp_inp_c) & 3);
422         for(itr_n=0; itr_n < (temp_inp_n); itr_n++)
423         {
424           for(itr_h=0; itr_h < (temp_inp_h); itr_h++)
425           {
426             for(itr_w=0; itr_w < (temp_inp_w); itr_w++)
427             {
428               p_wsrc2 = (xtfloatx2 *)(p_scr_in + (itr_n * plane_size) + (itr_h * wc_plane_size) + (itr_w * temp_inp_c));
429               p_dst = (xtfloatx2 *)(p_scratch + (itr_n * hw_plane_size) + (itr_h * temp_inp_w) + itr_w);
430               align_src = AE_LA64_PP(p_wsrc2);
431               xtfloatx2 i1 = AE_MOVXTFLOATX2_FROMF32X2(AE_MOVDA32(0));
432               for(itr_c = 0; itr_c < (temp_inp_c >> 2); itr_c++)
433               {
434                 xtfloatx2 j1, j2;
435                 XT_LASX2IP(j1, align_src, p_wsrc2);
436                 XT_LASX2IP(j2, align_src, p_wsrc2);
437 
438                 xtfloatx2 t1 = XT_SEL32_HH_SX2(j1, j1);
439                 xtfloatx2 t2 = XT_SEL32_HH_SX2(j2, j2);
440 
441                 j1 = XT_ADD_SX2(t1, j1);
442                 j2 = XT_ADD_SX2(t2, j2);
443 
444                 i1 = XT_ADD_SX2(i1, j1);
445                 i1 = XT_ADD_SX2(i1, j2);
446               }
447 
448               //Remainder Loop
449               for(itr_c=0; itr_c < rem_c; itr_c++)
450               {
451                 xtfloat j1;
452                 XT_LSXP(j1, (xtfloat *)p_wsrc2, sizeof(xtfloat));
453                 i1 = XT_ADD_S(i1, j1);
454               }
455               XT_SSX(i1, (xtfloat *)p_dst, 0);
456             }
457           }
458         }
459         temp_inp_c = 1;
460         }break;
461       default:
462         break;
463     }
464     axis_dims_count--;
465     itr_axis++;
466   }
467 }
468 
xa_nn_reduce_mean_4D_f32_f32(FLOAT32 * __restrict__ p_out,const WORD32 * const p_out_shape,const FLOAT32 * __restrict__ p_inp,const WORD32 * const p_inp_shape,const WORD32 * __restrict__ p_axis,WORD32 num_out_dims,WORD32 num_inp_dims,WORD32 num_axis_dims,void * __restrict__ p_scratch_in)469 WORD32 xa_nn_reduce_mean_4D_f32_f32(
470                                     FLOAT32 * __restrict__ p_out,
471                                     const WORD32 *const p_out_shape,
472                                     const FLOAT32 * __restrict__ p_inp,
473                                     const WORD32 *const p_inp_shape,
474                                     const WORD32 * __restrict__ p_axis,
475                                     WORD32 num_out_dims,
476                                     WORD32 num_inp_dims,
477                                     WORD32 num_axis_dims,
478                                     void * __restrict__ p_scratch_in)
479 {
480   /* NULL pointer checks */
481   XA_NNLIB_ARG_CHK_PTR(p_out, -1);
482   XA_NNLIB_ARG_CHK_PTR(p_inp, -1);
483   XA_NNLIB_ARG_CHK_PTR(p_axis, -1);
484   XA_NNLIB_ARG_CHK_PTR(p_out_shape, -1);
485   XA_NNLIB_ARG_CHK_PTR(p_inp_shape, -1);
486 
487   /* Invalid input checks */
488   XA_NNLIB_ARG_CHK_COND(((num_inp_dims <= 0) || (num_inp_dims > 4)), -1);
489   XA_NNLIB_ARG_CHK_COND(((num_out_dims <= 0) || (num_out_dims > 4)), -1);
490   XA_NNLIB_ARG_CHK_COND(((num_axis_dims < 0) || (num_axis_dims > 4)), -1);
491 
492   int axis_itr = 0, inp_itr = 0, out_itr = 0;
493   int num_elm_in_axis = 1;
494   int current, past = -1;
495   for(axis_itr=0; axis_itr < num_axis_dims; axis_itr++)
496   {
497     current = p_axis[axis_itr];
498     XA_NNLIB_ARG_CHK_COND(((current < 0) || (current > (num_inp_dims - 1))), -1);
499     XA_NNLIB_ARG_CHK_COND((p_inp_shape[current] > 1024), -1);
500 
501     /* Avoid calculation in case of repeated axis dims*/
502     if(current != past)
503     {
504       num_elm_in_axis *= p_inp_shape[current];
505       past = current;
506     }
507   }
508 
509   for(inp_itr=0; inp_itr < num_inp_dims; inp_itr++)
510   {
511     XA_NNLIB_ARG_CHK_COND((p_inp_shape[inp_itr] <= 0), -1);
512   }
513 
514   int out_length = 1;
515   for(out_itr=0; out_itr < num_out_dims; out_itr++)
516   {
517     XA_NNLIB_ARG_CHK_COND((p_out_shape[out_itr] <= 0), -1);
518     out_length *= p_out_shape[out_itr];
519   }
520 
521   /* Pointer alignment checks */
522   XA_NNLIB_ARG_CHK_ALIGN(p_out, sizeof(FLOAT32), -1);
523   XA_NNLIB_ARG_CHK_ALIGN(p_inp, sizeof(FLOAT32), -1);
524   XA_NNLIB_ARG_CHK_ALIGN(p_axis, sizeof(WORD32), -1);
525   XA_NNLIB_ARG_CHK_ALIGN(p_out_shape, sizeof(WORD32), -1);
526   XA_NNLIB_ARG_CHK_ALIGN(p_inp_shape, sizeof(WORD32), -1);
527 
528   FLOAT32 *p_in = (FLOAT32 *)(p_inp);
529   WORD32 *p_scratch = (WORD32 *)(ALIGN_PTR(p_scratch_in, ALIGNMENT_8));
530 
531   // Changing order of axis data so that reduce max will be first computed
532   // across largest inp shape dim in axis. This is required to
533   // minimize the scratch usage.
534   int inp_length = 1, p_axis_data[4] = {0}, inp_shape_max;
535   if(num_axis_dims)
536   {
537     inp_shape_max = p_inp_shape[p_axis[0]];
538     axis_itr = 1;
539     int max_axis_itr = 0;
540     int temp_p_axis_0 = p_axis[0];
541     for(axis_itr = 0; axis_itr < num_axis_dims; axis_itr++)
542     {
543       p_axis_data[axis_itr] = p_axis[axis_itr];
544     }
545     for(axis_itr = 1; axis_itr < num_axis_dims; axis_itr++)
546     {
547       if(p_inp_shape[p_axis[axis_itr]] > inp_shape_max)
548       {
549         inp_shape_max = p_inp_shape[p_axis[axis_itr]];
550         max_axis_itr = axis_itr;
551       }
552     }
553     p_axis_data[0] = p_axis_data[max_axis_itr];
554     p_axis_data[max_axis_itr] = temp_p_axis_0;
555 
556     inp_itr = 0;
557     for(inp_itr=0; inp_itr < num_inp_dims; inp_itr++)
558     {
559       inp_length *= p_inp_shape[inp_itr];
560     }
561 
562     memset(p_scratch, 0, ((inp_length / inp_shape_max) * sizeof(WORD32))); //TODO: Alternate approach for memset?
563   }
564 
565   // Promoting lesser dim tensors to 4D tensors. Also modifying axis
566   // data accordingly.
567   int p_4D_inp_shape[4] = {1, 1, 1, 1};
568   int itr = num_inp_dims - 1;
569   int count = 3;
570   while(itr >= 0)
571   {
572     p_4D_inp_shape[count] = p_inp_shape[itr];
573     itr--;
574     count--;
575   }
576   for(itr = 0; itr < num_axis_dims; itr++)
577   {
578     p_axis_data[itr] = p_axis_data[itr] + (4 - num_inp_dims);
579   }
580   ae_valign align_out = AE_ZALIGN64();
581 
582   if(num_axis_dims)
583   {
584     if(num_elm_in_axis > 1)
585     {
586       xa_nn_reduce_sum_4D_f32_f32(p_in,
587                                   p_4D_inp_shape,
588                                   p_axis_data,
589                                   num_inp_dims,
590                                   num_axis_dims,
591                                   p_scratch);
592       itr = 0;
593       xtfloatx2 *p_src1 = (xtfloatx2 *)(p_scratch);
594 
595       float div = 1;
596 
597       for(int i = 0; i < num_axis_dims; i++)
598       {
599         div = div * (float)p_4D_inp_shape[p_axis_data[i]];
600       }
601 
602       float mul = 1 / div;
603 
604       xtfloatx2 multiplier = XT_LSX((xtfloat *)&mul, 0);
605 
606       for(itr = 0; itr < (out_length >> 3); itr++)
607       {
608         xtfloatx2 temp1, temp2, temp3, temp4;
609 
610         temp2 = XT_LSX2X(p_src1, 8);
611         temp3 = XT_LSX2X(p_src1, 16);
612         temp4 = XT_LSX2X(p_src1, 24);
613         XT_LSX2XP(temp1, p_src1, 32);
614 
615         temp1 = XT_MUL_SX2(temp1, multiplier);
616         temp2 = XT_MUL_SX2(temp2, multiplier);
617         temp3 = XT_MUL_SX2(temp3, multiplier);
618         temp4 = XT_MUL_SX2(temp4, multiplier);
619 
620         XT_SASX2IP(temp1, align_out, (xtfloatx2 *)p_out);
621         XT_SASX2IP(temp2, align_out, (xtfloatx2 *)p_out);
622         XT_SASX2IP(temp3, align_out, (xtfloatx2 *)p_out);
623         XT_SASX2IP(temp4, align_out, (xtfloatx2 *)p_out);
624       }
625       AE_SA64POS_FP(align_out, p_out);
626 
627       for(itr = 0; itr < (out_length & 7); itr++)
628       {
629         xtfloat temp1;
630         XT_LSXP(temp1, (xtfloat *)p_src1, 4);
631         temp1 = XT_MUL_S(temp1, multiplier);
632         XT_SSXP(temp1, (xtfloat *)p_out, 4);
633       }
634     }
635     else
636     {
637 
638       memcpy(p_out, p_inp, inp_length * sizeof(FLOAT32));
639     }
640   }
641   else
642   {
643     memcpy(p_out, p_inp, inp_length * sizeof(FLOAT32));
644   }
645 
646   return 0;
647 }
648