xref: /aosp_15_r20/external/executorch/backends/cadence/hifi/kernels/kernels.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <inttypes.h>
12 #include <stddef.h>
13 #include <xa_type_def.h>
14 /* For NNLIB APIs */
15 #include "xa_nnlib_kernels_api.h"
16 
17 /* Potential NNLIB function/APIs */
18 extern "C" WORD32 xa_nn_elm_add_broadcast_4D_f32xf32_f32(
19     FLOAT32* __restrict__ p_out,
20     const WORD32* const p_out_shape,
21     const FLOAT32* __restrict__ p_inp1,
22     const WORD32* const p_inp1_shape,
23     const FLOAT32* __restrict__ p_inp2,
24     const WORD32* const p_inp2_shape);
25 
26 extern "C" WORD32 xa_nn_elm_div_broadcast_4D_f32xf32_f32(
27     FLOAT32* __restrict__ p_out,
28     const WORD32* const p_out_shape,
29     const FLOAT32* __restrict__ p_inp1,
30     const WORD32* const p_inp1_shape,
31     const FLOAT32* __restrict__ p_inp2,
32     const WORD32* const p_inp2_shape);
33 
34 extern "C" WORD32 xa_nn_elm_div_mode_f32xf32_f32(
35     FLOAT32* __restrict__ p_out,
36     const FLOAT32* __restrict__ p_inp1,
37     const FLOAT32* __restrict__ p_inp2,
38     WORD32 num_elm,
39     WORD32 mode);
40 
41 extern "C" WORD32 xa_nn_elm_div_mode_broadcast_4D_f32xf32_f32(
42     FLOAT32* __restrict__ p_out,
43     const WORD32* const p_out_shape,
44     const FLOAT32* __restrict__ p_inp1,
45     const WORD32* const p_inp1_shape,
46     const FLOAT32* __restrict__ p_inp2,
47     const WORD32* const p_inp2_shape,
48     WORD32 mode);
49 
50 extern "C" WORD32 xa_nn_elm_mul_broadcast_4D_f32xf32_f32(
51     FLOAT32* __restrict__ p_out,
52     const WORD32* const p_out_shape,
53     const FLOAT32* __restrict__ p_inp1,
54     const WORD32* const p_inp1_shape,
55     const FLOAT32* __restrict__ p_inp2,
56     const WORD32* const p_inp2_shape);
57 
58 extern "C" WORD32 xa_nn_elm_where_f32xf32_f32(
59     FLOAT32* __restrict__ p_out,
60     const FLOAT32* __restrict__ p_inp1,
61     const FLOAT32* __restrict__ p_inp2,
62     const unsigned char* __restrict__ p_condition,
63     WORD32 num_elm);
64 
65 extern "C" WORD32 xa_nn_elm_where_broadcast_4D_f32xf32_f32(
66     FLOAT32* __restrict__ p_out,
67     const WORD32* const p_out_shape,
68     const FLOAT32* __restrict__ p_inp1,
69     const WORD32* const p_inp1_shape,
70     const FLOAT32* __restrict__ p_inp2,
71     const WORD32* const p_inp2_shape,
72     const unsigned char* __restrict__ p_condition,
73     const WORD32* const p_condition_shape);
74 
75 extern "C" WORD32 xa_nn_reduce_mean_4D_f32_f32(
76     FLOAT32* __restrict__ p_out,
77     const WORD32* const p_out_shape,
78     const FLOAT32* __restrict__ p_inp,
79     const WORD32* const p_inp_shape,
80     const WORD32* __restrict__ p_axis,
81     WORD32 num_out_dims,
82     WORD32 num_inp_dims,
83     WORD32 num_axis_dims,
84     void* __restrict__ p_scratch_in);
85 
86 namespace cadence {
87 namespace impl {
88 namespace HiFi {
89 namespace kernels {
90 
91 void memcpy(void* dst, const void* src, size_t num_bytes);
92 
93 WORD32 matmul_asym8uxasym8u_asym8u(
94     UWORD8* __restrict__ p_out, // output uint8 matrix
95     const UWORD8* __restrict__ p_mat1, // weight uint8 matrix
96     const UWORD8* __restrict__ p_vec1, // input uint8 matrix
97     const WORD32* __restrict__ p_bias, // bias int32 vec
98     WORD32 rows, // rows of p_mat1
99     WORD32 cols1, // columns of p_mat1
100     WORD32 row_stride1, // row stride of p_mat1
101     WORD32 vec_count, // rows of p_mat2
102     WORD32 vec_offset, // vec_offset of p_mat2.
103     WORD32 out_offset, // out_offset, i.e., offset of next output element
104     WORD32 out_stride, // out_stride, i.e., stride to go to next output row
105     WORD32 mat1_zero_bias, // zero_point of p_mat1
106     WORD32 vec1_zero_bias, // zero_point of p_vec1
107     const WORD32* __restrict__ out_multiplier,
108     const WORD32* __restrict__ out_shift,
109     WORD32 out_zero_bias,
110     bool per_channel_quantized = false); // per-channel quantized weight
111 
112 WORD32 xa_nn_matmul_asym8uxasym8u_asym8u(
113     UWORD8* __restrict__ p_out,
114     const UWORD8* __restrict__ p_mat1,
115     const UWORD8* __restrict__ p_mat2,
116     const WORD32* __restrict__ p_bias,
117     WORD32 rows,
118     WORD32 cols,
119     WORD32 row_stride,
120     WORD32 vec_count,
121     WORD32 vec_offset,
122     WORD32 out_offset,
123     WORD32 out_stride,
124     WORD32 mat1_zero_bias,
125     WORD32 vec1_zero_bias,
126     WORD32 out_multiplier,
127     WORD32 out_shift,
128     WORD32 out_zero_bias);
129 
130 template <typename T>
131 T quantize(const float x, float scale, int32_t zero_point);
132 
133 template <typename T>
134 float dequantize(const T x, float scale, int32_t zero_point);
135 
136 template <typename T>
137 void quantize(
138     T* __restrict__ y,
139     const float* __restrict__ x,
140     float scale,
141     int32_t zero_point,
142     size_t size);
143 
144 // Deuantize an int8_t/uint8_t/int16_t array to an fp32 array
145 template <typename T>
146 void dequantize(
147     float* __restrict__ y,
148     const T* __restrict__ x,
149     float scale,
150     int32_t zero_point,
151     size_t size);
152 
153 }; // namespace kernels
154 }; // namespace HiFi
155 }; // namespace impl
156 }; // namespace cadence
157