1*a58d3d2aSXin Li /* Copyright (c) 2018 Mozilla 2*a58d3d2aSXin Li Copyright (c) 2017 Jean-Marc Valin */ 3*a58d3d2aSXin Li /* 4*a58d3d2aSXin Li Redistribution and use in source and binary forms, with or without 5*a58d3d2aSXin Li modification, are permitted provided that the following conditions 6*a58d3d2aSXin Li are met: 7*a58d3d2aSXin Li 8*a58d3d2aSXin Li - Redistributions of source code must retain the above copyright 9*a58d3d2aSXin Li notice, this list of conditions and the following disclaimer. 10*a58d3d2aSXin Li 11*a58d3d2aSXin Li - Redistributions in binary form must reproduce the above copyright 12*a58d3d2aSXin Li notice, this list of conditions and the following disclaimer in the 13*a58d3d2aSXin Li documentation and/or other materials provided with the distribution. 14*a58d3d2aSXin Li 15*a58d3d2aSXin Li THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 16*a58d3d2aSXin Li ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 17*a58d3d2aSXin Li LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 18*a58d3d2aSXin Li A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE FOUNDATION OR 19*a58d3d2aSXin Li CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20*a58d3d2aSXin Li EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21*a58d3d2aSXin Li PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22*a58d3d2aSXin Li PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 23*a58d3d2aSXin Li LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 24*a58d3d2aSXin Li NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 25*a58d3d2aSXin Li SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26*a58d3d2aSXin Li */ 27*a58d3d2aSXin Li 28*a58d3d2aSXin Li #ifndef NNET_H_ 29*a58d3d2aSXin Li #define NNET_H_ 30*a58d3d2aSXin Li 31*a58d3d2aSXin Li #include <stddef.h> 32*a58d3d2aSXin Li #include "opus_types.h" 33*a58d3d2aSXin Li 34*a58d3d2aSXin Li #define ACTIVATION_LINEAR 0 35*a58d3d2aSXin Li #define ACTIVATION_SIGMOID 1 36*a58d3d2aSXin Li #define ACTIVATION_TANH 2 37*a58d3d2aSXin Li #define ACTIVATION_RELU 3 38*a58d3d2aSXin Li #define ACTIVATION_SOFTMAX 4 39*a58d3d2aSXin Li #define ACTIVATION_SWISH 5 40*a58d3d2aSXin Li 41*a58d3d2aSXin Li #define WEIGHT_BLOB_VERSION 0 42*a58d3d2aSXin Li #define WEIGHT_BLOCK_SIZE 64 43*a58d3d2aSXin Li typedef struct { 44*a58d3d2aSXin Li const char *name; 45*a58d3d2aSXin Li int type; 46*a58d3d2aSXin Li int size; 47*a58d3d2aSXin Li const void *data; 48*a58d3d2aSXin Li } WeightArray; 49*a58d3d2aSXin Li 50*a58d3d2aSXin Li #define WEIGHT_TYPE_float 0 51*a58d3d2aSXin Li #define WEIGHT_TYPE_int 1 52*a58d3d2aSXin Li #define WEIGHT_TYPE_qweight 2 53*a58d3d2aSXin Li #define WEIGHT_TYPE_int8 3 54*a58d3d2aSXin Li 55*a58d3d2aSXin Li typedef struct { 56*a58d3d2aSXin Li char head[4]; 57*a58d3d2aSXin Li int version; 58*a58d3d2aSXin Li int type; 59*a58d3d2aSXin Li int size; 60*a58d3d2aSXin Li int block_size; 61*a58d3d2aSXin Li char name[44]; 62*a58d3d2aSXin Li } WeightHead; 63*a58d3d2aSXin Li 64*a58d3d2aSXin Li /* Generic sparse affine transformation. */ 65*a58d3d2aSXin Li typedef struct { 66*a58d3d2aSXin Li const float *bias; 67*a58d3d2aSXin Li const float *subias; 68*a58d3d2aSXin Li const opus_int8 *weights; 69*a58d3d2aSXin Li const float *float_weights; 70*a58d3d2aSXin Li const int *weights_idx; 71*a58d3d2aSXin Li const float *diag; 72*a58d3d2aSXin Li const float *scale; 73*a58d3d2aSXin Li int nb_inputs; 74*a58d3d2aSXin Li int nb_outputs; 75*a58d3d2aSXin Li } LinearLayer; 76*a58d3d2aSXin Li 77*a58d3d2aSXin Li /* Generic sparse affine transformation. */ 78*a58d3d2aSXin Li typedef struct { 79*a58d3d2aSXin Li const float *bias; 80*a58d3d2aSXin Li const float *float_weights; 81*a58d3d2aSXin Li int in_channels; 82*a58d3d2aSXin Li int out_channels; 83*a58d3d2aSXin Li int ktime; 84*a58d3d2aSXin Li int kheight; 85*a58d3d2aSXin Li } Conv2dLayer; 86*a58d3d2aSXin Li 87*a58d3d2aSXin Li 88*a58d3d2aSXin Li void compute_generic_dense(const LinearLayer *layer, float *output, const float *input, int activation, int arch); 89*a58d3d2aSXin Li void compute_generic_gru(const LinearLayer *input_weights, const LinearLayer *recurrent_weights, float *state, const float *in, int arch); 90*a58d3d2aSXin Li void compute_generic_conv1d(const LinearLayer *layer, float *output, float *mem, const float *input, int input_size, int activation, int arch); 91*a58d3d2aSXin Li void compute_generic_conv1d_dilation(const LinearLayer *layer, float *output, float *mem, const float *input, int input_size, int dilation, int activation, int arch); 92*a58d3d2aSXin Li void compute_glu(const LinearLayer *layer, float *output, const float *input, int arch); 93*a58d3d2aSXin Li void compute_gated_activation(const LinearLayer *layer, float *output, const float *input, int activation, int arch); 94*a58d3d2aSXin Li 95*a58d3d2aSXin Li 96*a58d3d2aSXin Li int parse_weights(WeightArray **list, const void *data, int len); 97*a58d3d2aSXin Li 98*a58d3d2aSXin Li 99*a58d3d2aSXin Li extern const WeightArray lpcnet_arrays[]; 100*a58d3d2aSXin Li extern const WeightArray plcmodel_arrays[]; 101*a58d3d2aSXin Li extern const WeightArray rdovaeenc_arrays[]; 102*a58d3d2aSXin Li extern const WeightArray rdovaedec_arrays[]; 103*a58d3d2aSXin Li extern const WeightArray fwgan_arrays[]; 104*a58d3d2aSXin Li extern const WeightArray fargan_arrays[]; 105*a58d3d2aSXin Li extern const WeightArray pitchdnn_arrays[]; 106*a58d3d2aSXin Li extern const WeightArray lossgen_arrays[]; 107*a58d3d2aSXin Li 108*a58d3d2aSXin Li int linear_init(LinearLayer *layer, const WeightArray *arrays, 109*a58d3d2aSXin Li const char *bias, 110*a58d3d2aSXin Li const char *subias, 111*a58d3d2aSXin Li const char *weights, 112*a58d3d2aSXin Li const char *float_weights, 113*a58d3d2aSXin Li const char *weights_idx, 114*a58d3d2aSXin Li const char *diag, 115*a58d3d2aSXin Li const char *scale, 116*a58d3d2aSXin Li int nb_inputs, 117*a58d3d2aSXin Li int nb_outputs); 118*a58d3d2aSXin Li 119*a58d3d2aSXin Li int conv2d_init(Conv2dLayer *layer, const WeightArray *arrays, 120*a58d3d2aSXin Li const char *bias, 121*a58d3d2aSXin Li const char *float_weights, 122*a58d3d2aSXin Li int in_channels, 123*a58d3d2aSXin Li int out_channels, 124*a58d3d2aSXin Li int ktime, 125*a58d3d2aSXin Li int kheight); 126*a58d3d2aSXin Li 127*a58d3d2aSXin Li 128*a58d3d2aSXin Li void compute_linear_c(const LinearLayer *linear, float *out, const float *in); 129*a58d3d2aSXin Li void compute_activation_c(float *output, const float *input, int N, int activation); 130*a58d3d2aSXin Li void compute_conv2d_c(const Conv2dLayer *conv, float *out, float *mem, const float *in, int height, int hstride, int activation); 131*a58d3d2aSXin Li 132*a58d3d2aSXin Li 133*a58d3d2aSXin Li #if defined(OPUS_ARM_MAY_HAVE_DOTPROD) || defined(OPUS_ARM_MAY_HAVE_NEON_INTR) 134*a58d3d2aSXin Li #include "arm/dnn_arm.h" 135*a58d3d2aSXin Li #endif 136*a58d3d2aSXin Li 137*a58d3d2aSXin Li #if defined(OPUS_X86_MAY_HAVE_SSE2) 138*a58d3d2aSXin Li #include "x86/dnn_x86.h" 139*a58d3d2aSXin Li #endif 140*a58d3d2aSXin Li 141*a58d3d2aSXin Li #ifndef OVERRIDE_COMPUTE_LINEAR 142*a58d3d2aSXin Li #define compute_linear(linear, out, in, arch) ((void)(arch),compute_linear_c(linear, out, in)) 143*a58d3d2aSXin Li #endif 144*a58d3d2aSXin Li 145*a58d3d2aSXin Li #ifndef OVERRIDE_COMPUTE_ACTIVATION 146*a58d3d2aSXin Li #define compute_activation(output, input, N, activation, arch) ((void)(arch),compute_activation_c(output, input, N, activation)) 147*a58d3d2aSXin Li #endif 148*a58d3d2aSXin Li 149*a58d3d2aSXin Li #ifndef OVERRIDE_COMPUTE_CONV2D 150*a58d3d2aSXin Li #define compute_conv2d(conv, out, mem, in, height, hstride, activation, arch) ((void)(arch),compute_conv2d_c(conv, out, mem, in, height, hstride, activation)) 151*a58d3d2aSXin Li #endif 152*a58d3d2aSXin Li 153*a58d3d2aSXin Li #if defined(__x86_64__) && !defined(OPUS_X86_MAY_HAVE_SSE4_1) && !defined(OPUS_X86_MAY_HAVE_AVX2) 154*a58d3d2aSXin Li #if defined(_MSC_VER) 155*a58d3d2aSXin Li #pragma message ("Only SSE and SSE2 are available. On newer machines, enable SSSE3/AVX/AVX2 to get better performance") 156*a58d3d2aSXin Li #else 157*a58d3d2aSXin Li #warning "Only SSE and SSE2 are available. On newer machines, enable SSSE3/AVX/AVX2 using -march= to get better performance" 158*a58d3d2aSXin Li #endif 159*a58d3d2aSXin Li #endif 160*a58d3d2aSXin Li 161*a58d3d2aSXin Li 162*a58d3d2aSXin Li 163*a58d3d2aSXin Li #endif /* NNET_H_ */ 164