xref: /aosp_15_r20/external/XNNPACK/src/f32-vlrelu/avx.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert BATCH_TILE % 8 == 0
7$assert BATCH_TILE >= 8
8$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
9#include <assert.h>
10
11#include <immintrin.h>
12
13#include <xnnpack/common.h>
14#include <xnnpack/vunary.h>
15
16
17void xnn_f32_vlrelu_ukernel__avx_x${BATCH_TILE}(
18    size_t n,
19    const float* x,
20    float* y,
21    const union xnn_f32_lrelu_params params[restrict XNN_MIN_ELEMENTS(1)])
22{
23  assert(n != 0);
24  assert(n % sizeof(float) == 0);
25
26  const __m256 vslope = _mm256_load_ps(params->avx.slope);
27  for (; n >= ${BATCH_TILE} * sizeof(float); n -= ${BATCH_TILE} * sizeof(float)) {
28    const __m256 vx${ABC[0:8]} = _mm256_loadu_ps(x);
29    $for N in range(8, BATCH_TILE, 8):
30      const __m256 vx${ABC[N:N+8]} = _mm256_loadu_ps(x + ${N});
31    x += ${BATCH_TILE};
32
33    $for N in range(0, BATCH_TILE, 8):
34      __m256 vacc${ABC[N:N+8]} = _mm256_mul_ps(vx${ABC[N:N+8]}, vslope);
35
36    $for N in range(0, BATCH_TILE, 8):
37      vacc${ABC[N:N+8]} = _mm256_blendv_ps(vx${ABC[N:N+8]}, vacc${ABC[N:N+8]}, vx${ABC[N:N+8]});
38
39    _mm256_storeu_ps(y, vacc${ABC[0:8]});
40    $for N in range(8, BATCH_TILE, 8):
41      _mm256_storeu_ps(y + ${N}, vacc${ABC[N:N+8]});
42    y += ${BATCH_TILE};
43  }
44  $if BATCH_TILE > 8:
45    for (; n >= 8 * sizeof(float); n -= 8 * sizeof(float)) {
46      const __m256 vx = _mm256_loadu_ps(x);
47      x += 8;
48      __m256 vacc = _mm256_mul_ps(vx, vslope);
49      vacc = _mm256_blendv_ps(vx, vacc, vx);
50      _mm256_storeu_ps(y, vacc);
51      y += 8;
52    }
53  if XNN_UNLIKELY(n != 0) {
54    assert(n >= 1 * sizeof(float));
55    assert(n <= 7 * sizeof(float));
56    const __m256i vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) &params->avx.mask_table[7] - n));
57
58    const __m256 vx = _mm256_maskload_ps(x, vmask);
59    __m256 vacc = _mm256_mul_ps(vx, vslope);
60    vacc = _mm256_blendv_ps(vx, vacc, vx);
61
62    __m128 vacc_lo = _mm256_castps256_ps128(vacc);
63    if (n & (4 * sizeof(float))) {
64      _mm_storeu_ps(y, vacc_lo);
65      vacc_lo = _mm256_extractf128_ps(vacc, 1);
66      y += 4;
67    }
68    if (n & (2 * sizeof(float))) {
69      _mm_storel_pi((__m64*) y, vacc_lo);
70      vacc_lo = _mm_movehl_ps(vacc_lo, vacc_lo);
71      y += 2;
72    }
73    if (n & (1 * sizeof(float))) {
74      _mm_store_ss(y, vacc_lo);
75    }
76  }
77}
78