xref: /aosp_15_r20/external/XNNPACK/src/x32-transposec/sse2.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2021 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5$import math
6$assert IN_PTRS in ["MULTI", "REUSE"]
7$assert OUT_PTRS in ["MULTI", "SWITCH", "MOV"]
8$assert SIZE in [8, 16, 32, 64]
9$TILE_SIZE = int(128/SIZE)
10$NUM_ITERS = int(math.log2(TILE_SIZE))
11
12#include <immintrin.h>
13
14#include <assert.h>
15
16#include <xnnpack/common.h>
17#include <xnnpack/math.h>
18#include <xnnpack/transpose.h>
19#include <xnnpack/unaligned.h>
20
21
22void xnn_x${SIZE}_transposec_ukernel__${TILE_SIZE}x${TILE_SIZE}_${IN_PTRS.lower()}_${OUT_PTRS.lower()}_sse2(
23    const uint${SIZE}_t* input,
24    uint${SIZE}_t* output,
25    size_t input_stride,
26    size_t output_stride,
27    size_t block_width,
28    size_t block_height) XNN_OOB_READS
29{
30  assert(output_stride >= block_height * sizeof(uint${SIZE}_t));
31  assert(input_stride >= block_width * sizeof(uint${SIZE}_t));
32
33  const size_t tile_height = ${TILE_SIZE};
34  const size_t tile_width = ${TILE_SIZE};
35  const size_t tile_hbytes = tile_height * sizeof(uint${SIZE}_t);
36  const size_t tile_wbytes = tile_width * sizeof(uint${SIZE}_t);
37  const size_t input_reset = tile_wbytes - round_down_po2(block_height, tile_height) * input_stride;
38  $if IN_PTRS == "MULTI":
39    const size_t input_offset = tile_height * input_stride;
40  $if OUT_PTRS == "MOV":
41    const size_t output_reset = tile_width * output_stride - round_down_po2(block_height, 2) * sizeof(uint${SIZE}_t) - tile_hbytes;
42  $else:
43    const size_t output_reset = tile_width * output_stride - round_down_po2(block_height, 2) * sizeof(uint${SIZE}_t);
44
45  $if IN_PTRS == "MULTI":
46    const uint${SIZE}_t* i0 = input;
47    $for N in range(1, TILE_SIZE):
48      const uint${SIZE}_t* i${N} = (const uint${SIZE}_t*) ((uintptr_t) i${N-1} + input_stride);
49  $else:
50    const uint${SIZE}_t* i0 = input;
51  $if OUT_PTRS == "MULTI":
52    uint${SIZE}_t* o0 = (uint${SIZE}_t*) output;
53    $for N in range(1, TILE_SIZE):
54      uint${SIZE}_t* o${N} = (uint${SIZE}_t*) ((uintptr_t) o${N-1} + output_stride);
55  $elif OUT_PTRS == "SWITCH":
56    uint${SIZE}_t* o = (uint${SIZE}_t*) output;
57  $else:
58    uint${SIZE}_t* o = (uint${SIZE}_t*) ((uintptr_t) output - tile_hbytes);
59  $if OUT_PTRS == "MOV":
60    const size_t minus_output_stride = -output_stride;
61  $elif OUT_PTRS == "SWITCH" and SIZE != 64:
62    const size_t minus_output_stride = -output_stride;
63
64  do {
65    $if OUT_PTRS == "MULTI":
66      if XNN_UNPREDICTABLE(block_width < 2) {
67        o1 = o0;
68      }
69      $for N in range(2, TILE_SIZE, 2):
70        if XNN_UNPREDICTABLE(block_width <= ${N}) {
71          o${N} = o0;
72        }
73        if XNN_UNPREDICTABLE(block_width < ${N+2}) {
74          o${N+1} = o0;
75        }
76    $elif OUT_PTRS == "MOV":
77      const size_t rem = min(block_width - 1, ${TILE_SIZE-1});
78      const size_t oN_stride = rem * output_stride;
79      const size_t oN_offset = oN_stride + tile_hbytes;
80    $elif OUT_PTRS == "SWITCH":
81      const size_t rem = min(block_width - 1, ${TILE_SIZE-1});
82      const size_t oN_stride = rem * output_stride;
83    size_t bh = block_height;
84    for (; bh >= ${TILE_SIZE}; bh -= ${TILE_SIZE}) {
85      $for N in range(TILE_SIZE):
86        $if IN_PTRS == "REUSE":
87          const __m128i v${NUM_ITERS}_${N} = _mm_loadu_si128((const __m128i*) i0);
88          i0 = (uint${SIZE}_t*) ((uintptr_t) i0 + input_stride);
89        $else:
90          const __m128i v${NUM_ITERS}_${N} = _mm_loadu_si128((const __m128i*) i${N});
91          i${N} = (uint${SIZE}_t*) ((uintptr_t) i${N} + input_offset);
92
93      $for N in range(TILE_SIZE >> 1):
94        const __m128i v${NUM_ITERS-1}_${N*2} = _mm_unpacklo_epi${SIZE}(v${NUM_ITERS}_${N*2}, v${NUM_ITERS}_${N*2+1});
95        const __m128i v${NUM_ITERS-1}_${N*2+1} = _mm_unpackhi_epi${SIZE}(v${NUM_ITERS}_${N*2}, v${NUM_ITERS}_${N*2+1});
96
97      $if NUM_ITERS>=2:
98        $for N in range(0, TILE_SIZE, 4):
99          const __m128i v${NUM_ITERS-2}_${N} = _mm_unpacklo_epi${SIZE*2}(v${NUM_ITERS-1}_${N}, v${NUM_ITERS-1}_${N+2});
100          const __m128i v${NUM_ITERS-2}_${N+1} = _mm_unpackhi_epi${SIZE*2}(v${NUM_ITERS-1}_${N}, v${NUM_ITERS-1}_${N+2});
101          const __m128i v${NUM_ITERS-2}_${N+2} = _mm_unpacklo_epi${SIZE*2}(v${NUM_ITERS-1}_${N+1}, v${NUM_ITERS-1}_${N+3});
102          const __m128i v${NUM_ITERS-2}_${N+3} = _mm_unpackhi_epi${SIZE*2}(v${NUM_ITERS-1}_${N+1}, v${NUM_ITERS-1}_${N+3});
103
104      $if NUM_ITERS>=3:
105        $for M in range(0, TILE_SIZE, 8):
106          $for N in range(0, 4):
107            const __m128i v${NUM_ITERS-3}_${M+2*N} = _mm_unpacklo_epi${SIZE*4}(v${NUM_ITERS-2}_${M+N}, v${NUM_ITERS-2}_${M+N+4});
108            const __m128i v${NUM_ITERS-3}_${M+2*N+1} = _mm_unpackhi_epi${SIZE*4}(v${NUM_ITERS-2}_${M+N}, v${NUM_ITERS-2}_${M+N+4});
109
110      $if NUM_ITERS>=4:
111        $for N in range(TILE_SIZE >> 1):
112          const __m128i v0_${N*2} = _mm_unpacklo_epi64(v1_${N}, v1_${N+8});
113          const __m128i v0_${N*2+1} = _mm_unpackhi_epi64(v1_${N}, v1_${N+8});
114
115      $if OUT_PTRS == "SWITCH":
116        uint${SIZE}_t* oN = (uint${SIZE}_t*) ((uintptr_t) o + oN_stride);
117        switch (rem) {
118          $for N in reversed(range(2, TILE_SIZE)):
119            case ${N}:
120              _mm_storeu_si128((__m128i*) oN, v0_${N});
121              oN = (uint${SIZE}_t*) ((uintptr_t) oN + minus_output_stride);
122          case 1:
123            _mm_storeu_si128((__m128i*) oN, v0_1);
124          case 0:
125            _mm_storeu_si128((__m128i*) o, v0_0);
126            o = (uint${SIZE}_t*) ((uintptr_t) o + tile_hbytes);
127            break;
128          default:
129            XNN_UNREACHABLE;
130        }
131      $elif OUT_PTRS == "MOV":
132        o = (uint${SIZE}_t*) ((uintptr_t) o + oN_offset);
133        _mm_storeu_si128((__m128i*) o, v0_${TILE_SIZE-1});
134        uint${SIZE}_t *oN = (uint${SIZE}_t*) ((uintptr_t) o + minus_output_stride);
135        $for N in reversed(range(2, TILE_SIZE, 2)):
136          if XNN_UNPREDICTABLE(block_width > ${N+1}) {
137            o = oN;
138          }
139          _mm_storeu_si128((__m128i*) o, v0_${N});
140          oN = (uint${SIZE}_t*) ((uintptr_t) o + minus_output_stride);
141          if XNN_UNPREDICTABLE(block_width >= ${N+1}) {
142            o = oN;
143          }
144          _mm_storeu_si128((__m128i*) o, v0_${N-1});
145          oN = (uint${SIZE}_t*) ((uintptr_t) o + minus_output_stride);
146        if XNN_UNPREDICTABLE(block_width > 1) {
147          o = oN;
148        }
149        _mm_storeu_si128((__m128i*) o, v0_0);
150      $else:
151        $for N in reversed(range(TILE_SIZE)):
152          _mm_storeu_si128((__m128i*) o${N}, v0_${N});
153          o${N} = (uint${SIZE}_t*) ((uintptr_t) o${N} + tile_hbytes);
154    }
155    $if OUT_PTRS == "MOV":
156      o = (uint${SIZE}_t*) ((uintptr_t) o + tile_hbytes);
157    if (bh != 0) {
158      $if IN_PTRS == "REUSE":
159        const __m128i v${NUM_ITERS}_0 = _mm_loadu_si128((const __m128i*) i0);
160        $for N in range(1, TILE_SIZE - 1, 2):
161          const uint${SIZE}_t *i${N} = (const uint${SIZE}_t*) ((uintptr_t) i${N-1} + input_stride);
162          if XNN_UNPREDICTABLE(bh < ${N+1}) {
163            i${N} = i${N-1};
164          }
165          const __m128i v${NUM_ITERS}_${N} = _mm_loadu_si128((const __m128i*) i${N});
166          const uint${SIZE}_t *i${N+1} = (const uint${SIZE}_t*) ((uintptr_t) i${N} + input_stride);
167          if XNN_UNPREDICTABLE(bh <= ${N+1}) {
168            i${N+1} = i${N};
169          }
170          const __m128i v${NUM_ITERS}_${N+1} = _mm_loadu_si128((const __m128i*) i${N+1});
171      $else:
172        const __m128i v${NUM_ITERS}_0 = _mm_loadu_si128((const __m128i*) i0);
173        $for N in range(1, TILE_SIZE - 1, 2):
174          if XNN_UNPREDICTABLE(bh < ${N+1}) {
175            i${N} = i0;
176          }
177          const __m128i v${NUM_ITERS}_${N} = _mm_loadu_si128((const __m128i*) i${N});
178          if XNN_UNPREDICTABLE(bh <= ${N+1}) {
179            i${N+1} = i0;
180          }
181          const __m128i v${NUM_ITERS}_${N+1} = _mm_loadu_si128((const __m128i*) i${N+1});
182      const __m128i v${NUM_ITERS}_${TILE_SIZE-1} = _mm_undefined_si128();
183
184      $CONST = "const "
185      $if NUM_ITERS == 1:
186        $CONST = ""
187      $for N in range(TILE_SIZE >> 1):
188        ${CONST}__m128i v${NUM_ITERS-1}_${N*2} = _mm_unpacklo_epi${SIZE}(v${NUM_ITERS}_${N*2}, v${NUM_ITERS}_${N*2+1});
189        ${CONST}__m128i v${NUM_ITERS-1}_${N*2+1} = _mm_unpackhi_epi${SIZE}(v${NUM_ITERS}_${N*2}, v${NUM_ITERS}_${N*2+1});
190
191      $if NUM_ITERS == 2:
192        $CONST = ""
193      $if NUM_ITERS>=2:
194        $for N in range(0, TILE_SIZE, 4):
195          ${CONST}__m128i v${NUM_ITERS-2}_${N} = _mm_unpacklo_epi${SIZE*2}(v${NUM_ITERS-1}_${N}, v${NUM_ITERS-1}_${N+2});
196          ${CONST}__m128i v${NUM_ITERS-2}_${N+1} = _mm_unpackhi_epi${SIZE*2}(v${NUM_ITERS-1}_${N}, v${NUM_ITERS-1}_${N+2});
197          ${CONST}__m128i v${NUM_ITERS-2}_${N+2} = _mm_unpacklo_epi${SIZE*2}(v${NUM_ITERS-1}_${N+1}, v${NUM_ITERS-1}_${N+3});
198          ${CONST}__m128i v${NUM_ITERS-2}_${N+3} = _mm_unpackhi_epi${SIZE*2}(v${NUM_ITERS-1}_${N+1}, v${NUM_ITERS-1}_${N+3});
199
200      $if NUM_ITERS == 3:
201        $CONST = ""
202      $if NUM_ITERS>=3:
203        $for M in range(0, TILE_SIZE, 8):
204          $for N in range(0, 4):
205            ${CONST}__m128i v${NUM_ITERS-3}_${M+2*N} = _mm_unpacklo_epi${SIZE*4}(v${NUM_ITERS-2}_${M+N}, v${NUM_ITERS-2}_${M+N+4});
206            ${CONST}__m128i v${NUM_ITERS-3}_${M+2*N+1} = _mm_unpackhi_epi${SIZE*4}(v${NUM_ITERS-2}_${M+N}, v${NUM_ITERS-2}_${M+N+4});
207
208      $if NUM_ITERS>=4:
209        $for N in range(TILE_SIZE >> 1):
210          __m128i v0_${N*2} = _mm_unpacklo_epi64(v1_${N}, v1_${N+8});
211          __m128i v0_${N*2+1} = _mm_unpackhi_epi64(v1_${N}, v1_${N+8});
212
213      if (bh & ${TILE_SIZE>>1}) {
214        $if OUT_PTRS == "SWITCH":
215          uint${SIZE}_t* oN = (uint${SIZE}_t*) ((uintptr_t) o + oN_stride);
216          switch (rem) {
217            $for N in reversed(range(2, TILE_SIZE)):
218              case ${N}:
219                _mm_storel_epi64((__m128i*) oN, v0_${N});
220                oN = (uint${SIZE}_t*) ((uintptr_t) oN + minus_output_stride);
221            case 1:
222              _mm_storel_epi64((__m128i*) oN, v0_1);
223            case 0:
224              _mm_storel_epi64((__m128i*) o, v0_0);
225              break;
226            default:
227              XNN_UNREACHABLE;
228          }
229          $if NUM_ITERS > 1:
230            o += ${TILE_SIZE>>1};
231        $elif OUT_PTRS == "MOV":
232          o = (uint${SIZE}_t*) ((uintptr_t) o + oN_stride);
233          _mm_storel_epi64((__m128i*) o, v0_${TILE_SIZE-1});
234          uint${SIZE}_t *oN = (uint${SIZE}_t*) ((uintptr_t) o + minus_output_stride);
235          $for N in reversed(range(2, TILE_SIZE, 2)):
236            if XNN_UNPREDICTABLE(block_width > ${N+1}) {
237              o = oN;
238            }
239            _mm_storel_epi64((__m128i*) o, v0_${N});
240            oN = (uint${SIZE}_t*) ((uintptr_t) o + minus_output_stride);
241            if XNN_UNPREDICTABLE(block_width >= ${N+1}) {
242              o = oN;
243            }
244            _mm_storel_epi64((__m128i*) o, v0_${N-1});
245            oN = (uint${SIZE}_t*) ((uintptr_t) o + minus_output_stride);
246          if XNN_UNPREDICTABLE(block_width > 1) {
247            o = oN;
248          }
249          _mm_storel_epi64((__m128i*) o, v0_0);
250          $if NUM_ITERS > 1:
251            o += ${TILE_SIZE>>1};
252        $else:
253          $for N in reversed(range(TILE_SIZE)):
254            _mm_storel_epi64((__m128i*) o${N}, v0_${N});
255            $if NUM_ITERS>1:
256              o${N} += ${TILE_SIZE>>1};
257        $if NUM_ITERS > 1:
258          $for N in range(TILE_SIZE):
259            v0_${N} = _mm_unpackhi_epi64(v0_${N}, v0_${N});
260      }
261
262      $if NUM_ITERS>1:
263        if (bh & ${TILE_SIZE>>2}) {
264          $if OUT_PTRS == "SWITCH":
265            uint${SIZE}_t* oN = (uint${SIZE}_t*) ((uintptr_t) o + oN_stride);
266            switch (rem) {
267              $for N in reversed(range(2, TILE_SIZE)):
268                case ${N}:
269                  unaligned_store_u32(oN, (uint32_t) _mm_cvtsi128_si32(v0_${N}));
270                  oN = (uint${SIZE}_t*) ((uintptr_t) oN + minus_output_stride);
271              case 1:
272                unaligned_store_u32(oN, (uint32_t) _mm_cvtsi128_si32(v0_1));
273              case 0:
274                unaligned_store_u32(o, (uint32_t) _mm_cvtsi128_si32(v0_0));
275                break;
276              default:
277                XNN_UNREACHABLE;
278            }
279            $if NUM_ITERS > 2:
280              o += ${TILE_SIZE>>2};
281          $elif OUT_PTRS == "MOV":
282            o = (uint${SIZE}_t*) ((uintptr_t) o + oN_stride);
283            unaligned_store_u32(o, (uint32_t) _mm_cvtsi128_si32(v0_${TILE_SIZE-1}));
284            uint${SIZE}_t *oN = (uint${SIZE}_t*) ((uintptr_t) o + minus_output_stride);
285            $for N in reversed(range(2, TILE_SIZE, 2)):
286              if XNN_UNPREDICTABLE(block_width > ${N+1}) {
287                o = oN;
288              }
289              unaligned_store_u32(o, (uint32_t) _mm_cvtsi128_si32(v0_${N}));
290              oN = (uint${SIZE}_t*) ((uintptr_t) o + minus_output_stride);
291              if XNN_UNPREDICTABLE(block_width >= ${N+1}) {
292                o = oN;
293              }
294              unaligned_store_u32(o, (uint32_t) _mm_cvtsi128_si32(v0_${N-1}));
295              oN = (uint${SIZE}_t*) ((uintptr_t) o + minus_output_stride);
296            if XNN_UNPREDICTABLE(block_width > 1) {
297              o = oN;
298            }
299            unaligned_store_u32(o, (uint32_t) _mm_cvtsi128_si32(v0_0));
300            $if NUM_ITERS > 2:
301              o += ${TILE_SIZE>>2};
302          $else:
303            $for N in reversed(range(TILE_SIZE)):
304              unaligned_store_u32(o${N}, (uint32_t) _mm_cvtsi128_si32(v0_${N}));
305              $if NUM_ITERS>2:
306                o${N} += ${TILE_SIZE>>2};
307          $if NUM_ITERS > 2:
308            $for N in range(TILE_SIZE):
309              v0_${N} = _mm_srli_epi64(v0_${N}, 32);
310        }
311      $if NUM_ITERS>2:
312        if (bh & ${TILE_SIZE>>3}) {
313          $if OUT_PTRS == "SWITCH":
314            uint${SIZE}_t* oN = (uint${SIZE}_t*) ((uintptr_t) o + oN_stride);
315            switch (rem) {
316              $for N in reversed(range(2, TILE_SIZE)):
317                case ${N}:
318                  unaligned_store_u16(oN, (uint16_t) _mm_cvtsi128_si32(v0_${N}));
319                  oN = (uint${SIZE}_t*) ((uintptr_t) oN + minus_output_stride);
320              case 1:
321                unaligned_store_u16(oN, (uint16_t) _mm_cvtsi128_si32(v0_1));
322              case 0:
323                unaligned_store_u16(o, (uint16_t) _mm_cvtsi128_si32(v0_0));
324                break;
325              default:
326                XNN_UNREACHABLE;
327            }
328            $if NUM_ITERS>3:
329              o += ${TILE_SIZE>>3};
330          $elif OUT_PTRS == "MOV":
331            o = (uint${SIZE}_t*) ((uintptr_t) o + oN_stride);
332            unaligned_store_u16(o, (uint16_t) _mm_cvtsi128_si32(v0_${TILE_SIZE-1}));
333            uint${SIZE}_t* oN = (uint${SIZE}_t*) ((uintptr_t) o + minus_output_stride);
334            $for N in reversed(range(2, TILE_SIZE, 2)):
335              if XNN_UNPREDICTABLE(block_width > ${N+1}) {
336                o = oN;
337              }
338              unaligned_store_u16(o, (uint16_t) _mm_cvtsi128_si32(v0_${N}));
339              oN = (uint${SIZE}_t*) ((uintptr_t) o + minus_output_stride);
340              if XNN_UNPREDICTABLE(block_width >= ${N+1}) {
341                o = oN;
342              }
343              unaligned_store_u16(o, (uint16_t) _mm_cvtsi128_si32(v0_${N-1}));
344              oN = (uint${SIZE}_t*) ((uintptr_t) o + minus_output_stride);
345            if XNN_UNPREDICTABLE(block_width > 1) {
346              o = oN;
347            }
348            unaligned_store_u16(o, (uint16_t) _mm_cvtsi128_si32(v0_0));
349            $if NUM_ITERS > 3:
350              o += ${TILE_SIZE>>3};
351          $else:
352            $for N in reversed(range(TILE_SIZE)):
353              unaligned_store_u16(o${N}, (uint16_t) _mm_cvtsi128_si32(v0_${N}));
354              $if NUM_ITERS>3:
355                o${N} += ${TILE_SIZE>>3};
356          $if NUM_ITERS>3:
357            $for N in range(TILE_SIZE):
358              v0_${N} = _mm_srli_epi32(v0_${N}, 16);
359        }
360      $if SIZE == 8:
361        if (bh & 1) {
362          $if OUT_PTRS == "SWITCH":
363            uint${SIZE}_t* oN = (uint${SIZE}_t*) ((uintptr_t) o + oN_stride);
364            switch (rem) {
365              $for N in reversed(range(2, TILE_SIZE)):
366                case ${N}:
367                  *oN = (uint8_t) _mm_cvtsi128_si32(v0_${N});
368                  oN = (uint${SIZE}_t*) ((uintptr_t) oN + minus_output_stride);
369              case 1:
370                *oN = (uint8_t) _mm_cvtsi128_si32(v0_1);
371              case 0:
372                *o = (uint8_t) _mm_cvtsi128_si32(v0_0);
373                break;
374              default:
375                XNN_UNREACHABLE;
376            }
377          $elif OUT_PTRS == "MOV":
378            o = (uint${SIZE}_t*) ((uintptr_t) o + oN_stride);
379            *o = (uint8_t) _mm_cvtsi128_si32(v0_${TILE_SIZE-1});
380            uint${SIZE}_t* oN = (uint${SIZE}_t*) ((uintptr_t) o + minus_output_stride);
381            $for N in reversed(range(2, TILE_SIZE, 2)):
382              if XNN_UNPREDICTABLE(block_width > ${N+1}) {
383                o = oN;
384              }
385              *o = (uint8_t) _mm_cvtsi128_si32(v0_${N});
386              oN = (uint${SIZE}_t*) ((uintptr_t) o + minus_output_stride);
387              if XNN_UNPREDICTABLE(block_width >= ${N+1}) {
388                o = oN;
389              }
390              *o = (uint8_t) _mm_cvtsi128_si32(v0_${N-1});
391              oN = (uint${SIZE}_t*) ((uintptr_t) o + minus_output_stride);
392            if XNN_UNPREDICTABLE(block_width > 1) {
393              o = oN;
394            }
395            *o = (uint8_t) _mm_cvtsi128_si32(v0_0);
396        }
397    }
398
399    $if IN_PTRS == "MULTI":
400      i0 = (const uint${SIZE}_t*) ((uintptr_t) i0 + input_reset);
401      $for N in range(1, TILE_SIZE):
402        i${N} = (const uint${SIZE}_t*) ((uintptr_t) i${N-1} + input_stride);
403    $else:
404      i0 = (const uint${SIZE}_t*) ((uintptr_t) i0 + input_reset);
405    $if OUT_PTRS == "MULTI":
406      o0 = (uint${SIZE}_t*) ((uintptr_t) o0 + output_reset);
407      $for N in range(1, TILE_SIZE):
408        o${N} = (uint${SIZE}_t*) ((uintptr_t) o${N} + output_reset);
409    $else:
410      o = (uint${SIZE}_t*) ((uintptr_t) o + output_reset);
411    block_width = doz(block_width, tile_width);
412  } while (block_width != 0);
413}
414