xref: /aosp_15_r20/external/XNNPACK/src/s8-ibilinear/sse.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2021 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert SSE in [2, 4]
7$assert not XOP or AVX
8$assert not AVX or SSE == 4
9$assert DATATYPE in ["S8", "U8"]
10$assert CHANNEL_TILE % 8 == 0
11$assert CHANNEL_TILE >= 8
12$assert PIXEL_TILE == 1
13$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
14#include <assert.h>
15
16$if XOP:
17  #if defined(__GNUC__) || defined(__clang__)
18    #include <x86intrin.h>
19  #else
20    #include <immintrin.h>
21    #include <ammintrin.h>
22  #endif
23$else:
24  $SSE_HEADER = {2: "emmintrin.h", 4: "smmintrin.h"}[SSE]
25  #include <${SSE_HEADER}>
26
27#include <xnnpack/common.h>
28#include <xnnpack/ibilinear.h>
29#include <xnnpack/unaligned.h>
30
31
32$XINT8_T = {"S8": "int8_t", "U8": "uint8_t"}[DATATYPE]
33$_MM_CVTEPX8_EPI16 = {"S8": "_mm_cvtepi8_epi16", "U8": "_mm_cvtepu8_epi16"}[DATATYPE]
34$_MM_SRXI_EPI32 = {"S8": "_mm_srai_epi32", "U8": "_mm_srli_epi32"}[DATATYPE]
35$_MM_SRXI_EPI16 = {"S8": "_mm_srai_epi16", "U8": "_mm_srli_epi16"}[DATATYPE]
36$_MM_PACKXS_EPI16 = {"S8": "_mm_packs_epi16", "U8": "_mm_packus_epi16"}[DATATYPE]
37$ISA = "xop" if XOP else "avx" if AVX else {2: "sse2", 3: "ssse3", 4: "sse41"}[SSE]
38void xnn_${DATATYPE.lower()}_ibilinear_ukernel__${ISA}_c${CHANNEL_TILE}${"" if PIXEL_TILE == 1 else "x%d" % PIXEL_TILE}(
39    size_t output_pixels,
40    size_t channels,
41    const ${XINT8_T}**restrict input,
42    size_t input_offset,
43    const int16_t*restrict weights,
44    ${XINT8_T}*restrict output,
45    size_t output_increment) XNN_OOB_READS
46{
47  assert(output_pixels != 0);
48  assert(channels != 0);
49
50  do {
51    const ${XINT8_T}* i0 = (const ${XINT8_T}*) ((uintptr_t) input[0] + input_offset);
52    const ${XINT8_T}* i1 = (const ${XINT8_T}*) ((uintptr_t) input[1] + input_offset);
53    const ${XINT8_T}* i2 = (const ${XINT8_T}*) ((uintptr_t) input[2] + input_offset);
54    const ${XINT8_T}* i3 = (const ${XINT8_T}*) ((uintptr_t) input[3] + input_offset);
55    input += 4;
56
57    const __m128i valpha = _mm_cvtsi32_si128(*((const int*) weights));
58    weights += 2;
59    __m128i valphah = _mm_shufflelo_epi16(valpha, _MM_SHUFFLE(0, 0, 0, 0));
60    valphah = _mm_unpacklo_epi64(valphah, valphah);
61    $if SSE == 2:
62      __m128i valphav = _mm_shufflelo_epi16(valpha, _MM_SHUFFLE(1, 1, 1, 1));
63      valphav = _mm_unpacklo_epi64(valphav, valphav);
64    $else:
65      __m128i valphav = _mm_srli_epi32(valpha, 16);
66      valphav = _mm_shuffle_epi32(valphav, _MM_SHUFFLE(0, 0, 0, 0));
67
68    $if SSE == 4:
69      valphah = _mm_blend_epi16(valphah, _mm_sub_epi16(_mm_set1_epi32(0x08000000), valphah), 0xAA);
70    $else:
71      valphah = _mm_xor_si128(valphah, _mm_set1_epi32(0xFFFF0000));
72      valphah = _mm_add_epi16(valphah, _mm_set1_epi32(0x08010000));
73
74    const __m128i vrounding = _mm_set1_epi32(0x00200000);
75
76    size_t c = channels;
77    $if CHANNEL_TILE > 8:
78      for (; c >= ${CHANNEL_TILE} * sizeof(${XINT8_T}); c -= ${CHANNEL_TILE} * sizeof(${XINT8_T})) {
79        $if SSE == 4:
80          const __m128i vtl${ABC[0:8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i0));
81          const __m128i vtr${ABC[0:8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i1));
82          const __m128i vbl${ABC[0:8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i2));
83          const __m128i vbr${ABC[0:8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i3));
84          $for C in range(8, CHANNEL_TILE, 8):
85            const __m128i vtl${ABC[C:C+8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) (i0 + ${C})));
86            const __m128i vtr${ABC[C:C+8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) (i1 + ${C})));
87            const __m128i vbl${ABC[C:C+8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) (i2 + ${C})));
88            const __m128i vbr${ABC[C:C+8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) (i3 + ${C})));
89        $else:
90          __m128i vtl${ABC[0:8]} = _mm_loadl_epi64((const __m128i*) i0);
91          __m128i vtr${ABC[0:8]} = _mm_loadl_epi64((const __m128i*) i1);
92          __m128i vbl${ABC[0:8]} = _mm_loadl_epi64((const __m128i*) i2);
93          __m128i vbr${ABC[0:8]} = _mm_loadl_epi64((const __m128i*) i3);
94          $for C in range(8, CHANNEL_TILE, 8):
95            __m128i vtl${ABC[C:C+8]} = _mm_loadl_epi64((const __m128i*) (i0 + ${C}));
96            __m128i vtr${ABC[C:C+8]} = _mm_loadl_epi64((const __m128i*) (i1 + ${C}));
97            __m128i vbl${ABC[C:C+8]} = _mm_loadl_epi64((const __m128i*) (i2 + ${C}));
98            __m128i vbr${ABC[C:C+8]} = _mm_loadl_epi64((const __m128i*) (i3 + ${C}));
99        i0 += ${CHANNEL_TILE};
100        i1 += ${CHANNEL_TILE};
101        i2 += ${CHANNEL_TILE};
102        i3 += ${CHANNEL_TILE};
103
104        $if SSE != 4:
105          $if DATATYPE == "U8":
106            __m128i vzero = _mm_setzero_si128();
107            $for C in range(0, CHANNEL_TILE, 8):
108              vtl${ABC[C:C+8]} = _mm_unpacklo_epi8(vtl${ABC[C:C+8]}, vzero);
109              vtr${ABC[C:C+8]} = _mm_unpacklo_epi8(vtr${ABC[C:C+8]}, vzero);
110              vbl${ABC[C:C+8]} = _mm_unpacklo_epi8(vbl${ABC[C:C+8]}, vzero);
111              vbr${ABC[C:C+8]} = _mm_unpacklo_epi8(vbr${ABC[C:C+8]}, vzero);
112          $else:
113            $for C in range(0, CHANNEL_TILE, 8):
114              vtl${ABC[C:C+8]} = _mm_srai_epi16(_mm_unpacklo_epi8(vtl${ABC[C:C+8]}, vtl${ABC[C:C+8]}), 8);
115              vtr${ABC[C:C+8]} = _mm_srai_epi16(_mm_unpacklo_epi8(vtr${ABC[C:C+8]}, vtr${ABC[C:C+8]}), 8);
116              vbl${ABC[C:C+8]} = _mm_srai_epi16(_mm_unpacklo_epi8(vbl${ABC[C:C+8]}, vbl${ABC[C:C+8]}), 8);
117              vbr${ABC[C:C+8]} = _mm_srai_epi16(_mm_unpacklo_epi8(vbr${ABC[C:C+8]}, vbr${ABC[C:C+8]}), 8);
118
119        $for C in range(0, CHANNEL_TILE, 8):
120          const __m128i vdr${ABC[C:C+8]} = _mm_sub_epi16(vbr${ABC[C:C+8]}, vtr${ABC[C:C+8]});
121          const __m128i vt${ABC[C:C+4]} = _mm_madd_epi16(_mm_unpacklo_epi16(vtr${ABC[C:C+8]}, vtl${ABC[C:C+8]}), valphah);
122          const __m128i vdl${ABC[C:C+8]} = _mm_sub_epi16(vbl${ABC[C:C+8]}, vtl${ABC[C:C+8]});
123          const __m128i vt${ABC[C+4:C+8]} = _mm_madd_epi16(_mm_unpackhi_epi16(vtr${ABC[C:C+8]}, vtl${ABC[C:C+8]}), valphah);
124
125        $for C in range(0, CHANNEL_TILE, 8):
126          const __m128i vd${ABC[C:C+4]} = _mm_madd_epi16(_mm_unpacklo_epi16(vdr${ABC[C:C+8]}, vdl${ABC[C:C+8]}), valphah);
127          const __m128i vd${ABC[C+4:C+8]} = _mm_madd_epi16(_mm_unpackhi_epi16(vdr${ABC[C:C+8]}, vdl${ABC[C:C+8]}), valphah);
128
129        $if SSE == 4:
130          $for C in range(0, CHANNEL_TILE, 4):
131            __m128i vacc${ABC[C:C+4]} = _mm_mullo_epi32(vd${ABC[C:C+4]}, valphav);
132        $else:
133          $for C in range(0, CHANNEL_TILE, 4):
134            __m128i vacc${ABC[C:C+4]} = _mm_slli_epi32(_mm_mulhi_epu16(vd${ABC[C:C+4]}, valphav), 16);
135
136          $for C in range(0, CHANNEL_TILE, 4):
137            vacc${ABC[C:C+4]} = _mm_add_epi16(_mm_mullo_epi16(vd${ABC[C:C+4]}, valphav), vacc${ABC[C:C+4]});
138
139        $for C in range(0, CHANNEL_TILE, 4):
140          vacc${ABC[C:C+4]} = _mm_add_epi32(_mm_slli_epi32(vt${ABC[C:C+4]}, 11), vacc${ABC[C:C+4]});
141
142        $for C in range(0, CHANNEL_TILE, 4):
143          vacc${ABC[C:C+4]} = ${_MM_SRXI_EPI32}(_mm_add_epi16(vacc${ABC[C:C+4]}, vrounding), 22);
144
145        $for C in range(0, CHANNEL_TILE, 8):
146          const __m128i vacc${ABC[C:C+8]} = _mm_packs_epi32(vacc${ABC[C:C+4]}, vacc${ABC[C+4:C+8]});
147
148        $for C in range(0, CHANNEL_TILE, 16):
149          $if C + 8 < CHANNEL_TILE:
150            const __m128i vo${ABC[C:C+16]} = ${_MM_PACKXS_EPI16}(vacc${ABC[C:C+8]}, vacc${ABC[C+8:C+16]});
151          $else:
152            const __m128i vo${ABC[C:C+8]} = ${_MM_PACKXS_EPI16}(vacc${ABC[C:C+8]}, vacc${ABC[C:C+8]});
153
154        _mm_storeu_si128((__m128i*) output, vo${ABC[0:16]});
155        $for C in range(16, CHANNEL_TILE, 16):
156          $if C + 8 < CHANNEL_TILE:
157            _mm_storeu_si128((__m128i*) (output + ${C}), vo${ABC[C:C+16]});
158          $else:
159            _mm_storel_epi64((__m128i*) (output + ${C}), vo${ABC[C:C+8]});
160        output += ${CHANNEL_TILE};
161      }
162    for (; c >= 8 * sizeof(${XINT8_T}); c -= 8 * sizeof(${XINT8_T})) {
163      $if SSE == 4:
164        const __m128i vtl01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i0));
165        i0 += 8;
166        const __m128i vtr01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i1));
167        i1 += 8;
168        const __m128i vbl01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i2));
169        i2 += 8;
170        const __m128i vbr01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i3));
171        i3 += 8;
172      $else:
173        __m128i vtl01234567 = _mm_loadl_epi64((const __m128i*) i0);
174        i0 += 8;
175        __m128i vtr01234567 = _mm_loadl_epi64((const __m128i*) i1);
176        i1 += 8;
177        __m128i vbl01234567 = _mm_loadl_epi64((const __m128i*) i2);
178        i2 += 8;
179        __m128i vbr01234567 = _mm_loadl_epi64((const __m128i*) i3);
180        i3 += 8;
181
182      $if SSE != 4:
183        $if DATATYPE == "U8":
184          __m128i vzero = _mm_setzero_si128();
185          vtl01234567 = _mm_unpacklo_epi8(vtl01234567, vzero);
186          vtr01234567 = _mm_unpacklo_epi8(vtr01234567, vzero);
187          vbl01234567 = _mm_unpacklo_epi8(vbl01234567, vzero);
188          vbr01234567 = _mm_unpacklo_epi8(vbr01234567, vzero);
189        $else:
190          vtl01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vtl01234567, vtl01234567), 8);
191          vtr01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vtr01234567, vtr01234567), 8);
192          vbl01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vbl01234567, vbl01234567), 8);
193          vbr01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vbr01234567, vbr01234567), 8);
194
195      const __m128i vdr01234567 = _mm_sub_epi16(vbr01234567, vtr01234567);
196      const __m128i vt0123 = _mm_madd_epi16(_mm_unpacklo_epi16(vtr01234567, vtl01234567), valphah);
197      const __m128i vdl01234567 = _mm_sub_epi16(vbl01234567, vtl01234567);
198      const __m128i vt4567 = _mm_madd_epi16(_mm_unpackhi_epi16(vtr01234567, vtl01234567), valphah);
199
200      const __m128i vd0123 = _mm_madd_epi16(_mm_unpacklo_epi16(vdr01234567, vdl01234567), valphah);
201      const __m128i vd4567 = _mm_madd_epi16(_mm_unpackhi_epi16(vdr01234567, vdl01234567), valphah);
202
203      $if SSE == 4:
204        __m128i vacc0123 = _mm_mullo_epi32(vd0123, valphav);
205        __m128i vacc4567 = _mm_mullo_epi32(vd4567, valphav);
206      $else:
207        __m128i vacc0123 = _mm_slli_epi32(_mm_mulhi_epu16(vd0123, valphav), 16);
208        __m128i vacc4567 = _mm_slli_epi32(_mm_mulhi_epu16(vd4567, valphav), 16);
209
210        vacc0123 = _mm_add_epi16(_mm_mullo_epi16(vd0123, valphav), vacc0123);
211        vacc4567 = _mm_add_epi16(_mm_mullo_epi16(vd4567, valphav), vacc4567);
212
213      vacc0123 = _mm_add_epi32(_mm_slli_epi32(vt0123, 11), vacc0123);
214      vacc4567 = _mm_add_epi32(_mm_slli_epi32(vt4567, 11), vacc4567);
215
216      vacc0123 = ${_MM_SRXI_EPI32}(_mm_add_epi16(vacc0123, vrounding), 22);
217      vacc4567 = ${_MM_SRXI_EPI32}(_mm_add_epi16(vacc4567, vrounding), 22);
218
219      const __m128i vacc01234567 = _mm_packs_epi32(vacc0123, vacc4567);
220
221      const __m128i vo01234567 = ${_MM_PACKXS_EPI16}(vacc01234567, vacc01234567);
222
223      _mm_storel_epi64((__m128i*) output, vo01234567);
224      output += 8;
225    }
226    if XNN_UNLIKELY(c != 0) {
227      $if SSE == 4:
228        const __m128i vtl01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i0));
229        const __m128i vtr01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i1));
230        const __m128i vbl01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i2));
231        const __m128i vbr01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i3));
232      $else:
233        __m128i vtl01234567 = _mm_loadl_epi64((const __m128i*) i0);
234        __m128i vtr01234567 = _mm_loadl_epi64((const __m128i*) i1);
235        __m128i vbl01234567 = _mm_loadl_epi64((const __m128i*) i2);
236        __m128i vbr01234567 = _mm_loadl_epi64((const __m128i*) i3);
237
238      $if SSE != 4:
239        $if DATATYPE == "U8":
240          __m128i vzero = _mm_setzero_si128();
241          vtl01234567 = _mm_unpacklo_epi8(vtl01234567, vzero);
242          vtr01234567 = _mm_unpacklo_epi8(vtr01234567, vzero);
243          vbl01234567 = _mm_unpacklo_epi8(vbl01234567, vzero);
244          vbr01234567 = _mm_unpacklo_epi8(vbr01234567, vzero);
245        $else:
246          vtl01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vtl01234567, vtl01234567), 8);
247          vtr01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vtr01234567, vtr01234567), 8);
248          vbl01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vbl01234567, vbl01234567), 8);
249          vbr01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vbr01234567, vbr01234567), 8);
250
251      const __m128i vdr01234567 = _mm_sub_epi16(vbr01234567, vtr01234567);
252      const __m128i vt0123 = _mm_madd_epi16(_mm_unpacklo_epi16(vtr01234567, vtl01234567), valphah);
253      const __m128i vdl01234567 = _mm_sub_epi16(vbl01234567, vtl01234567);
254      const __m128i vt4567 = _mm_madd_epi16(_mm_unpackhi_epi16(vtr01234567, vtl01234567), valphah);
255
256      const __m128i vd0123 = _mm_madd_epi16(_mm_unpacklo_epi16(vdr01234567, vdl01234567), valphah);
257      const __m128i vd4567 = _mm_madd_epi16(_mm_unpackhi_epi16(vdr01234567, vdl01234567), valphah);
258
259      $if SSE == 4:
260        __m128i vacc0123 = _mm_mullo_epi32(vd0123, valphav);
261        __m128i vacc4567 = _mm_mullo_epi32(vd4567, valphav);
262      $else:
263        __m128i vacc0123 = _mm_slli_epi32(_mm_mulhi_epu16(vd0123, valphav), 16);
264        __m128i vacc4567 = _mm_slli_epi32(_mm_mulhi_epu16(vd4567, valphav), 16);
265
266        vacc0123 = _mm_add_epi16(_mm_mullo_epi16(vd0123, valphav), vacc0123);
267        vacc4567 = _mm_add_epi16(_mm_mullo_epi16(vd4567, valphav), vacc4567);
268
269      vacc0123 = _mm_add_epi32(_mm_slli_epi32(vt0123, 11), vacc0123);
270      vacc4567 = _mm_add_epi32(_mm_slli_epi32(vt4567, 11), vacc4567);
271
272      vacc0123 = ${_MM_SRXI_EPI32}(_mm_add_epi16(vacc0123, vrounding), 22);
273      vacc4567 = ${_MM_SRXI_EPI32}(_mm_add_epi16(vacc4567, vrounding), 22);
274
275      const __m128i vacc01234567 = _mm_packs_epi32(vacc0123, vacc4567);
276
277      __m128i vo01234567 = ${_MM_PACKXS_EPI16}(vacc01234567, vacc01234567);
278
279      if (c & (4 * sizeof(${XINT8_T}))) {
280        unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vo01234567));
281        output += 4;
282        vo01234567 = _mm_srli_epi64(vo01234567, 32);
283      }
284      $if SSE == 4:
285        if (c & (2 * sizeof(${XINT8_T}))) {
286          unaligned_store_u16(output, (uint16_t) _mm_extract_epi16(vo01234567, 0));
287          output += 2;
288          vo01234567 = _mm_srli_epi32(vo01234567, 16);
289        }
290        if (c & (1 * sizeof(${XINT8_T}))) {
291          *output++ = (uint8_t) _mm_extract_epi8(vo01234567, 0);
292        }
293      $else:
294        uint32_t vo0123 = (uint32_t) _mm_cvtsi128_si32(vo01234567);
295        if (c & (2 * sizeof(${XINT8_T}))) {
296          unaligned_store_u16(output, (uint16_t) vo0123);
297          output += 2;
298          vo0123 >>= 16;
299        }
300        if (c & (1 * sizeof(${XINT8_T}))) {
301          *output++ = (uint8_t) vo0123;
302        }
303    }
304
305    output = (${XINT8_T}*) ((uintptr_t) output + output_increment);
306  } while (--output_pixels != 0);
307}
308