1// Copyright 2022 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert CHANNEL_TILE % 8 == 0 7$assert CHANNEL_TILE >= 8 8$assert ROW_TILE >= 3 9$assert ROW_SUBTILE >= 3 10$assert ROW_SUBTILE <= ROW_TILE 11$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 12#include <assert.h> 13 14#include <immintrin.h> 15 16#include <xnnpack/gavgpool.h> 17#include <xnnpack/intrinsics-polyfill.h> 18#include <xnnpack/math.h> 19 20 21void xnn_f16_gavgpool_minmax_ukernel_${ROW_TILE}p${ROW_SUBTILE}x__f16c_c${CHANNEL_TILE}( 22 size_t rows, 23 size_t channels, 24 const void* input, 25 size_t input_stride, 26 const void* zero, 27 void* buffer, 28 void* output, 29 const union xnn_f16_scaleminmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS 30{ 31 assert(rows > ${ROW_TILE}); 32 assert(channels != 0); 33 34 const uint16_t* i0 = input; 35 $for M in range(1, ROW_TILE): 36 const uint16_t* i${M} = (const uint16_t*) ((uintptr_t) i${M-1} + input_stride); 37 const size_t input_increment = ${ROW_TILE} * input_stride - round_up_po2(channels, 8) * sizeof(uint16_t); 38 39 uint16_t* b = buffer; 40 size_t c = channels; 41 for (; ${"c >= %d" % CHANNEL_TILE if CHANNEL_TILE > 8 else "c != 0"}; ${("c -= %d" if CHANNEL_TILE > 8 else "c = doz(c, %d)") % CHANNEL_TILE}) { 42 $for M in range(2): 43 $for C in range(0, CHANNEL_TILE, 8): 44 const __m256 vi${M}x${ABC[C:C+8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i${M})); i${M} += 8; 45 46 $for C in range(0, CHANNEL_TILE, 8): 47 const __m256 vi2x${ABC[C:C+8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2)); i2 += 8; 48 __m128i vacc${ABC[C:C+8]} = _mm256_cvtps_ph(_mm256_add_ps(vi0x${ABC[C:C+8]}, vi1x${ABC[C:C+8]}), _MM_FROUND_NO_EXC); 49 50 $for M in range(2, ROW_TILE): 51 $for C in range(0, CHANNEL_TILE, 8): 52 $if M + 1 != ROW_TILE: 53 const __m256 vi${M+1}x${ABC[C:C+8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i${M+1})); i${M+1} += 8; 54 vacc${ABC[C:C+8]} = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc${ABC[C:C+8]}), vi${M}x${ABC[C:C+8]}), _MM_FROUND_NO_EXC); 55 56 $for C in range(0, CHANNEL_TILE, 8): 57 _mm_store_si128((__m128i*) b, vacc${ABC[C:C+8]}); b += 8; 58 } 59 $if CHANNEL_TILE > 8: 60 if XNN_UNLIKELY(c != 0) { 61 do { 62 $for M in range(3): 63 const __m256 vi${M}x${ABC[0:8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i${M})); i${M} += 8; 64 __m128i vacc${ABC[0:8]} = _mm256_cvtps_ph(_mm256_add_ps(vi0x${ABC[0:8]}, vi1x${ABC[0:8]}), _MM_FROUND_NO_EXC); 65 66 $for M in range(2, ROW_TILE): 67 $if M + 1 != ROW_TILE: 68 const __m256 vi${M+1}x${ABC[0:8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i${M+1})); i${M+1} += 8; 69 vacc${ABC[0:8]} = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc${ABC[0:8]}), vi${M}x${ABC[0:8]}), _MM_FROUND_NO_EXC); 70 71 _mm_store_si128((__m128i*) b, vacc${ABC[0:8]}); b += 8; 72 73 c = doz(c, 8); 74 } while (c != 0); 75 } 76 77 for (rows -= ${ROW_TILE}; rows > ${ROW_SUBTILE}; rows -= ${ROW_SUBTILE}) { 78 $for M in range(ROW_SUBTILE): 79 i${M} = (const uint16_t*) ((uintptr_t) i${M + ROW_TILE - ROW_SUBTILE} + input_increment); 80 81 uint16_t* b = buffer; 82 size_t c = channels; 83 for (; ${"c >= %d" % CHANNEL_TILE if CHANNEL_TILE > 8 else "c != 0"}; ${("c -= %d" if CHANNEL_TILE > 8 else "c = doz(c, %d)") % CHANNEL_TILE}) { 84 __m128i vacc${ABC[0:8]} = _mm_loadu_si128((const __m128i*) b); 85 $for C in range(8, CHANNEL_TILE, 8): 86 __m128i vacc${ABC[C:C+8]} = _mm_loadu_si128((const __m128i*) (b + ${C})); 87 88 $for C in range(0, CHANNEL_TILE, 8): 89 const __m256 vi0x${ABC[C:C+8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0)); i0 += 8; 90 91 $for M in range(ROW_TILE): 92 $for C in range(0, CHANNEL_TILE, 8): 93 $if M + 1 != ROW_TILE: 94 const __m256 vi${M+1}x${ABC[C:C+8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i${M+1})); i${M+1} += 8; 95 vacc${ABC[C:C+8]} = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc${ABC[C:C+8]}), vi${M}x${ABC[C:C+8]}), _MM_FROUND_NO_EXC); 96 97 $for C in range(0, CHANNEL_TILE, 8): 98 _mm_store_si128((__m128i*) b, vacc${ABC[C:C+8]}); b += 8; 99 } 100 $if CHANNEL_TILE > 8: 101 if XNN_UNLIKELY(c != 0) { 102 do { 103 __m128i vacc${ABC[0:8]} = _mm_loadu_si128((const __m128i*) b); 104 const __m256 vi0x${ABC[0:8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0)); i0 += 8; 105 106 $for M in range(ROW_TILE): 107 $if M + 1 != ROW_TILE: 108 const __m256 vi${M+1}x${ABC[0:8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i${M+1})); i${M+1} += 8; 109 vacc${ABC[0:8]} = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc${ABC[0:8]}), vi${M}x${ABC[0:8]}), _MM_FROUND_NO_EXC); 110 111 _mm_store_si128((__m128i*) b, vacc${ABC[0:8]}); 112 b += 8; 113 114 c = doz(c, 8); 115 } while (c != 0); 116 } 117 } 118 119 i0 = (const uint16_t*) ((uintptr_t) i${ROW_TILE - ROW_SUBTILE} + input_increment); 120 $for M in range(1, ROW_SUBTILE): 121 i${M} = (const uint16_t*) ((uintptr_t) i${M + ROW_TILE - ROW_SUBTILE} + input_increment); 122 $if M % 2 == 1: 123 if XNN_UNPREDICTABLE(rows < ${M+1}) { 124 i${M} = (const uint16_t*) zero; 125 } 126 $else: 127 if XNN_UNPREDICTABLE(rows <= ${M}) { 128 i${M} = (const uint16_t*) zero; 129 } 130 uint16_t* o = (uint16_t*) output; 131 132 const __m256 vscale = _mm256_load_ps(params->avx.scale); 133 const __m256 vmin = _mm256_load_ps(params->avx.min); 134 const __m256 vmax = _mm256_load_ps(params->avx.max); 135 for (; channels >= ${CHANNEL_TILE}; channels -= ${CHANNEL_TILE}) { 136 $for C in range(0, CHANNEL_TILE, 8): 137 __m128i vacc${ABC[C:C+8]} = _mm_loadu_si128((const __m128i*) buffer); buffer = (uint16_t*) buffer + 8; 138 139 $for C in range(0, CHANNEL_TILE, 8): 140 const __m256 vi0x${ABC[C:C+8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0)); i0 += 8; 141 142 $for M in range(ROW_TILE): 143 $for C in range(0, CHANNEL_TILE, 8): 144 $if M + 1 != ROW_TILE: 145 const __m256 vi${M+1}x${ABC[C:C+8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i${M+1})); i${M+1} += 8; 146 vacc${ABC[C:C+8]} = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc${ABC[C:C+8]}), vi${M}x${ABC[C:C+8]}), _MM_FROUND_NO_EXC); 147 148 $for C in range(0, CHANNEL_TILE, 8): 149 vacc${ABC[C:C+8]} = _mm256_cvtps_ph(_mm256_mul_ps(_mm256_cvtph_ps(vacc${ABC[C:C+8]}), vscale), _MM_FROUND_NO_EXC); 150 151 $for C in range(0, CHANNEL_TILE, 8): 152 __m256 vout${ABC[C:C+8]} = _mm256_max_ps(_mm256_cvtph_ps(vacc${ABC[C:C+8]}), vmin); 153 154 $for C in range(0, CHANNEL_TILE, 8): 155 vout${ABC[C:C+8]} = _mm256_min_ps(vout${ABC[C:C+8]}, vmax); 156 157 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vout${ABC[0:8]}, _MM_FROUND_NO_EXC)); 158 $for C in range(8, CHANNEL_TILE, 8): 159 _mm_storeu_si128((__m128i*) ((uint16_t*) o + ${C}), _mm256_cvtps_ph(vout${ABC[C:C+8]}, _MM_FROUND_NO_EXC)); 160 o += ${CHANNEL_TILE}; 161 } 162 if XNN_UNLIKELY(channels != 0) { 163 ${"do " if CHANNEL_TILE > 8 else ""}{ 164 __m128i vacc${ABC[0:8]} = _mm_loadu_si128((const __m128i*) buffer); buffer = (uint16_t*) buffer + 8; 165 166 const __m256 vi0x${ABC[0:8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0)); i0 += 8; 167 $for M in range(ROW_TILE): 168 $if M + 1 != ROW_TILE: 169 const __m256 vi${M+1}x${ABC[0:8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i${M+1})); i${M+1} += 8; 170 vacc${ABC[0:8]} = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc${ABC[0:8]}), vi${M}x${ABC[0:8]}), _MM_FROUND_NO_EXC); 171 172 vacc${ABC[0:8]} = _mm256_cvtps_ph(_mm256_mul_ps(_mm256_cvtph_ps(vacc${ABC[0:8]}), vscale), _MM_FROUND_NO_EXC); 173 __m256 vout${ABC[0:8]} = _mm256_max_ps(_mm256_cvtph_ps(vacc${ABC[0:8]}), vmin); 174 vout${ABC[0:8]} = _mm256_min_ps(vout${ABC[0:8]}, vmax); 175 176 $if CHANNEL_TILE > 8: 177 if XNN_LIKELY(channels >= 8) { 178 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vout${ABC[0:8]}, _MM_FROUND_NO_EXC)); 179 o += 8; 180 channels -= 8; 181 } else { 182 __m128i vh${ABC[0:8]} = _mm256_cvtps_ph(vout${ABC[0:8]}, _MM_FROUND_NO_EXC); 183 if (channels & 4) { 184 _mm_storel_epi64((__m128i*) o, vh${ABC[0:8]}); 185 o += 4; 186 vh${ABC[0:8]} = _mm_unpackhi_epi64(vh${ABC[0:8]}, vh${ABC[0:8]}); 187 } 188 if (channels & 2) { 189 _mm_storeu_si32(o, vh${ABC[0:8]}); 190 o += 2; 191 vh${ABC[0:8]} = _mm_srli_epi64(vh${ABC[0:8]}, 32); 192 } 193 if (channels & 1) { 194 *o = (uint16_t) _mm_extract_epi16(vh${ABC[0:8]}, 0); 195 } 196 channels = 0; 197 } 198 $else: 199 __m128i vh${ABC[0:8]} = _mm256_cvtps_ph(vout${ABC[0:8]}, _MM_FROUND_NO_EXC); 200 if (channels & 4) { 201 _mm_storel_epi64((__m128i*) o, vh${ABC[0:8]}); 202 o += 4; 203 vh${ABC[0:8]} = _mm_unpackhi_epi64(vh${ABC[0:8]}, vh${ABC[0:8]}); 204 } 205 if (channels & 2) { 206 _mm_storeu_si32(o, vh${ABC[0:8]}); 207 o += 2; 208 vh${ABC[0:8]} = _mm_srli_epi64(vh${ABC[0:8]}, 32); 209 } 210 if (channels & 1) { 211 *o = (uint16_t) _mm_extract_epi16(vh${ABC[0:8]}, 0); 212 } 213 }${" while (channels != 0);" if CHANNEL_TILE > 8 else ""} 214 } 215} 216