1// Copyright 2022 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert CHANNEL_TILE % 8 == 0 7$assert CHANNEL_TILE >= 8 8$assert ROW_TILE >= 3 9$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 10#include <assert.h> 11 12#include <immintrin.h> 13 14#include <xnnpack/gavgpool.h> 15#include <xnnpack/intrinsics-polyfill.h> 16 17 18void xnn_f16_gavgpool_minmax_ukernel_${ROW_TILE}x__f16c_c${CHANNEL_TILE}( 19 size_t rows, 20 size_t channels, 21 const void* input, 22 size_t input_stride, 23 const void* zero, 24 void* output, 25 const union xnn_f16_scaleminmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS 26{ 27 assert(rows != 0); 28 assert(rows <= ${ROW_TILE}); 29 assert(channels != 0); 30 31 const uint16_t* i0 = input; 32 $for M in range(1, ROW_TILE): 33 const uint16_t* i${M} = (const uint16_t*) ((uintptr_t) i${M-1} + input_stride); 34 $if M % 2 == 1: 35 if XNN_UNPREDICTABLE(rows < ${M+1}) { 36 i${M} = (const uint16_t*) zero; 37 } 38 $else: 39 if XNN_UNPREDICTABLE(rows <= ${M}) { 40 i${M} = (const uint16_t*) zero; 41 } 42 uint16_t* o = (uint16_t*) output; 43 44 const __m256 vscale = _mm256_load_ps(params->avx.scale); 45 const __m256 vmin = _mm256_load_ps(params->avx.min); 46 const __m256 vmax = _mm256_load_ps(params->avx.max); 47 for (; channels >= ${CHANNEL_TILE}; channels -= ${CHANNEL_TILE}) { 48 $for M in range(2): 49 const __m256 vi${M}x${ABC[0:8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i${M})); 50 $for C in range(8, CHANNEL_TILE, 8): 51 const __m256 vi${M}x${ABC[C:C+8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i${M} + ${C}))); 52 i${M} += ${CHANNEL_TILE}; 53 54 $for C in range(0, CHANNEL_TILE, 8): 55 $if C == 0: 56 const __m256 vi2x${ABC[0:8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2)); 57 $else: 58 const __m256 vi2x${ABC[C:C+8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i2 + ${C}))); 59 __m128i vacc${ABC[C:C+8]} = _mm256_cvtps_ph(_mm256_add_ps(vi0x${ABC[C:C+8]}, vi1x${ABC[C:C+8]}), _MM_FROUND_NO_EXC); 60 i2 += ${CHANNEL_TILE}; 61 62 $for M in range(2, ROW_TILE): 63 $for C in range(0, CHANNEL_TILE, 8): 64 $if M + 1 != ROW_TILE: 65 $if C == 0: 66 const __m256 vi${M+1}x${ABC[0:8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i${M+1})); 67 $else: 68 const __m256 vi${M+1}x${ABC[C:C+8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i${M+1} + ${C}))); 69 $if C + 8 == CHANNEL_TILE: 70 i${M+1} += ${CHANNEL_TILE}; 71 vacc${ABC[C:C+8]} = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc${ABC[C:C+8]}), vi${M}x${ABC[C:C+8]}), _MM_FROUND_NO_EXC); 72 73 $for C in range(0, CHANNEL_TILE, 8): 74 vacc${ABC[C:C+8]} = _mm256_cvtps_ph(_mm256_mul_ps(_mm256_cvtph_ps(vacc${ABC[C:C+8]}), vscale), _MM_FROUND_NO_EXC); 75 76 $for C in range(0, CHANNEL_TILE, 8): 77 __m256 vout${ABC[C:C+8]} = _mm256_max_ps(_mm256_cvtph_ps(vacc${ABC[C:C+8]}), vmin); 78 79 $for C in range(0, CHANNEL_TILE, 8): 80 vout${ABC[C:C+8]} = _mm256_min_ps(vout${ABC[C:C+8]}, vmax); 81 82 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vout${ABC[0:8]}, _MM_FROUND_NO_EXC)); 83 $for C in range(8, CHANNEL_TILE, 8): 84 _mm_storeu_si128((__m128i*) (o + ${C}), _mm256_cvtps_ph(vout${ABC[C:C+8]}, _MM_FROUND_NO_EXC)); 85 o += ${CHANNEL_TILE}; 86 } 87 if XNN_UNLIKELY(channels != 0) { 88 ${"do " if CHANNEL_TILE > 8 else ""}{ 89 $for M in range(2): 90 const __m256 vi${M}x${ABC[0:8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i${M})); 91 $if CHANNEL_TILE > 8: 92 i${M} += 8; 93 94 const __m256 vi2x${ABC[0:8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2)); 95 __m128i vacc${ABC[0:8]} = _mm256_cvtps_ph(_mm256_add_ps(vi0x${ABC[0:8]}, vi1x${ABC[0:8]}), _MM_FROUND_NO_EXC); 96 $if CHANNEL_TILE > 8: 97 i2 += 8; 98 99 $for M in range(2, ROW_TILE): 100 $if M + 1 != ROW_TILE: 101 const __m256 vi${M+1}x${ABC[0:8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i${M+1})); 102 $if CHANNEL_TILE > 8: 103 i${M+1} += 8; 104 vacc${ABC[0:8]} = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc${ABC[0:8]}), vi${M}x${ABC[0:8]}), _MM_FROUND_NO_EXC); 105 106 vacc${ABC[0:8]} = _mm256_cvtps_ph(_mm256_mul_ps(_mm256_cvtph_ps(vacc${ABC[0:8]}), vscale), _MM_FROUND_NO_EXC); 107 __m256 vout${ABC[0:8]} = _mm256_max_ps(_mm256_cvtph_ps(vacc${ABC[0:8]}), vmin); 108 vout${ABC[0:8]} = _mm256_min_ps(vout${ABC[0:8]}, vmax); 109 110 $if CHANNEL_TILE > 8: 111 if XNN_LIKELY(channels >= 8) { 112 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vout${ABC[0:8]}, _MM_FROUND_NO_EXC)); 113 o += 8; 114 channels -= 8; 115 } else { 116 __m128i vh${ABC[0:8]} = _mm256_cvtps_ph(vout${ABC[0:8]}, _MM_FROUND_NO_EXC); 117 if (channels & 4) { 118 _mm_storel_epi64((__m128i*) o, vh${ABC[0:8]}); 119 o += 4; 120 vh${ABC[0:8]} = _mm_unpackhi_epi64(vh${ABC[0:8]}, vh${ABC[0:8]}); 121 } 122 if (channels & 2) { 123 _mm_storeu_si32(o, vh${ABC[0:8]}); 124 o += 2; 125 vh${ABC[0:8]} = _mm_srli_epi64(vh${ABC[0:8]}, 32); 126 } 127 if (channels & 1) { 128 *o = (uint16_t) _mm_extract_epi16(vh${ABC[0:8]}, 0); 129 } 130 channels = 0; 131 } 132 $else: 133 __m128i vh${ABC[0:8]} = _mm256_cvtps_ph(vout${ABC[0:8]}, _MM_FROUND_NO_EXC); 134 if (channels & 4) { 135 _mm_storel_epi64((__m128i*) o, vh${ABC[0:8]}); 136 o += 4; 137 vh${ABC[0:8]} = _mm_unpackhi_epi64(vh${ABC[0:8]}, vh${ABC[0:8]}); 138 } 139 if (channels & 2) { 140 _mm_storeu_si32(o, vh${ABC[0:8]}); 141 o += 2; 142 vh${ABC[0:8]} = _mm_srli_epi64(vh${ABC[0:8]}, 32); 143 } 144 if (channels & 1) { 145 *o = (uint16_t) _mm_extract_epi16(vh${ABC[0:8]}, 0); 146 } 147 }${" while (channels != 0);" if CHANNEL_TILE > 8 else ""} 148 } 149} 150