1// Copyright 2021 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert SSE in [2, 4] 7$assert not XOP or AVX 8$assert not AVX or SSE == 4 9$assert DATATYPE in ["S8", "U8"] 10$assert CHANNEL_TILE % 8 == 0 11$assert CHANNEL_TILE >= 8 12$assert PIXEL_TILE == 1 13$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 14#include <assert.h> 15 16$if XOP: 17 #if defined(__GNUC__) || defined(__clang__) 18 #include <x86intrin.h> 19 #else 20 #include <immintrin.h> 21 #include <ammintrin.h> 22 #endif 23$else: 24 $SSE_HEADER = {2: "emmintrin.h", 4: "smmintrin.h"}[SSE] 25 #include <${SSE_HEADER}> 26 27#include <xnnpack/common.h> 28#include <xnnpack/ibilinear.h> 29#include <xnnpack/unaligned.h> 30 31 32$XINT8_T = {"S8": "int8_t", "U8": "uint8_t"}[DATATYPE] 33$_MM_CVTEPX8_EPI16 = {"S8": "_mm_cvtepi8_epi16", "U8": "_mm_cvtepu8_epi16"}[DATATYPE] 34$_MM_SRXI_EPI32 = {"S8": "_mm_srai_epi32", "U8": "_mm_srli_epi32"}[DATATYPE] 35$_MM_SRXI_EPI16 = {"S8": "_mm_srai_epi16", "U8": "_mm_srli_epi16"}[DATATYPE] 36$_MM_PACKXS_EPI16 = {"S8": "_mm_packs_epi16", "U8": "_mm_packus_epi16"}[DATATYPE] 37$ISA = "xop" if XOP else "avx" if AVX else {2: "sse2", 3: "ssse3", 4: "sse41"}[SSE] 38void xnn_${DATATYPE.lower()}_ibilinear_ukernel__${ISA}_c${CHANNEL_TILE}${"" if PIXEL_TILE == 1 else "x%d" % PIXEL_TILE}( 39 size_t output_pixels, 40 size_t channels, 41 const ${XINT8_T}**restrict input, 42 size_t input_offset, 43 const int16_t*restrict weights, 44 ${XINT8_T}*restrict output, 45 size_t output_increment) XNN_OOB_READS 46{ 47 assert(output_pixels != 0); 48 assert(channels != 0); 49 50 do { 51 const ${XINT8_T}* i0 = (const ${XINT8_T}*) ((uintptr_t) input[0] + input_offset); 52 const ${XINT8_T}* i1 = (const ${XINT8_T}*) ((uintptr_t) input[1] + input_offset); 53 const ${XINT8_T}* i2 = (const ${XINT8_T}*) ((uintptr_t) input[2] + input_offset); 54 const ${XINT8_T}* i3 = (const ${XINT8_T}*) ((uintptr_t) input[3] + input_offset); 55 input += 4; 56 57 const __m128i valpha = _mm_cvtsi32_si128(*((const int*) weights)); 58 weights += 2; 59 __m128i valphah = _mm_shufflelo_epi16(valpha, _MM_SHUFFLE(0, 0, 0, 0)); 60 valphah = _mm_unpacklo_epi64(valphah, valphah); 61 $if SSE == 2: 62 __m128i valphav = _mm_shufflelo_epi16(valpha, _MM_SHUFFLE(1, 1, 1, 1)); 63 valphav = _mm_unpacklo_epi64(valphav, valphav); 64 $else: 65 __m128i valphav = _mm_srli_epi32(valpha, 16); 66 valphav = _mm_shuffle_epi32(valphav, _MM_SHUFFLE(0, 0, 0, 0)); 67 68 $if SSE == 4: 69 valphah = _mm_blend_epi16(valphah, _mm_sub_epi16(_mm_set1_epi32(0x08000000), valphah), 0xAA); 70 $else: 71 valphah = _mm_xor_si128(valphah, _mm_set1_epi32(0xFFFF0000)); 72 valphah = _mm_add_epi16(valphah, _mm_set1_epi32(0x08010000)); 73 74 const __m128i vrounding = _mm_set1_epi32(0x00200000); 75 76 size_t c = channels; 77 $if CHANNEL_TILE > 8: 78 for (; c >= ${CHANNEL_TILE} * sizeof(${XINT8_T}); c -= ${CHANNEL_TILE} * sizeof(${XINT8_T})) { 79 $if SSE == 4: 80 const __m128i vtl${ABC[0:8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i0)); 81 const __m128i vtr${ABC[0:8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i1)); 82 const __m128i vbl${ABC[0:8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i2)); 83 const __m128i vbr${ABC[0:8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i3)); 84 $for C in range(8, CHANNEL_TILE, 8): 85 const __m128i vtl${ABC[C:C+8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) (i0 + ${C}))); 86 const __m128i vtr${ABC[C:C+8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) (i1 + ${C}))); 87 const __m128i vbl${ABC[C:C+8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) (i2 + ${C}))); 88 const __m128i vbr${ABC[C:C+8]} = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) (i3 + ${C}))); 89 $else: 90 __m128i vtl${ABC[0:8]} = _mm_loadl_epi64((const __m128i*) i0); 91 __m128i vtr${ABC[0:8]} = _mm_loadl_epi64((const __m128i*) i1); 92 __m128i vbl${ABC[0:8]} = _mm_loadl_epi64((const __m128i*) i2); 93 __m128i vbr${ABC[0:8]} = _mm_loadl_epi64((const __m128i*) i3); 94 $for C in range(8, CHANNEL_TILE, 8): 95 __m128i vtl${ABC[C:C+8]} = _mm_loadl_epi64((const __m128i*) (i0 + ${C})); 96 __m128i vtr${ABC[C:C+8]} = _mm_loadl_epi64((const __m128i*) (i1 + ${C})); 97 __m128i vbl${ABC[C:C+8]} = _mm_loadl_epi64((const __m128i*) (i2 + ${C})); 98 __m128i vbr${ABC[C:C+8]} = _mm_loadl_epi64((const __m128i*) (i3 + ${C})); 99 i0 += ${CHANNEL_TILE}; 100 i1 += ${CHANNEL_TILE}; 101 i2 += ${CHANNEL_TILE}; 102 i3 += ${CHANNEL_TILE}; 103 104 $if SSE != 4: 105 $if DATATYPE == "U8": 106 __m128i vzero = _mm_setzero_si128(); 107 $for C in range(0, CHANNEL_TILE, 8): 108 vtl${ABC[C:C+8]} = _mm_unpacklo_epi8(vtl${ABC[C:C+8]}, vzero); 109 vtr${ABC[C:C+8]} = _mm_unpacklo_epi8(vtr${ABC[C:C+8]}, vzero); 110 vbl${ABC[C:C+8]} = _mm_unpacklo_epi8(vbl${ABC[C:C+8]}, vzero); 111 vbr${ABC[C:C+8]} = _mm_unpacklo_epi8(vbr${ABC[C:C+8]}, vzero); 112 $else: 113 $for C in range(0, CHANNEL_TILE, 8): 114 vtl${ABC[C:C+8]} = _mm_srai_epi16(_mm_unpacklo_epi8(vtl${ABC[C:C+8]}, vtl${ABC[C:C+8]}), 8); 115 vtr${ABC[C:C+8]} = _mm_srai_epi16(_mm_unpacklo_epi8(vtr${ABC[C:C+8]}, vtr${ABC[C:C+8]}), 8); 116 vbl${ABC[C:C+8]} = _mm_srai_epi16(_mm_unpacklo_epi8(vbl${ABC[C:C+8]}, vbl${ABC[C:C+8]}), 8); 117 vbr${ABC[C:C+8]} = _mm_srai_epi16(_mm_unpacklo_epi8(vbr${ABC[C:C+8]}, vbr${ABC[C:C+8]}), 8); 118 119 $for C in range(0, CHANNEL_TILE, 8): 120 const __m128i vdr${ABC[C:C+8]} = _mm_sub_epi16(vbr${ABC[C:C+8]}, vtr${ABC[C:C+8]}); 121 const __m128i vt${ABC[C:C+4]} = _mm_madd_epi16(_mm_unpacklo_epi16(vtr${ABC[C:C+8]}, vtl${ABC[C:C+8]}), valphah); 122 const __m128i vdl${ABC[C:C+8]} = _mm_sub_epi16(vbl${ABC[C:C+8]}, vtl${ABC[C:C+8]}); 123 const __m128i vt${ABC[C+4:C+8]} = _mm_madd_epi16(_mm_unpackhi_epi16(vtr${ABC[C:C+8]}, vtl${ABC[C:C+8]}), valphah); 124 125 $for C in range(0, CHANNEL_TILE, 8): 126 const __m128i vd${ABC[C:C+4]} = _mm_madd_epi16(_mm_unpacklo_epi16(vdr${ABC[C:C+8]}, vdl${ABC[C:C+8]}), valphah); 127 const __m128i vd${ABC[C+4:C+8]} = _mm_madd_epi16(_mm_unpackhi_epi16(vdr${ABC[C:C+8]}, vdl${ABC[C:C+8]}), valphah); 128 129 $if SSE == 4: 130 $for C in range(0, CHANNEL_TILE, 4): 131 __m128i vacc${ABC[C:C+4]} = _mm_mullo_epi32(vd${ABC[C:C+4]}, valphav); 132 $else: 133 $for C in range(0, CHANNEL_TILE, 4): 134 __m128i vacc${ABC[C:C+4]} = _mm_slli_epi32(_mm_mulhi_epu16(vd${ABC[C:C+4]}, valphav), 16); 135 136 $for C in range(0, CHANNEL_TILE, 4): 137 vacc${ABC[C:C+4]} = _mm_add_epi16(_mm_mullo_epi16(vd${ABC[C:C+4]}, valphav), vacc${ABC[C:C+4]}); 138 139 $for C in range(0, CHANNEL_TILE, 4): 140 vacc${ABC[C:C+4]} = _mm_add_epi32(_mm_slli_epi32(vt${ABC[C:C+4]}, 11), vacc${ABC[C:C+4]}); 141 142 $for C in range(0, CHANNEL_TILE, 4): 143 vacc${ABC[C:C+4]} = ${_MM_SRXI_EPI32}(_mm_add_epi16(vacc${ABC[C:C+4]}, vrounding), 22); 144 145 $for C in range(0, CHANNEL_TILE, 8): 146 const __m128i vacc${ABC[C:C+8]} = _mm_packs_epi32(vacc${ABC[C:C+4]}, vacc${ABC[C+4:C+8]}); 147 148 $for C in range(0, CHANNEL_TILE, 16): 149 $if C + 8 < CHANNEL_TILE: 150 const __m128i vo${ABC[C:C+16]} = ${_MM_PACKXS_EPI16}(vacc${ABC[C:C+8]}, vacc${ABC[C+8:C+16]}); 151 $else: 152 const __m128i vo${ABC[C:C+8]} = ${_MM_PACKXS_EPI16}(vacc${ABC[C:C+8]}, vacc${ABC[C:C+8]}); 153 154 _mm_storeu_si128((__m128i*) output, vo${ABC[0:16]}); 155 $for C in range(16, CHANNEL_TILE, 16): 156 $if C + 8 < CHANNEL_TILE: 157 _mm_storeu_si128((__m128i*) (output + ${C}), vo${ABC[C:C+16]}); 158 $else: 159 _mm_storel_epi64((__m128i*) (output + ${C}), vo${ABC[C:C+8]}); 160 output += ${CHANNEL_TILE}; 161 } 162 for (; c >= 8 * sizeof(${XINT8_T}); c -= 8 * sizeof(${XINT8_T})) { 163 $if SSE == 4: 164 const __m128i vtl01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i0)); 165 i0 += 8; 166 const __m128i vtr01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i1)); 167 i1 += 8; 168 const __m128i vbl01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i2)); 169 i2 += 8; 170 const __m128i vbr01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i3)); 171 i3 += 8; 172 $else: 173 __m128i vtl01234567 = _mm_loadl_epi64((const __m128i*) i0); 174 i0 += 8; 175 __m128i vtr01234567 = _mm_loadl_epi64((const __m128i*) i1); 176 i1 += 8; 177 __m128i vbl01234567 = _mm_loadl_epi64((const __m128i*) i2); 178 i2 += 8; 179 __m128i vbr01234567 = _mm_loadl_epi64((const __m128i*) i3); 180 i3 += 8; 181 182 $if SSE != 4: 183 $if DATATYPE == "U8": 184 __m128i vzero = _mm_setzero_si128(); 185 vtl01234567 = _mm_unpacklo_epi8(vtl01234567, vzero); 186 vtr01234567 = _mm_unpacklo_epi8(vtr01234567, vzero); 187 vbl01234567 = _mm_unpacklo_epi8(vbl01234567, vzero); 188 vbr01234567 = _mm_unpacklo_epi8(vbr01234567, vzero); 189 $else: 190 vtl01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vtl01234567, vtl01234567), 8); 191 vtr01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vtr01234567, vtr01234567), 8); 192 vbl01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vbl01234567, vbl01234567), 8); 193 vbr01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vbr01234567, vbr01234567), 8); 194 195 const __m128i vdr01234567 = _mm_sub_epi16(vbr01234567, vtr01234567); 196 const __m128i vt0123 = _mm_madd_epi16(_mm_unpacklo_epi16(vtr01234567, vtl01234567), valphah); 197 const __m128i vdl01234567 = _mm_sub_epi16(vbl01234567, vtl01234567); 198 const __m128i vt4567 = _mm_madd_epi16(_mm_unpackhi_epi16(vtr01234567, vtl01234567), valphah); 199 200 const __m128i vd0123 = _mm_madd_epi16(_mm_unpacklo_epi16(vdr01234567, vdl01234567), valphah); 201 const __m128i vd4567 = _mm_madd_epi16(_mm_unpackhi_epi16(vdr01234567, vdl01234567), valphah); 202 203 $if SSE == 4: 204 __m128i vacc0123 = _mm_mullo_epi32(vd0123, valphav); 205 __m128i vacc4567 = _mm_mullo_epi32(vd4567, valphav); 206 $else: 207 __m128i vacc0123 = _mm_slli_epi32(_mm_mulhi_epu16(vd0123, valphav), 16); 208 __m128i vacc4567 = _mm_slli_epi32(_mm_mulhi_epu16(vd4567, valphav), 16); 209 210 vacc0123 = _mm_add_epi16(_mm_mullo_epi16(vd0123, valphav), vacc0123); 211 vacc4567 = _mm_add_epi16(_mm_mullo_epi16(vd4567, valphav), vacc4567); 212 213 vacc0123 = _mm_add_epi32(_mm_slli_epi32(vt0123, 11), vacc0123); 214 vacc4567 = _mm_add_epi32(_mm_slli_epi32(vt4567, 11), vacc4567); 215 216 vacc0123 = ${_MM_SRXI_EPI32}(_mm_add_epi16(vacc0123, vrounding), 22); 217 vacc4567 = ${_MM_SRXI_EPI32}(_mm_add_epi16(vacc4567, vrounding), 22); 218 219 const __m128i vacc01234567 = _mm_packs_epi32(vacc0123, vacc4567); 220 221 const __m128i vo01234567 = ${_MM_PACKXS_EPI16}(vacc01234567, vacc01234567); 222 223 _mm_storel_epi64((__m128i*) output, vo01234567); 224 output += 8; 225 } 226 if XNN_UNLIKELY(c != 0) { 227 $if SSE == 4: 228 const __m128i vtl01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i0)); 229 const __m128i vtr01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i1)); 230 const __m128i vbl01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i2)); 231 const __m128i vbr01234567 = ${_MM_CVTEPX8_EPI16}(_mm_loadl_epi64((const __m128i*) i3)); 232 $else: 233 __m128i vtl01234567 = _mm_loadl_epi64((const __m128i*) i0); 234 __m128i vtr01234567 = _mm_loadl_epi64((const __m128i*) i1); 235 __m128i vbl01234567 = _mm_loadl_epi64((const __m128i*) i2); 236 __m128i vbr01234567 = _mm_loadl_epi64((const __m128i*) i3); 237 238 $if SSE != 4: 239 $if DATATYPE == "U8": 240 __m128i vzero = _mm_setzero_si128(); 241 vtl01234567 = _mm_unpacklo_epi8(vtl01234567, vzero); 242 vtr01234567 = _mm_unpacklo_epi8(vtr01234567, vzero); 243 vbl01234567 = _mm_unpacklo_epi8(vbl01234567, vzero); 244 vbr01234567 = _mm_unpacklo_epi8(vbr01234567, vzero); 245 $else: 246 vtl01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vtl01234567, vtl01234567), 8); 247 vtr01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vtr01234567, vtr01234567), 8); 248 vbl01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vbl01234567, vbl01234567), 8); 249 vbr01234567 = _mm_srai_epi16(_mm_unpacklo_epi8(vbr01234567, vbr01234567), 8); 250 251 const __m128i vdr01234567 = _mm_sub_epi16(vbr01234567, vtr01234567); 252 const __m128i vt0123 = _mm_madd_epi16(_mm_unpacklo_epi16(vtr01234567, vtl01234567), valphah); 253 const __m128i vdl01234567 = _mm_sub_epi16(vbl01234567, vtl01234567); 254 const __m128i vt4567 = _mm_madd_epi16(_mm_unpackhi_epi16(vtr01234567, vtl01234567), valphah); 255 256 const __m128i vd0123 = _mm_madd_epi16(_mm_unpacklo_epi16(vdr01234567, vdl01234567), valphah); 257 const __m128i vd4567 = _mm_madd_epi16(_mm_unpackhi_epi16(vdr01234567, vdl01234567), valphah); 258 259 $if SSE == 4: 260 __m128i vacc0123 = _mm_mullo_epi32(vd0123, valphav); 261 __m128i vacc4567 = _mm_mullo_epi32(vd4567, valphav); 262 $else: 263 __m128i vacc0123 = _mm_slli_epi32(_mm_mulhi_epu16(vd0123, valphav), 16); 264 __m128i vacc4567 = _mm_slli_epi32(_mm_mulhi_epu16(vd4567, valphav), 16); 265 266 vacc0123 = _mm_add_epi16(_mm_mullo_epi16(vd0123, valphav), vacc0123); 267 vacc4567 = _mm_add_epi16(_mm_mullo_epi16(vd4567, valphav), vacc4567); 268 269 vacc0123 = _mm_add_epi32(_mm_slli_epi32(vt0123, 11), vacc0123); 270 vacc4567 = _mm_add_epi32(_mm_slli_epi32(vt4567, 11), vacc4567); 271 272 vacc0123 = ${_MM_SRXI_EPI32}(_mm_add_epi16(vacc0123, vrounding), 22); 273 vacc4567 = ${_MM_SRXI_EPI32}(_mm_add_epi16(vacc4567, vrounding), 22); 274 275 const __m128i vacc01234567 = _mm_packs_epi32(vacc0123, vacc4567); 276 277 __m128i vo01234567 = ${_MM_PACKXS_EPI16}(vacc01234567, vacc01234567); 278 279 if (c & (4 * sizeof(${XINT8_T}))) { 280 unaligned_store_u32(output, (uint32_t) _mm_cvtsi128_si32(vo01234567)); 281 output += 4; 282 vo01234567 = _mm_srli_epi64(vo01234567, 32); 283 } 284 $if SSE == 4: 285 if (c & (2 * sizeof(${XINT8_T}))) { 286 unaligned_store_u16(output, (uint16_t) _mm_extract_epi16(vo01234567, 0)); 287 output += 2; 288 vo01234567 = _mm_srli_epi32(vo01234567, 16); 289 } 290 if (c & (1 * sizeof(${XINT8_T}))) { 291 *output++ = (uint8_t) _mm_extract_epi8(vo01234567, 0); 292 } 293 $else: 294 uint32_t vo0123 = (uint32_t) _mm_cvtsi128_si32(vo01234567); 295 if (c & (2 * sizeof(${XINT8_T}))) { 296 unaligned_store_u16(output, (uint16_t) vo0123); 297 output += 2; 298 vo0123 >>= 16; 299 } 300 if (c & (1 * sizeof(${XINT8_T}))) { 301 *output++ = (uint8_t) vo0123; 302 } 303 } 304 305 output = (${XINT8_T}*) ((uintptr_t) output + output_increment); 306 } while (--output_pixels != 0); 307} 308