xref: /aosp_15_r20/external/XNNPACK/src/f16-gavgpool/unipass-f16c.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2022 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert CHANNEL_TILE % 8 == 0
7$assert CHANNEL_TILE >= 8
8$assert ROW_TILE >= 3
9$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
10#include <assert.h>
11
12#include <immintrin.h>
13
14#include <xnnpack/gavgpool.h>
15#include <xnnpack/intrinsics-polyfill.h>
16
17
18void xnn_f16_gavgpool_minmax_ukernel_${ROW_TILE}x__f16c_c${CHANNEL_TILE}(
19    size_t rows,
20    size_t channels,
21    const void* input,
22    size_t input_stride,
23    const void* zero,
24    void* output,
25    const union xnn_f16_scaleminmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
26{
27  assert(rows != 0);
28  assert(rows <= ${ROW_TILE});
29  assert(channels != 0);
30
31  const uint16_t* i0 = input;
32  $for M in range(1, ROW_TILE):
33    const uint16_t* i${M} = (const uint16_t*) ((uintptr_t) i${M-1} + input_stride);
34    $if M % 2 == 1:
35      if XNN_UNPREDICTABLE(rows < ${M+1}) {
36        i${M} = (const uint16_t*) zero;
37      }
38    $else:
39      if XNN_UNPREDICTABLE(rows <= ${M}) {
40        i${M} = (const uint16_t*) zero;
41      }
42  uint16_t* o = (uint16_t*) output;
43
44  const __m256 vscale = _mm256_load_ps(params->avx.scale);
45  const __m256 vmin = _mm256_load_ps(params->avx.min);
46  const __m256 vmax = _mm256_load_ps(params->avx.max);
47  for (; channels >= ${CHANNEL_TILE}; channels -= ${CHANNEL_TILE}) {
48    $for M in range(2):
49      const __m256 vi${M}x${ABC[0:8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i${M}));
50      $for C in range(8, CHANNEL_TILE, 8):
51        const __m256 vi${M}x${ABC[C:C+8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i${M} + ${C})));
52      i${M} += ${CHANNEL_TILE};
53
54    $for C in range(0, CHANNEL_TILE, 8):
55      $if C == 0:
56        const __m256 vi2x${ABC[0:8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
57      $else:
58        const __m256 vi2x${ABC[C:C+8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i2 + ${C})));
59      __m128i vacc${ABC[C:C+8]} = _mm256_cvtps_ph(_mm256_add_ps(vi0x${ABC[C:C+8]}, vi1x${ABC[C:C+8]}), _MM_FROUND_NO_EXC);
60    i2 += ${CHANNEL_TILE};
61
62    $for M in range(2, ROW_TILE):
63      $for C in range(0, CHANNEL_TILE, 8):
64        $if M + 1 != ROW_TILE:
65          $if C == 0:
66            const __m256 vi${M+1}x${ABC[0:8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i${M+1}));
67          $else:
68            const __m256 vi${M+1}x${ABC[C:C+8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i${M+1} + ${C})));
69          $if C + 8 == CHANNEL_TILE:
70            i${M+1} += ${CHANNEL_TILE};
71        vacc${ABC[C:C+8]} = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc${ABC[C:C+8]}), vi${M}x${ABC[C:C+8]}), _MM_FROUND_NO_EXC);
72
73    $for C in range(0, CHANNEL_TILE, 8):
74      vacc${ABC[C:C+8]} = _mm256_cvtps_ph(_mm256_mul_ps(_mm256_cvtph_ps(vacc${ABC[C:C+8]}), vscale), _MM_FROUND_NO_EXC);
75
76    $for C in range(0, CHANNEL_TILE, 8):
77      __m256 vout${ABC[C:C+8]} = _mm256_max_ps(_mm256_cvtph_ps(vacc${ABC[C:C+8]}), vmin);
78
79    $for C in range(0, CHANNEL_TILE, 8):
80      vout${ABC[C:C+8]} = _mm256_min_ps(vout${ABC[C:C+8]}, vmax);
81
82    _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vout${ABC[0:8]}, _MM_FROUND_NO_EXC));
83    $for C in range(8, CHANNEL_TILE, 8):
84      _mm_storeu_si128((__m128i*) (o + ${C}), _mm256_cvtps_ph(vout${ABC[C:C+8]}, _MM_FROUND_NO_EXC));
85    o += ${CHANNEL_TILE};
86  }
87  if XNN_UNLIKELY(channels != 0) {
88    ${"do " if CHANNEL_TILE > 8 else ""}{
89      $for M in range(2):
90        const __m256 vi${M}x${ABC[0:8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i${M}));
91        $if CHANNEL_TILE > 8:
92          i${M} += 8;
93
94      const __m256 vi2x${ABC[0:8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
95      __m128i vacc${ABC[0:8]} = _mm256_cvtps_ph(_mm256_add_ps(vi0x${ABC[0:8]}, vi1x${ABC[0:8]}), _MM_FROUND_NO_EXC);
96      $if CHANNEL_TILE > 8:
97        i2 += 8;
98
99      $for M in range(2, ROW_TILE):
100        $if M + 1 != ROW_TILE:
101          const __m256 vi${M+1}x${ABC[0:8]} = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i${M+1}));
102          $if CHANNEL_TILE > 8:
103            i${M+1} += 8;
104        vacc${ABC[0:8]} = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc${ABC[0:8]}), vi${M}x${ABC[0:8]}), _MM_FROUND_NO_EXC);
105
106      vacc${ABC[0:8]} = _mm256_cvtps_ph(_mm256_mul_ps(_mm256_cvtph_ps(vacc${ABC[0:8]}), vscale), _MM_FROUND_NO_EXC);
107      __m256 vout${ABC[0:8]} = _mm256_max_ps(_mm256_cvtph_ps(vacc${ABC[0:8]}), vmin);
108      vout${ABC[0:8]} = _mm256_min_ps(vout${ABC[0:8]}, vmax);
109
110      $if CHANNEL_TILE > 8:
111        if XNN_LIKELY(channels >= 8) {
112          _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vout${ABC[0:8]}, _MM_FROUND_NO_EXC));
113          o += 8;
114          channels -= 8;
115        } else {
116          __m128i vh${ABC[0:8]} = _mm256_cvtps_ph(vout${ABC[0:8]}, _MM_FROUND_NO_EXC);
117          if (channels & 4) {
118            _mm_storel_epi64((__m128i*) o, vh${ABC[0:8]});
119            o += 4;
120            vh${ABC[0:8]} = _mm_unpackhi_epi64(vh${ABC[0:8]}, vh${ABC[0:8]});
121          }
122          if (channels & 2) {
123            _mm_storeu_si32(o, vh${ABC[0:8]});
124            o += 2;
125            vh${ABC[0:8]} = _mm_srli_epi64(vh${ABC[0:8]}, 32);
126          }
127          if (channels & 1) {
128            *o = (uint16_t) _mm_extract_epi16(vh${ABC[0:8]}, 0);
129          }
130          channels = 0;
131        }
132      $else:
133        __m128i vh${ABC[0:8]} = _mm256_cvtps_ph(vout${ABC[0:8]}, _MM_FROUND_NO_EXC);
134        if (channels & 4) {
135          _mm_storel_epi64((__m128i*) o, vh${ABC[0:8]});
136          o += 4;
137          vh${ABC[0:8]} = _mm_unpackhi_epi64(vh${ABC[0:8]}, vh${ABC[0:8]});
138        }
139        if (channels & 2) {
140          _mm_storeu_si32(o, vh${ABC[0:8]});
141          o += 2;
142          vh${ABC[0:8]} = _mm_srli_epi64(vh${ABC[0:8]}, 32);
143        }
144        if (channels & 1) {
145          *o = (uint16_t) _mm_extract_epi16(vh${ABC[0:8]}, 0);
146        }
147    }${" while (channels != 0);" if CHANNEL_TILE > 8 else ""}
148  }
149}
150