xref: /aosp_15_r20/external/XNNPACK/src/f32-gemm/scalar.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert ACTIVATION in ["LINEAR", "RELU", "MINMAX"]
7$assert ACTIVATION != "LINEAR" or not WASM
8#include <assert.h>
9
10#include <xnnpack/gemm.h>
11#include <xnnpack/math.h>
12
13
14$MIN_F32 = "__builtin_wasm_min_f32" if WASM else "math_min_f32"
15$MAX_F32 = "__builtin_wasm_max_f32" if WASM else "math_max_f32"
16$KERNEL = "gemminc" if INC else "gemm"
17$SUFFIX = {"LINEAR": "", "RELU": "_relu", "MINMAX": "_minmax"}[ACTIVATION]
18$PARAMS = {"LINEAR": "xnn_f32_default_params", "RELU": "xnn_f32_relu_params", "MINMAX": "xnn_f32_minmax_params"}[ACTIVATION]
19void xnn_f32_${KERNEL}${SUFFIX}_ukernel_${MR}x${NR}__${"wasm" if WASM else "scalar"}(
20    size_t mr,
21    size_t nc,
22    size_t kc,
23    const float* restrict a,
24    size_t a_stride,
25    const float* restrict w,
26    float* restrict c,
27    size_t cm_stride,
28    size_t cn_stride,
29    $if INC:
30      const float*restrict acc,
31    const union ${PARAMS} params[restrict XNN_MIN_ELEMENTS(1)])
32{
33  assert(mr != 0);
34  assert(mr <= ${MR});
35  assert(nc != 0);
36  assert(kc != 0);
37  assert(kc % sizeof(float) == 0);
38  assert(a != NULL);
39  assert(w != NULL);
40  assert(c != NULL);
41  $if INC:
42    assert(acc != NULL);
43
44  const float* a0 = a;
45  float* c0 = c;
46  $for M in range(1, MR):
47    const float* a${M} = (const float*) ((uintptr_t) a${M-1} + a_stride);
48    float* c${M} = (float*) ((uintptr_t) c${M-1} + cm_stride);
49    $if M % 2 == 0:
50      if XNN_UNPREDICTABLE(mr <= ${M}) {
51        a${M} = a${M-1};
52        c${M} = c${M-1};
53      }
54    $elif M + 1 == MR:
55      if XNN_UNPREDICTABLE(mr != ${M+1}) {
56        a${M} = a${M-1};
57        c${M} = c${M-1};
58      }
59    $else:
60      if XNN_UNPREDICTABLE(mr < ${M+1}) {
61        a${M} = a${M-1};
62        c${M} = c${M-1};
63      }
64
65  $if ACTIVATION == "MINMAX":
66    const float vmin = params->scalar.min;
67    const float vmax = params->scalar.max;
68  do {
69    $if INC:
70      $for M in range(MR):
71        $for N in range(NR):
72          float vacc${M}${N} = acc[${M*NR+N}];
73      acc += ${MR*NR};
74    $else:
75      $for N in range(NR):
76        float vacc0${N} = w[${N}];
77      w += ${NR};
78      $for M in range(1, MR):
79        $for N in range(NR):
80          float vacc${M}${N} = vacc0${N};
81
82    size_t k = kc;
83    do {
84      $for M in range(MR):
85        const float va${M} = *a${M}++;
86
87      $for N in range(NR):
88        const float vb${N} = w[${N}];
89      w += ${NR};
90
91      $for M in range(MR):
92        $for N in range(NR):
93          vacc${M}${N} = math_muladd_f32(va${M}, vb${N}, vacc${M}${N});
94
95      k -= sizeof(float);
96    } while (k != 0);
97
98    $if ACTIVATION == "MINMAX":
99      $for M in range(MR):
100        $for N in range(NR):
101          vacc${M}${N} = ${MAX_F32}(vacc${M}${N}, vmin);
102
103      $for M in range(MR):
104        $for N in range(NR):
105          vacc${M}${N} = ${MIN_F32}(vacc${M}${N}, vmax);
106    $elif ACTIVATION == "RELU":
107      $for M in range(MR):
108        $for N in range(NR):
109          vacc${M}${N} = ${MAX_F32}(vacc${M}${N}, 0.0f);
110
111    if XNN_LIKELY(nc >= ${NR}) {
112      $for M in reversed(range(MR)):
113        $for N in range(NR):
114          c${M}[${N}] = vacc${M}${N};
115        c${M} = (float*) ((uintptr_t) c${M} + cn_stride);
116
117      $for M in reversed(range(MR)):
118        a${M} = (const void*) ((uintptr_t) a${M} - kc);
119
120      nc -= ${NR};
121    } else {
122      $for LOG2N in reversed(range(NR.bit_length() - 1)):
123        if (nc & ${1 << LOG2N}) {
124          $for M in reversed(range(MR)):
125            $for N in range(1 << LOG2N):
126              c${M}[${N}] = vacc${M}${N};
127            $if LOG2N != 0:
128              $for N in range(1 << (LOG2N - 1)):
129                vacc${M}${N} = vacc${M}${N + (1 << LOG2N)};
130              c${M} += ${1 << LOG2N};
131        }
132
133      nc = 0;
134    }
135  } while (nc != 0);
136}
137