xref: /aosp_15_r20/external/XNNPACK/src/f32-spmm/wasmsimd.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert MR % 4 == 0
7$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
8$assert MINMAX in ["MINMAX", "PMINMAX"]
9#include <assert.h>
10
11#include <wasm_simd128.h>
12
13#include <xnnpack/spmm.h>
14
15
16$WASM_F32X4_MIN={"MINMAX": "wasm_f32x4_min", "PMINMAX": "wasm_f32x4_pmin"}[MINMAX]
17$WASM_F32X4_MAX={"MINMAX": "wasm_f32x4_max", "PMINMAX": "wasm_f32x4_pmax"}[MINMAX]
18$ARCH_SUFFIX = "_x86" if MINMAX == "PMINMAX" else "_arm"
19void xnn_f32_spmm_minmax_ukernel_${MR}x${NR}__wasmsimd${ARCH_SUFFIX}${"_x" + str(UNROLL) if UNROLL > 1 else ""}(
20    size_t mc,
21    size_t nc,
22    const float*restrict input,
23    const float*restrict weights,
24    const int32_t*restrict widx_dmap,
25    const uint32_t*restrict nidx_nnzmap,
26    float*restrict output,
27    size_t output_stride,
28    const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
29{
30  assert(mc != 0);
31  assert(mc % sizeof(float) == 0);
32  assert(nc != 0);
33
34  const v128_t vmin = wasm_v128_load64_splat(params->wasmsimd.min);
35  const v128_t vmax = wasm_v128_load64_splat(params->wasmsimd.max);
36  size_t output_decrement = output_stride * nc - ${MR} * sizeof(float);
37  while XNN_LIKELY(mc >= ${MR} * sizeof(float)) {
38    const float*restrict w = weights;
39    const int32_t* dmap = widx_dmap;
40    const uint32_t* nnzmap = nidx_nnzmap;
41    size_t n = nc;
42    do {
43      uint32_t nnz = *nnzmap++;
44      $if UNROLL > 1:
45        v128_t vacc0123x0 = wasm_v128_load32_splat(w);
46        w += 1;
47        $for K in range(1, UNROLL):
48          v128_t vacc0123x${K} = wasm_f32x4_const_splat(0.0f);
49        $for M in range(4, MR, 4):
50          v128_t vacc${ABC[M:M+4]}x0 = vacc0123x0;
51          $for K in range(1, UNROLL):
52            v128_t vacc${ABC[M:M+4]}x${K} = wasm_f32x4_const_splat(0.0f);
53        for (; nnz >= ${UNROLL}; nnz -= ${UNROLL}) {
54          $for K in range(UNROLL):
55            const intptr_t diff${K} = dmap[${K}];
56          dmap += ${UNROLL};
57          $for K in range(UNROLL):
58            const v128_t vi0123x${K} = wasm_v128_load(input);
59            $for M in range(4, MR, 4):
60              const v128_t vi${ABC[M:M+4]}x${K} = wasm_v128_load(input + ${M});
61            input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff${K});
62            const v128_t vw${K} = wasm_v128_load32_splat(w);
63            w += 1;
64            $for M in range(0, MR, 4):
65              vacc${ABC[M:M+4]}x${K} = wasm_f32x4_add(vacc${ABC[M:M+4]}x${K}, wasm_f32x4_mul(vi${ABC[M:M+4]}x${K}, vw${K}));
66        }
67        $for M in range(0, MR, 4):
68          v128_t vacc${ABC[M:M+4]} = vacc${ABC[M:M+4]}x0;
69        $for K in range(1, UNROLL):
70          $for M in range(0, MR, 4):
71            vacc${ABC[M:M+4]} = wasm_f32x4_add(vacc${ABC[M:M+4]}, vacc${ABC[M:M+4]}x${K});
72      $else:
73        v128_t vacc0123 = wasm_v128_load32_splat(w); w += 1;
74        $for M in range(4, MR, 4):
75          v128_t vacc${ABC[M:M+4]} = vacc0123;
76      if XNN_LIKELY(nnz != 0) {
77        do {
78          const intptr_t diff = *dmap++;
79          const v128_t vi0123 = wasm_v128_load(input);
80          $for M in range(4, MR, 4):
81            const v128_t vi${ABC[M:M+4]} = wasm_v128_load(input + ${M});
82          input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
83          const v128_t vw = wasm_v128_load32_splat(w); w += 1;
84          $for M in range(0, MR, 4):
85            vacc${ABC[M:M+4]} = wasm_f32x4_add(vacc${ABC[M:M+4]}, wasm_f32x4_mul(vi${ABC[M:M+4]}, vw));
86        } while (--nnz != 0);
87      }
88      $for M in range(0, MR, 4):
89        v128_t vout${ABC[M:M+4]} = ${WASM_F32X4_MIN}(vmax, vacc${ABC[M:M+4]});
90      $for M in range(0, MR, 4):
91        vout${ABC[M:M+4]} = ${WASM_F32X4_MAX}(vmin, vout${ABC[M:M+4]});
92      wasm_v128_store(output, vout0123);
93      $for M in range(4, MR, 4):
94        wasm_v128_store(output + ${M}, vout${ABC[M:M+4]});
95      output = (float*restrict) ((uintptr_t) output + output_stride);
96    } while (--n != 0);
97    output = (float*restrict) ((uintptr_t) output - output_decrement);
98    input += ${MR};
99    mc -= ${MR} * sizeof(float);
100  }
101  if XNN_UNLIKELY(mc != 0) {
102    $for LOG2M in reversed(range((MR - 1).bit_length())):
103      $SUBMR = 1 << LOG2M
104      $if SUBMR * 2 >= MR:
105        output_decrement += ${MR - SUBMR} * sizeof(float);
106      $else:
107        output_decrement += ${SUBMR} * sizeof(float);
108      if (mc & (${SUBMR} * sizeof(float))) {
109        const float*restrict w = weights;
110        const int32_t* dmap = widx_dmap;
111        const uint32_t* nnzmap = nidx_nnzmap;
112        size_t n = nc;
113        do {
114          uint32_t nnz = *nnzmap++;
115          $if SUBMR == 1:
116            v128_t vacc0 = wasm_v128_load32_splat(w); w += 1;
117          $elif SUBMR == 2:
118            v128_t vacc01 = wasm_v128_load32_splat(w); w += 1;
119          $else:
120            v128_t vacc0123 = wasm_v128_load32_splat(w); w += 1;
121          $for M in range(4, SUBMR, 4):
122            v128_t vacc${ABC[M:M+4]} = vacc0123;
123          if XNN_LIKELY(nnz != 0) {
124            do {
125              const intptr_t diff = *dmap++;
126              $if SUBMR >= 4:
127                const v128_t vi0123 = wasm_v128_load(input);
128              $elif SUBMR == 2:
129                const v128_t vi01 = wasm_v128_load64_splat(input);
130              $elif SUBMR == 1:
131                const v128_t vi0 = wasm_v128_load32_splat(input);
132              $for M in range(4, SUBMR, 4):
133                const v128_t vi${ABC[M:M+4]} = wasm_v128_load(input + ${M});
134              input = (const float*restrict) ((uintptr_t) input + (uintptr_t) diff);
135              const v128_t vw = wasm_v128_load32_splat(w); w += 1;
136              $if SUBMR == 1:
137                vacc${ABC[0]} = wasm_f32x4_add(vacc${ABC[0]}, wasm_f32x4_mul(vi${ABC[0]}, vw));
138              $else:
139                $for M in range(0, SUBMR, 4):
140                  vacc${ABC[M:min(M+4,SUBMR)]} = wasm_f32x4_add(vacc${ABC[M:min(M+4,SUBMR)]}, wasm_f32x4_mul(vi${ABC[M:min(M+4,SUBMR)]}, vw));
141            } while (--nnz != 0);
142          }
143          $if SUBMR == 1:
144            v128_t vout${ABC[0]} = ${WASM_F32X4_MIN}(vmax, vacc${ABC[0]});
145            vout${ABC[0]} = ${WASM_F32X4_MAX}(vmin, vout${ABC[0]});
146          $else:
147            $for M in range(0, SUBMR, 4):
148              v128_t vout${ABC[M:min(M+4,SUBMR)]} = ${WASM_F32X4_MIN}(vmax, vacc${ABC[M:min(M+4,SUBMR)]});
149            $for M in range(0, SUBMR, 4):
150              vout${ABC[M:min(M+4,SUBMR)]} = ${WASM_F32X4_MAX}(vmin, vout${ABC[M:min(M+4,SUBMR)]});
151          $if SUBMR >= 4:
152            wasm_v128_store(output, vout0123);
153          $elif SUBMR == 2:
154            *((double*) output) = wasm_f64x2_extract_lane(vout01, 0);
155          $elif SUBMR == 1:
156            *output = wasm_f32x4_extract_lane(vout0, 0);
157
158          $for M in range(4, SUBMR, 4):
159            wasm_v128_store(output + ${M}, vout${ABC[M:M+4]});
160          output = (float*restrict) ((uintptr_t) output + output_stride);
161        } while (--n != 0);
162        output = (float*restrict) ((uintptr_t) output - output_decrement);
163        input += ${SUBMR};
164      }
165  }
166}
167