xref: /aosp_15_r20/external/XNNPACK/src/f32-dwconv/up-avx512.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert CHANNEL_TILE % 16 == 0
7$assert KERNEL_TILE >= 2
8$assert ACCUMULATORS >= 1
9$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
10#include <assert.h>
11
12#include <immintrin.h>
13
14#include <xnnpack/dwconv.h>
15#include <xnnpack/intrinsics-polyfill.h>
16
17
18void xnn_f32_dwconv_minmax_ukernel_up${CHANNEL_TILE}x${KERNEL_TILE}__avx512f${"" if ACCUMULATORS == 1 else "_acc%d" % ACCUMULATORS}(
19    size_t channels,
20    size_t output_width,
21    const float** input,
22    const float* weights,
23    float* output,
24    size_t input_stride,
25    size_t output_increment,
26    size_t input_offset,
27    const float* zero,
28    const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
29{
30  assert(channels != 0);
31  assert(output_width != 0);
32
33  const __m512 vmax = _mm512_set1_ps(params->scalar.max);
34  const __m512 vmin = _mm512_set1_ps(params->scalar.min);
35  do {
36    $for K in range(KERNEL_TILE):
37      const float* i${K} = input[${K}];
38      assert(i${K} != NULL);
39      if XNN_UNPREDICTABLE(i${K} != zero) {
40        i${K} = (const float*) ((uintptr_t) i${K} + input_offset);
41      }
42    input = (const float**) ((uintptr_t) input + input_stride);
43
44    size_t c = channels;
45    const float* w = weights;
46    for (; c >= ${CHANNEL_TILE}; c -= ${CHANNEL_TILE}) {
47      __m512 vacc${ABC[0:16]}p0 = _mm512_load_ps(w);
48      $for C in range(16, CHANNEL_TILE, 16):
49        __m512 vacc${ABC[C:C+16]}p0 = _mm512_load_ps(w + ${C});
50
51      $for K in range(KERNEL_TILE):
52
53        const __m512 vi${K}x${ABC[0:16]} = _mm512_loadu_ps(i${K});
54        $for C in range(16, CHANNEL_TILE, 16):
55          const __m512 vi${K}x${ABC[C:C+16]} = _mm512_loadu_ps(i${K} + ${C});
56        i${K} += ${CHANNEL_TILE};
57
58        $for C in range(0, CHANNEL_TILE, 16):
59          const __m512 vk${K}x${ABC[C:C+16]} = _mm512_load_ps(w + ${(K + 1) * CHANNEL_TILE + C});
60        $for C in range(0, CHANNEL_TILE, 16):
61          $if 1 <= K < ACCUMULATORS:
62            __m512 vacc${ABC[C:C+16]}p${K} = _mm512_mul_ps(vi${K}x${ABC[C:C+16]}, vk${K}x${ABC[C:C+16]});
63          $else:
64            vacc${ABC[C:C+16]}p${K % ACCUMULATORS} = _mm512_fmadd_ps(vi${K}x${ABC[C:C+16]}, vk${K}x${ABC[C:C+16]}, vacc${ABC[C:C+16]}p${K % ACCUMULATORS});
65
66      w += ${(KERNEL_TILE + 1) * CHANNEL_TILE};
67
68      $if ACCUMULATORS > 1:
69        // Add up all accumulators to vacc${ABC[0:CHANNEL_TILE]}p0
70        $ACC_SLICE = 1
71        $while ACC_SLICE < ACCUMULATORS:
72          $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
73            $if A + ACC_SLICE < ACCUMULATORS:
74              $for C in range(0, CHANNEL_TILE, 16):
75                vacc${ABC[C:C+16]}p${A} = _mm512_add_ps(vacc${ABC[C:C+16]}p${A}, vacc${ABC[C:C+16]}p${A + ACC_SLICE});
76          $ACC_SLICE *= 2
77
78      $for C in range(0, CHANNEL_TILE, 16):
79        __m512 vacc${ABC[C:C+16]} = _mm512_max_ps(vacc${ABC[C:C+16]}p0, vmin);
80      $for C in range(0, CHANNEL_TILE, 16):
81        vacc${ABC[C:C+16]} = _mm512_min_ps(vacc${ABC[C:C+16]}, vmax);
82
83      _mm512_storeu_ps(output, vacc${ABC[0:16]});
84      $for C in range(16, CHANNEL_TILE, 16):
85        _mm512_storeu_ps(output + ${C}, vacc${ABC[C:C+16]});
86      output += ${CHANNEL_TILE};
87    }
88    $if CHANNEL_TILE > 16:
89      for (; c >= 16; c -= 16) {
90        __m512 vacc${ABC[0:16]}p0 = _mm512_load_ps(w);
91        $for K in range(KERNEL_TILE):
92
93          const __m512 vi${K}x${ABC[0:16]} = _mm512_loadu_ps(i${K});
94          i${K} += 16;
95
96          const __m512 vk${K}x${ABC[0:16]} = _mm512_load_ps(w + ${(K + 1) * CHANNEL_TILE});
97          $if 1 <= K < ACCUMULATORS:
98            __m512 vacc${ABC[0:16]}p${K} = _mm512_mul_ps(vi${K}x${ABC[0:16]}, vk${K}x${ABC[0:16]});
99          $else:
100            vacc${ABC[0:16]}p${K % ACCUMULATORS} = _mm512_fmadd_ps(vi${K}x${ABC[0:16]}, vk${K}x${ABC[0:16]}, vacc${ABC[0:16]}p${K % ACCUMULATORS});
101
102        w += 16;
103
104        $if ACCUMULATORS > 1:
105          // Add up all accumulators to vacc${ABC[0:16]}p0
106          $ACC_SLICE = 1
107          $while ACC_SLICE < ACCUMULATORS:
108            $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
109              $if A + ACC_SLICE < ACCUMULATORS:
110                vacc${ABC[0:16]}p${A} = _mm512_add_ps(vacc${ABC[0:16]}p${A}, vacc${ABC[0:16]}p${A + ACC_SLICE});
111            $ACC_SLICE *= 2
112
113        __m512 vacc${ABC[0:16]} = _mm512_max_ps(vacc${ABC[0:16]}p0, vmin);
114        vacc${ABC[0:16]} = _mm512_min_ps(vacc${ABC[0:16]}, vmax);
115
116        _mm512_storeu_ps(output, vacc${ABC[0:16]});
117        output += 16;
118      }
119    if XNN_UNLIKELY(c != 0) {
120      assert(c >= 1);
121      assert(c <= 16);
122      // Prepare mask for valid 32-bit elements (depends on nc).
123      const __mmask16 vmask = _cvtu32_mask16((uint16_t) ((uint32_t) (UINT32_C(1) << c) - UINT32_C(1)));
124
125      __m512 vacc${ABC[0:16]}p0 = _mm512_maskz_loadu_ps(vmask, w);
126      $for K in range(KERNEL_TILE):
127
128        const __m512 vi${K}x${ABC[0:16]} = _mm512_maskz_loadu_ps(vmask, i${K});
129        const __m512 vk${K}x${ABC[0:16]} = _mm512_maskz_loadu_ps(vmask, w + ${(K + 1) * CHANNEL_TILE});
130        $if 1 <= K < ACCUMULATORS:
131          __m512 vacc${ABC[0:16]}p${K} = _mm512_mul_ps(vi${K}x${ABC[0:16]}, vk${K}x${ABC[0:16]});
132        $else:
133          vacc${ABC[0:16]}p${K % ACCUMULATORS} = _mm512_fmadd_ps(vi${K}x${ABC[0:16]}, vk${K}x${ABC[0:16]}, vacc${ABC[0:16]}p${K % ACCUMULATORS});
134
135      $if ACCUMULATORS > 1:
136        // Add up all accumulators to vacc${ABC[0:16]}p0
137        $ACC_SLICE = 1
138        $while ACC_SLICE < ACCUMULATORS:
139          $for A in range(0, ACCUMULATORS, ACC_SLICE * 2):
140            $if A + ACC_SLICE < ACCUMULATORS:
141              vacc${ABC[0:16]}p${A} = _mm512_add_ps(vacc${ABC[0:16]}p${A}, vacc${ABC[0:16]}p${A + ACC_SLICE});
142          $ACC_SLICE *= 2
143
144      __m512 vacc${ABC[0:16]} = _mm512_max_ps(vacc${ABC[0:16]}p0, vmin);
145      vacc${ABC[0:16]} = _mm512_min_ps(vacc${ABC[0:16]}, vmax);
146
147      _mm512_mask_storeu_ps(output, vmask, vacc${ABC[0:16]});
148      output += c;
149    }
150
151    output = (float*) ((uintptr_t) output + output_increment);
152  } while (--output_width != 0);
153}
154