xref: /aosp_15_r20/external/XNNPACK/src/s16-window/scalar.c.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2022 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert BATCH_TILE >= 1
7#include <assert.h>
8#include <stddef.h>
9#include <stdint.h>
10
11#include <xnnpack/math.h>
12#include <xnnpack/window.h>
13
14
15void xnn_s16_window_ukernel__scalar_x${BATCH_TILE}(
16    size_t rows,
17    size_t batch_size,
18    const int16_t* input,
19    const int16_t* weights,
20    int16_t* output,
21    uint32_t shift)
22{
23  assert(rows > 0);
24  assert(batch_size != 0);
25  assert(input != NULL);
26  assert(weights != NULL);
27  assert(output != NULL);
28  assert(shift < 32);
29
30  do {
31    size_t n = batch_size;
32    const int16_t* w = weights;
33    $if BATCH_TILE > 1:
34      for (; n >= ${BATCH_TILE}; n -= ${BATCH_TILE}) {
35        $for N in range(BATCH_TILE):
36          const int16_t vi${N} = input[${N}];
37        input += ${BATCH_TILE};
38
39        $for N in range(BATCH_TILE):
40          const int16_t w${N} = w[${N}];
41        w += ${BATCH_TILE};
42
43        $for N in range(BATCH_TILE):
44          int32_t vout${N} = (int32_t) vi${N} * (int32_t) w${N};
45
46        $for N in range(BATCH_TILE):
47          vout${N} = math_asr_s32(vout${N}, shift);
48
49        $for N in range(BATCH_TILE):
50          vout${N} = math_max_s32(vout${N}, INT16_MIN);
51
52        $for N in range(BATCH_TILE):
53          vout${N} = math_min_s32(vout${N}, INT16_MAX);
54
55        $for N in range(BATCH_TILE):
56          output[${N}] = (int16_t) vout${N};
57
58        output += ${BATCH_TILE};
59      }
60
61    if XNN_UNLIKELY(n != 0) {
62      do {
63        const int32_t vi = (int32_t) *input++;
64        const int32_t vw = (int32_t) *w++;
65        int32_t vout = vi * vw;
66        vout = math_asr_s32(vout, shift);
67        vout = math_max_s32(vout, INT16_MIN);
68        vout = math_min_s32(vout, INT16_MAX);
69        *output++ = (int16_t) vout;
70      } while (--n != 0);
71    }
72  } while (--rows != 0);
73}
74