1// Copyright 2021 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5$import math 6$assert IN_PTRS in ["MULTI", "REUSE"] 7$assert OUT_PTRS in ["MULTI", "SWITCH", "MOV"] 8$assert SIZE in [8, 16, 32, 64] 9$TILE_SIZE = int(128/SIZE) 10$NUM_ITERS = int(math.log2(TILE_SIZE)) 11 12#include <immintrin.h> 13 14#include <assert.h> 15 16#include <xnnpack/common.h> 17#include <xnnpack/math.h> 18#include <xnnpack/transpose.h> 19#include <xnnpack/unaligned.h> 20 21 22void xnn_x${SIZE}_transposec_ukernel__${TILE_SIZE}x${TILE_SIZE}_${IN_PTRS.lower()}_${OUT_PTRS.lower()}_sse2( 23 const uint${SIZE}_t* input, 24 uint${SIZE}_t* output, 25 size_t input_stride, 26 size_t output_stride, 27 size_t block_width, 28 size_t block_height) XNN_OOB_READS 29{ 30 assert(output_stride >= block_height * sizeof(uint${SIZE}_t)); 31 assert(input_stride >= block_width * sizeof(uint${SIZE}_t)); 32 33 const size_t tile_height = ${TILE_SIZE}; 34 const size_t tile_width = ${TILE_SIZE}; 35 const size_t tile_hbytes = tile_height * sizeof(uint${SIZE}_t); 36 const size_t tile_wbytes = tile_width * sizeof(uint${SIZE}_t); 37 const size_t input_reset = tile_wbytes - round_down_po2(block_height, tile_height) * input_stride; 38 $if IN_PTRS == "MULTI": 39 const size_t input_offset = tile_height * input_stride; 40 $if OUT_PTRS == "MOV": 41 const size_t output_reset = tile_width * output_stride - round_down_po2(block_height, 2) * sizeof(uint${SIZE}_t) - tile_hbytes; 42 $else: 43 const size_t output_reset = tile_width * output_stride - round_down_po2(block_height, 2) * sizeof(uint${SIZE}_t); 44 45 $if IN_PTRS == "MULTI": 46 const uint${SIZE}_t* i0 = input; 47 $for N in range(1, TILE_SIZE): 48 const uint${SIZE}_t* i${N} = (const uint${SIZE}_t*) ((uintptr_t) i${N-1} + input_stride); 49 $else: 50 const uint${SIZE}_t* i0 = input; 51 $if OUT_PTRS == "MULTI": 52 uint${SIZE}_t* o0 = (uint${SIZE}_t*) output; 53 $for N in range(1, TILE_SIZE): 54 uint${SIZE}_t* o${N} = (uint${SIZE}_t*) ((uintptr_t) o${N-1} + output_stride); 55 $elif OUT_PTRS == "SWITCH": 56 uint${SIZE}_t* o = (uint${SIZE}_t*) output; 57 $else: 58 uint${SIZE}_t* o = (uint${SIZE}_t*) ((uintptr_t) output - tile_hbytes); 59 $if OUT_PTRS == "MOV": 60 const size_t minus_output_stride = -output_stride; 61 $elif OUT_PTRS == "SWITCH" and SIZE != 64: 62 const size_t minus_output_stride = -output_stride; 63 64 do { 65 $if OUT_PTRS == "MULTI": 66 if XNN_UNPREDICTABLE(block_width < 2) { 67 o1 = o0; 68 } 69 $for N in range(2, TILE_SIZE, 2): 70 if XNN_UNPREDICTABLE(block_width <= ${N}) { 71 o${N} = o0; 72 } 73 if XNN_UNPREDICTABLE(block_width < ${N+2}) { 74 o${N+1} = o0; 75 } 76 $elif OUT_PTRS == "MOV": 77 const size_t rem = min(block_width - 1, ${TILE_SIZE-1}); 78 const size_t oN_stride = rem * output_stride; 79 const size_t oN_offset = oN_stride + tile_hbytes; 80 $elif OUT_PTRS == "SWITCH": 81 const size_t rem = min(block_width - 1, ${TILE_SIZE-1}); 82 const size_t oN_stride = rem * output_stride; 83 size_t bh = block_height; 84 for (; bh >= ${TILE_SIZE}; bh -= ${TILE_SIZE}) { 85 $for N in range(TILE_SIZE): 86 $if IN_PTRS == "REUSE": 87 const __m128i v${NUM_ITERS}_${N} = _mm_loadu_si128((const __m128i*) i0); 88 i0 = (uint${SIZE}_t*) ((uintptr_t) i0 + input_stride); 89 $else: 90 const __m128i v${NUM_ITERS}_${N} = _mm_loadu_si128((const __m128i*) i${N}); 91 i${N} = (uint${SIZE}_t*) ((uintptr_t) i${N} + input_offset); 92 93 $for N in range(TILE_SIZE >> 1): 94 const __m128i v${NUM_ITERS-1}_${N*2} = _mm_unpacklo_epi${SIZE}(v${NUM_ITERS}_${N*2}, v${NUM_ITERS}_${N*2+1}); 95 const __m128i v${NUM_ITERS-1}_${N*2+1} = _mm_unpackhi_epi${SIZE}(v${NUM_ITERS}_${N*2}, v${NUM_ITERS}_${N*2+1}); 96 97 $if NUM_ITERS>=2: 98 $for N in range(0, TILE_SIZE, 4): 99 const __m128i v${NUM_ITERS-2}_${N} = _mm_unpacklo_epi${SIZE*2}(v${NUM_ITERS-1}_${N}, v${NUM_ITERS-1}_${N+2}); 100 const __m128i v${NUM_ITERS-2}_${N+1} = _mm_unpackhi_epi${SIZE*2}(v${NUM_ITERS-1}_${N}, v${NUM_ITERS-1}_${N+2}); 101 const __m128i v${NUM_ITERS-2}_${N+2} = _mm_unpacklo_epi${SIZE*2}(v${NUM_ITERS-1}_${N+1}, v${NUM_ITERS-1}_${N+3}); 102 const __m128i v${NUM_ITERS-2}_${N+3} = _mm_unpackhi_epi${SIZE*2}(v${NUM_ITERS-1}_${N+1}, v${NUM_ITERS-1}_${N+3}); 103 104 $if NUM_ITERS>=3: 105 $for M in range(0, TILE_SIZE, 8): 106 $for N in range(0, 4): 107 const __m128i v${NUM_ITERS-3}_${M+2*N} = _mm_unpacklo_epi${SIZE*4}(v${NUM_ITERS-2}_${M+N}, v${NUM_ITERS-2}_${M+N+4}); 108 const __m128i v${NUM_ITERS-3}_${M+2*N+1} = _mm_unpackhi_epi${SIZE*4}(v${NUM_ITERS-2}_${M+N}, v${NUM_ITERS-2}_${M+N+4}); 109 110 $if NUM_ITERS>=4: 111 $for N in range(TILE_SIZE >> 1): 112 const __m128i v0_${N*2} = _mm_unpacklo_epi64(v1_${N}, v1_${N+8}); 113 const __m128i v0_${N*2+1} = _mm_unpackhi_epi64(v1_${N}, v1_${N+8}); 114 115 $if OUT_PTRS == "SWITCH": 116 uint${SIZE}_t* oN = (uint${SIZE}_t*) ((uintptr_t) o + oN_stride); 117 switch (rem) { 118 $for N in reversed(range(2, TILE_SIZE)): 119 case ${N}: 120 _mm_storeu_si128((__m128i*) oN, v0_${N}); 121 oN = (uint${SIZE}_t*) ((uintptr_t) oN + minus_output_stride); 122 case 1: 123 _mm_storeu_si128((__m128i*) oN, v0_1); 124 case 0: 125 _mm_storeu_si128((__m128i*) o, v0_0); 126 o = (uint${SIZE}_t*) ((uintptr_t) o + tile_hbytes); 127 break; 128 default: 129 XNN_UNREACHABLE; 130 } 131 $elif OUT_PTRS == "MOV": 132 o = (uint${SIZE}_t*) ((uintptr_t) o + oN_offset); 133 _mm_storeu_si128((__m128i*) o, v0_${TILE_SIZE-1}); 134 uint${SIZE}_t *oN = (uint${SIZE}_t*) ((uintptr_t) o + minus_output_stride); 135 $for N in reversed(range(2, TILE_SIZE, 2)): 136 if XNN_UNPREDICTABLE(block_width > ${N+1}) { 137 o = oN; 138 } 139 _mm_storeu_si128((__m128i*) o, v0_${N}); 140 oN = (uint${SIZE}_t*) ((uintptr_t) o + minus_output_stride); 141 if XNN_UNPREDICTABLE(block_width >= ${N+1}) { 142 o = oN; 143 } 144 _mm_storeu_si128((__m128i*) o, v0_${N-1}); 145 oN = (uint${SIZE}_t*) ((uintptr_t) o + minus_output_stride); 146 if XNN_UNPREDICTABLE(block_width > 1) { 147 o = oN; 148 } 149 _mm_storeu_si128((__m128i*) o, v0_0); 150 $else: 151 $for N in reversed(range(TILE_SIZE)): 152 _mm_storeu_si128((__m128i*) o${N}, v0_${N}); 153 o${N} = (uint${SIZE}_t*) ((uintptr_t) o${N} + tile_hbytes); 154 } 155 $if OUT_PTRS == "MOV": 156 o = (uint${SIZE}_t*) ((uintptr_t) o + tile_hbytes); 157 if (bh != 0) { 158 $if IN_PTRS == "REUSE": 159 const __m128i v${NUM_ITERS}_0 = _mm_loadu_si128((const __m128i*) i0); 160 $for N in range(1, TILE_SIZE - 1, 2): 161 const uint${SIZE}_t *i${N} = (const uint${SIZE}_t*) ((uintptr_t) i${N-1} + input_stride); 162 if XNN_UNPREDICTABLE(bh < ${N+1}) { 163 i${N} = i${N-1}; 164 } 165 const __m128i v${NUM_ITERS}_${N} = _mm_loadu_si128((const __m128i*) i${N}); 166 const uint${SIZE}_t *i${N+1} = (const uint${SIZE}_t*) ((uintptr_t) i${N} + input_stride); 167 if XNN_UNPREDICTABLE(bh <= ${N+1}) { 168 i${N+1} = i${N}; 169 } 170 const __m128i v${NUM_ITERS}_${N+1} = _mm_loadu_si128((const __m128i*) i${N+1}); 171 $else: 172 const __m128i v${NUM_ITERS}_0 = _mm_loadu_si128((const __m128i*) i0); 173 $for N in range(1, TILE_SIZE - 1, 2): 174 if XNN_UNPREDICTABLE(bh < ${N+1}) { 175 i${N} = i0; 176 } 177 const __m128i v${NUM_ITERS}_${N} = _mm_loadu_si128((const __m128i*) i${N}); 178 if XNN_UNPREDICTABLE(bh <= ${N+1}) { 179 i${N+1} = i0; 180 } 181 const __m128i v${NUM_ITERS}_${N+1} = _mm_loadu_si128((const __m128i*) i${N+1}); 182 const __m128i v${NUM_ITERS}_${TILE_SIZE-1} = _mm_undefined_si128(); 183 184 $CONST = "const " 185 $if NUM_ITERS == 1: 186 $CONST = "" 187 $for N in range(TILE_SIZE >> 1): 188 ${CONST}__m128i v${NUM_ITERS-1}_${N*2} = _mm_unpacklo_epi${SIZE}(v${NUM_ITERS}_${N*2}, v${NUM_ITERS}_${N*2+1}); 189 ${CONST}__m128i v${NUM_ITERS-1}_${N*2+1} = _mm_unpackhi_epi${SIZE}(v${NUM_ITERS}_${N*2}, v${NUM_ITERS}_${N*2+1}); 190 191 $if NUM_ITERS == 2: 192 $CONST = "" 193 $if NUM_ITERS>=2: 194 $for N in range(0, TILE_SIZE, 4): 195 ${CONST}__m128i v${NUM_ITERS-2}_${N} = _mm_unpacklo_epi${SIZE*2}(v${NUM_ITERS-1}_${N}, v${NUM_ITERS-1}_${N+2}); 196 ${CONST}__m128i v${NUM_ITERS-2}_${N+1} = _mm_unpackhi_epi${SIZE*2}(v${NUM_ITERS-1}_${N}, v${NUM_ITERS-1}_${N+2}); 197 ${CONST}__m128i v${NUM_ITERS-2}_${N+2} = _mm_unpacklo_epi${SIZE*2}(v${NUM_ITERS-1}_${N+1}, v${NUM_ITERS-1}_${N+3}); 198 ${CONST}__m128i v${NUM_ITERS-2}_${N+3} = _mm_unpackhi_epi${SIZE*2}(v${NUM_ITERS-1}_${N+1}, v${NUM_ITERS-1}_${N+3}); 199 200 $if NUM_ITERS == 3: 201 $CONST = "" 202 $if NUM_ITERS>=3: 203 $for M in range(0, TILE_SIZE, 8): 204 $for N in range(0, 4): 205 ${CONST}__m128i v${NUM_ITERS-3}_${M+2*N} = _mm_unpacklo_epi${SIZE*4}(v${NUM_ITERS-2}_${M+N}, v${NUM_ITERS-2}_${M+N+4}); 206 ${CONST}__m128i v${NUM_ITERS-3}_${M+2*N+1} = _mm_unpackhi_epi${SIZE*4}(v${NUM_ITERS-2}_${M+N}, v${NUM_ITERS-2}_${M+N+4}); 207 208 $if NUM_ITERS>=4: 209 $for N in range(TILE_SIZE >> 1): 210 __m128i v0_${N*2} = _mm_unpacklo_epi64(v1_${N}, v1_${N+8}); 211 __m128i v0_${N*2+1} = _mm_unpackhi_epi64(v1_${N}, v1_${N+8}); 212 213 if (bh & ${TILE_SIZE>>1}) { 214 $if OUT_PTRS == "SWITCH": 215 uint${SIZE}_t* oN = (uint${SIZE}_t*) ((uintptr_t) o + oN_stride); 216 switch (rem) { 217 $for N in reversed(range(2, TILE_SIZE)): 218 case ${N}: 219 _mm_storel_epi64((__m128i*) oN, v0_${N}); 220 oN = (uint${SIZE}_t*) ((uintptr_t) oN + minus_output_stride); 221 case 1: 222 _mm_storel_epi64((__m128i*) oN, v0_1); 223 case 0: 224 _mm_storel_epi64((__m128i*) o, v0_0); 225 break; 226 default: 227 XNN_UNREACHABLE; 228 } 229 $if NUM_ITERS > 1: 230 o += ${TILE_SIZE>>1}; 231 $elif OUT_PTRS == "MOV": 232 o = (uint${SIZE}_t*) ((uintptr_t) o + oN_stride); 233 _mm_storel_epi64((__m128i*) o, v0_${TILE_SIZE-1}); 234 uint${SIZE}_t *oN = (uint${SIZE}_t*) ((uintptr_t) o + minus_output_stride); 235 $for N in reversed(range(2, TILE_SIZE, 2)): 236 if XNN_UNPREDICTABLE(block_width > ${N+1}) { 237 o = oN; 238 } 239 _mm_storel_epi64((__m128i*) o, v0_${N}); 240 oN = (uint${SIZE}_t*) ((uintptr_t) o + minus_output_stride); 241 if XNN_UNPREDICTABLE(block_width >= ${N+1}) { 242 o = oN; 243 } 244 _mm_storel_epi64((__m128i*) o, v0_${N-1}); 245 oN = (uint${SIZE}_t*) ((uintptr_t) o + minus_output_stride); 246 if XNN_UNPREDICTABLE(block_width > 1) { 247 o = oN; 248 } 249 _mm_storel_epi64((__m128i*) o, v0_0); 250 $if NUM_ITERS > 1: 251 o += ${TILE_SIZE>>1}; 252 $else: 253 $for N in reversed(range(TILE_SIZE)): 254 _mm_storel_epi64((__m128i*) o${N}, v0_${N}); 255 $if NUM_ITERS>1: 256 o${N} += ${TILE_SIZE>>1}; 257 $if NUM_ITERS > 1: 258 $for N in range(TILE_SIZE): 259 v0_${N} = _mm_unpackhi_epi64(v0_${N}, v0_${N}); 260 } 261 262 $if NUM_ITERS>1: 263 if (bh & ${TILE_SIZE>>2}) { 264 $if OUT_PTRS == "SWITCH": 265 uint${SIZE}_t* oN = (uint${SIZE}_t*) ((uintptr_t) o + oN_stride); 266 switch (rem) { 267 $for N in reversed(range(2, TILE_SIZE)): 268 case ${N}: 269 unaligned_store_u32(oN, (uint32_t) _mm_cvtsi128_si32(v0_${N})); 270 oN = (uint${SIZE}_t*) ((uintptr_t) oN + minus_output_stride); 271 case 1: 272 unaligned_store_u32(oN, (uint32_t) _mm_cvtsi128_si32(v0_1)); 273 case 0: 274 unaligned_store_u32(o, (uint32_t) _mm_cvtsi128_si32(v0_0)); 275 break; 276 default: 277 XNN_UNREACHABLE; 278 } 279 $if NUM_ITERS > 2: 280 o += ${TILE_SIZE>>2}; 281 $elif OUT_PTRS == "MOV": 282 o = (uint${SIZE}_t*) ((uintptr_t) o + oN_stride); 283 unaligned_store_u32(o, (uint32_t) _mm_cvtsi128_si32(v0_${TILE_SIZE-1})); 284 uint${SIZE}_t *oN = (uint${SIZE}_t*) ((uintptr_t) o + minus_output_stride); 285 $for N in reversed(range(2, TILE_SIZE, 2)): 286 if XNN_UNPREDICTABLE(block_width > ${N+1}) { 287 o = oN; 288 } 289 unaligned_store_u32(o, (uint32_t) _mm_cvtsi128_si32(v0_${N})); 290 oN = (uint${SIZE}_t*) ((uintptr_t) o + minus_output_stride); 291 if XNN_UNPREDICTABLE(block_width >= ${N+1}) { 292 o = oN; 293 } 294 unaligned_store_u32(o, (uint32_t) _mm_cvtsi128_si32(v0_${N-1})); 295 oN = (uint${SIZE}_t*) ((uintptr_t) o + minus_output_stride); 296 if XNN_UNPREDICTABLE(block_width > 1) { 297 o = oN; 298 } 299 unaligned_store_u32(o, (uint32_t) _mm_cvtsi128_si32(v0_0)); 300 $if NUM_ITERS > 2: 301 o += ${TILE_SIZE>>2}; 302 $else: 303 $for N in reversed(range(TILE_SIZE)): 304 unaligned_store_u32(o${N}, (uint32_t) _mm_cvtsi128_si32(v0_${N})); 305 $if NUM_ITERS>2: 306 o${N} += ${TILE_SIZE>>2}; 307 $if NUM_ITERS > 2: 308 $for N in range(TILE_SIZE): 309 v0_${N} = _mm_srli_epi64(v0_${N}, 32); 310 } 311 $if NUM_ITERS>2: 312 if (bh & ${TILE_SIZE>>3}) { 313 $if OUT_PTRS == "SWITCH": 314 uint${SIZE}_t* oN = (uint${SIZE}_t*) ((uintptr_t) o + oN_stride); 315 switch (rem) { 316 $for N in reversed(range(2, TILE_SIZE)): 317 case ${N}: 318 unaligned_store_u16(oN, (uint16_t) _mm_cvtsi128_si32(v0_${N})); 319 oN = (uint${SIZE}_t*) ((uintptr_t) oN + minus_output_stride); 320 case 1: 321 unaligned_store_u16(oN, (uint16_t) _mm_cvtsi128_si32(v0_1)); 322 case 0: 323 unaligned_store_u16(o, (uint16_t) _mm_cvtsi128_si32(v0_0)); 324 break; 325 default: 326 XNN_UNREACHABLE; 327 } 328 $if NUM_ITERS>3: 329 o += ${TILE_SIZE>>3}; 330 $elif OUT_PTRS == "MOV": 331 o = (uint${SIZE}_t*) ((uintptr_t) o + oN_stride); 332 unaligned_store_u16(o, (uint16_t) _mm_cvtsi128_si32(v0_${TILE_SIZE-1})); 333 uint${SIZE}_t* oN = (uint${SIZE}_t*) ((uintptr_t) o + minus_output_stride); 334 $for N in reversed(range(2, TILE_SIZE, 2)): 335 if XNN_UNPREDICTABLE(block_width > ${N+1}) { 336 o = oN; 337 } 338 unaligned_store_u16(o, (uint16_t) _mm_cvtsi128_si32(v0_${N})); 339 oN = (uint${SIZE}_t*) ((uintptr_t) o + minus_output_stride); 340 if XNN_UNPREDICTABLE(block_width >= ${N+1}) { 341 o = oN; 342 } 343 unaligned_store_u16(o, (uint16_t) _mm_cvtsi128_si32(v0_${N-1})); 344 oN = (uint${SIZE}_t*) ((uintptr_t) o + minus_output_stride); 345 if XNN_UNPREDICTABLE(block_width > 1) { 346 o = oN; 347 } 348 unaligned_store_u16(o, (uint16_t) _mm_cvtsi128_si32(v0_0)); 349 $if NUM_ITERS > 3: 350 o += ${TILE_SIZE>>3}; 351 $else: 352 $for N in reversed(range(TILE_SIZE)): 353 unaligned_store_u16(o${N}, (uint16_t) _mm_cvtsi128_si32(v0_${N})); 354 $if NUM_ITERS>3: 355 o${N} += ${TILE_SIZE>>3}; 356 $if NUM_ITERS>3: 357 $for N in range(TILE_SIZE): 358 v0_${N} = _mm_srli_epi32(v0_${N}, 16); 359 } 360 $if SIZE == 8: 361 if (bh & 1) { 362 $if OUT_PTRS == "SWITCH": 363 uint${SIZE}_t* oN = (uint${SIZE}_t*) ((uintptr_t) o + oN_stride); 364 switch (rem) { 365 $for N in reversed(range(2, TILE_SIZE)): 366 case ${N}: 367 *oN = (uint8_t) _mm_cvtsi128_si32(v0_${N}); 368 oN = (uint${SIZE}_t*) ((uintptr_t) oN + minus_output_stride); 369 case 1: 370 *oN = (uint8_t) _mm_cvtsi128_si32(v0_1); 371 case 0: 372 *o = (uint8_t) _mm_cvtsi128_si32(v0_0); 373 break; 374 default: 375 XNN_UNREACHABLE; 376 } 377 $elif OUT_PTRS == "MOV": 378 o = (uint${SIZE}_t*) ((uintptr_t) o + oN_stride); 379 *o = (uint8_t) _mm_cvtsi128_si32(v0_${TILE_SIZE-1}); 380 uint${SIZE}_t* oN = (uint${SIZE}_t*) ((uintptr_t) o + minus_output_stride); 381 $for N in reversed(range(2, TILE_SIZE, 2)): 382 if XNN_UNPREDICTABLE(block_width > ${N+1}) { 383 o = oN; 384 } 385 *o = (uint8_t) _mm_cvtsi128_si32(v0_${N}); 386 oN = (uint${SIZE}_t*) ((uintptr_t) o + minus_output_stride); 387 if XNN_UNPREDICTABLE(block_width >= ${N+1}) { 388 o = oN; 389 } 390 *o = (uint8_t) _mm_cvtsi128_si32(v0_${N-1}); 391 oN = (uint${SIZE}_t*) ((uintptr_t) o + minus_output_stride); 392 if XNN_UNPREDICTABLE(block_width > 1) { 393 o = oN; 394 } 395 *o = (uint8_t) _mm_cvtsi128_si32(v0_0); 396 } 397 } 398 399 $if IN_PTRS == "MULTI": 400 i0 = (const uint${SIZE}_t*) ((uintptr_t) i0 + input_reset); 401 $for N in range(1, TILE_SIZE): 402 i${N} = (const uint${SIZE}_t*) ((uintptr_t) i${N-1} + input_stride); 403 $else: 404 i0 = (const uint${SIZE}_t*) ((uintptr_t) i0 + input_reset); 405 $if OUT_PTRS == "MULTI": 406 o0 = (uint${SIZE}_t*) ((uintptr_t) o0 + output_reset); 407 $for N in range(1, TILE_SIZE): 408 o${N} = (uint${SIZE}_t*) ((uintptr_t) o${N} + output_reset); 409 $else: 410 o = (uint${SIZE}_t*) ((uintptr_t) o + output_reset); 411 block_width = doz(block_width, tile_width); 412 } while (block_width != 0); 413} 414