xref: /aosp_15_r20/external/libopus/dnn/nnet.h (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Li /* Copyright (c) 2018 Mozilla
2*a58d3d2aSXin Li    Copyright (c) 2017 Jean-Marc Valin */
3*a58d3d2aSXin Li /*
4*a58d3d2aSXin Li    Redistribution and use in source and binary forms, with or without
5*a58d3d2aSXin Li    modification, are permitted provided that the following conditions
6*a58d3d2aSXin Li    are met:
7*a58d3d2aSXin Li 
8*a58d3d2aSXin Li    - Redistributions of source code must retain the above copyright
9*a58d3d2aSXin Li    notice, this list of conditions and the following disclaimer.
10*a58d3d2aSXin Li 
11*a58d3d2aSXin Li    - Redistributions in binary form must reproduce the above copyright
12*a58d3d2aSXin Li    notice, this list of conditions and the following disclaimer in the
13*a58d3d2aSXin Li    documentation and/or other materials provided with the distribution.
14*a58d3d2aSXin Li 
15*a58d3d2aSXin Li    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
16*a58d3d2aSXin Li    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
17*a58d3d2aSXin Li    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
18*a58d3d2aSXin Li    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
19*a58d3d2aSXin Li    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20*a58d3d2aSXin Li    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21*a58d3d2aSXin Li    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22*a58d3d2aSXin Li    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
23*a58d3d2aSXin Li    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
24*a58d3d2aSXin Li    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25*a58d3d2aSXin Li    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26*a58d3d2aSXin Li */
27*a58d3d2aSXin Li 
28*a58d3d2aSXin Li #ifndef NNET_H_
29*a58d3d2aSXin Li #define NNET_H_
30*a58d3d2aSXin Li 
31*a58d3d2aSXin Li #include <stddef.h>
32*a58d3d2aSXin Li #include "opus_types.h"
33*a58d3d2aSXin Li 
34*a58d3d2aSXin Li #define ACTIVATION_LINEAR  0
35*a58d3d2aSXin Li #define ACTIVATION_SIGMOID 1
36*a58d3d2aSXin Li #define ACTIVATION_TANH    2
37*a58d3d2aSXin Li #define ACTIVATION_RELU    3
38*a58d3d2aSXin Li #define ACTIVATION_SOFTMAX 4
39*a58d3d2aSXin Li #define ACTIVATION_SWISH   5
40*a58d3d2aSXin Li 
41*a58d3d2aSXin Li #define WEIGHT_BLOB_VERSION 0
42*a58d3d2aSXin Li #define WEIGHT_BLOCK_SIZE 64
43*a58d3d2aSXin Li typedef struct {
44*a58d3d2aSXin Li   const char *name;
45*a58d3d2aSXin Li   int type;
46*a58d3d2aSXin Li   int size;
47*a58d3d2aSXin Li   const void *data;
48*a58d3d2aSXin Li } WeightArray;
49*a58d3d2aSXin Li 
50*a58d3d2aSXin Li #define WEIGHT_TYPE_float 0
51*a58d3d2aSXin Li #define WEIGHT_TYPE_int 1
52*a58d3d2aSXin Li #define WEIGHT_TYPE_qweight 2
53*a58d3d2aSXin Li #define WEIGHT_TYPE_int8 3
54*a58d3d2aSXin Li 
55*a58d3d2aSXin Li typedef struct {
56*a58d3d2aSXin Li   char head[4];
57*a58d3d2aSXin Li   int version;
58*a58d3d2aSXin Li   int type;
59*a58d3d2aSXin Li   int size;
60*a58d3d2aSXin Li   int block_size;
61*a58d3d2aSXin Li   char name[44];
62*a58d3d2aSXin Li } WeightHead;
63*a58d3d2aSXin Li 
64*a58d3d2aSXin Li /* Generic sparse affine transformation. */
65*a58d3d2aSXin Li typedef struct {
66*a58d3d2aSXin Li   const float *bias;
67*a58d3d2aSXin Li   const float *subias;
68*a58d3d2aSXin Li   const opus_int8 *weights;
69*a58d3d2aSXin Li   const float *float_weights;
70*a58d3d2aSXin Li   const int *weights_idx;
71*a58d3d2aSXin Li   const float *diag;
72*a58d3d2aSXin Li   const float *scale;
73*a58d3d2aSXin Li   int nb_inputs;
74*a58d3d2aSXin Li   int nb_outputs;
75*a58d3d2aSXin Li } LinearLayer;
76*a58d3d2aSXin Li 
77*a58d3d2aSXin Li /* Generic sparse affine transformation. */
78*a58d3d2aSXin Li typedef struct {
79*a58d3d2aSXin Li   const float *bias;
80*a58d3d2aSXin Li   const float *float_weights;
81*a58d3d2aSXin Li   int in_channels;
82*a58d3d2aSXin Li   int out_channels;
83*a58d3d2aSXin Li   int ktime;
84*a58d3d2aSXin Li   int kheight;
85*a58d3d2aSXin Li } Conv2dLayer;
86*a58d3d2aSXin Li 
87*a58d3d2aSXin Li 
88*a58d3d2aSXin Li void compute_generic_dense(const LinearLayer *layer, float *output, const float *input, int activation, int arch);
89*a58d3d2aSXin Li void compute_generic_gru(const LinearLayer *input_weights, const LinearLayer *recurrent_weights, float *state, const float *in, int arch);
90*a58d3d2aSXin Li void compute_generic_conv1d(const LinearLayer *layer, float *output, float *mem, const float *input, int input_size, int activation, int arch);
91*a58d3d2aSXin Li void compute_generic_conv1d_dilation(const LinearLayer *layer, float *output, float *mem, const float *input, int input_size, int dilation, int activation, int arch);
92*a58d3d2aSXin Li void compute_glu(const LinearLayer *layer, float *output, const float *input, int arch);
93*a58d3d2aSXin Li void compute_gated_activation(const LinearLayer *layer, float *output, const float *input, int activation, int arch);
94*a58d3d2aSXin Li 
95*a58d3d2aSXin Li 
96*a58d3d2aSXin Li int parse_weights(WeightArray **list, const void *data, int len);
97*a58d3d2aSXin Li 
98*a58d3d2aSXin Li 
99*a58d3d2aSXin Li extern const WeightArray lpcnet_arrays[];
100*a58d3d2aSXin Li extern const WeightArray plcmodel_arrays[];
101*a58d3d2aSXin Li extern const WeightArray rdovaeenc_arrays[];
102*a58d3d2aSXin Li extern const WeightArray rdovaedec_arrays[];
103*a58d3d2aSXin Li extern const WeightArray fwgan_arrays[];
104*a58d3d2aSXin Li extern const WeightArray fargan_arrays[];
105*a58d3d2aSXin Li extern const WeightArray pitchdnn_arrays[];
106*a58d3d2aSXin Li extern const WeightArray lossgen_arrays[];
107*a58d3d2aSXin Li 
108*a58d3d2aSXin Li int linear_init(LinearLayer *layer, const WeightArray *arrays,
109*a58d3d2aSXin Li   const char *bias,
110*a58d3d2aSXin Li   const char *subias,
111*a58d3d2aSXin Li   const char *weights,
112*a58d3d2aSXin Li   const char *float_weights,
113*a58d3d2aSXin Li   const char *weights_idx,
114*a58d3d2aSXin Li   const char *diag,
115*a58d3d2aSXin Li   const char *scale,
116*a58d3d2aSXin Li   int nb_inputs,
117*a58d3d2aSXin Li   int nb_outputs);
118*a58d3d2aSXin Li 
119*a58d3d2aSXin Li int conv2d_init(Conv2dLayer *layer, const WeightArray *arrays,
120*a58d3d2aSXin Li   const char *bias,
121*a58d3d2aSXin Li   const char *float_weights,
122*a58d3d2aSXin Li   int in_channels,
123*a58d3d2aSXin Li   int out_channels,
124*a58d3d2aSXin Li   int ktime,
125*a58d3d2aSXin Li   int kheight);
126*a58d3d2aSXin Li 
127*a58d3d2aSXin Li 
128*a58d3d2aSXin Li void compute_linear_c(const LinearLayer *linear, float *out, const float *in);
129*a58d3d2aSXin Li void compute_activation_c(float *output, const float *input, int N, int activation);
130*a58d3d2aSXin Li void compute_conv2d_c(const Conv2dLayer *conv, float *out, float *mem, const float *in, int height, int hstride, int activation);
131*a58d3d2aSXin Li 
132*a58d3d2aSXin Li 
133*a58d3d2aSXin Li #if defined(OPUS_ARM_MAY_HAVE_DOTPROD) || defined(OPUS_ARM_MAY_HAVE_NEON_INTR)
134*a58d3d2aSXin Li #include "arm/dnn_arm.h"
135*a58d3d2aSXin Li #endif
136*a58d3d2aSXin Li 
137*a58d3d2aSXin Li #if defined(OPUS_X86_MAY_HAVE_SSE2)
138*a58d3d2aSXin Li #include "x86/dnn_x86.h"
139*a58d3d2aSXin Li #endif
140*a58d3d2aSXin Li 
141*a58d3d2aSXin Li #ifndef OVERRIDE_COMPUTE_LINEAR
142*a58d3d2aSXin Li #define compute_linear(linear, out, in, arch) ((void)(arch),compute_linear_c(linear, out, in))
143*a58d3d2aSXin Li #endif
144*a58d3d2aSXin Li 
145*a58d3d2aSXin Li #ifndef OVERRIDE_COMPUTE_ACTIVATION
146*a58d3d2aSXin Li #define compute_activation(output, input, N, activation, arch) ((void)(arch),compute_activation_c(output, input, N, activation))
147*a58d3d2aSXin Li #endif
148*a58d3d2aSXin Li 
149*a58d3d2aSXin Li #ifndef OVERRIDE_COMPUTE_CONV2D
150*a58d3d2aSXin Li #define compute_conv2d(conv, out, mem, in, height, hstride, activation, arch) ((void)(arch),compute_conv2d_c(conv, out, mem, in, height, hstride, activation))
151*a58d3d2aSXin Li #endif
152*a58d3d2aSXin Li 
153*a58d3d2aSXin Li #if defined(__x86_64__) && !defined(OPUS_X86_MAY_HAVE_SSE4_1) && !defined(OPUS_X86_MAY_HAVE_AVX2)
154*a58d3d2aSXin Li #if defined(_MSC_VER)
155*a58d3d2aSXin Li #pragma message ("Only SSE and SSE2 are available. On newer machines, enable SSSE3/AVX/AVX2 to get better performance")
156*a58d3d2aSXin Li #else
157*a58d3d2aSXin Li #warning "Only SSE and SSE2 are available. On newer machines, enable SSSE3/AVX/AVX2 using -march= to get better performance"
158*a58d3d2aSXin Li #endif
159*a58d3d2aSXin Li #endif
160*a58d3d2aSXin Li 
161*a58d3d2aSXin Li 
162*a58d3d2aSXin Li 
163*a58d3d2aSXin Li #endif /* NNET_H_ */
164