xref: /aosp_15_r20/external/libopus/dnn/parse_lpcnet_weights.c (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Li /* Copyright (c) 2023 Amazon */
2*a58d3d2aSXin Li /*
3*a58d3d2aSXin Li    Redistribution and use in source and binary forms, with or without
4*a58d3d2aSXin Li    modification, are permitted provided that the following conditions
5*a58d3d2aSXin Li    are met:
6*a58d3d2aSXin Li 
7*a58d3d2aSXin Li    - Redistributions of source code must retain the above copyright
8*a58d3d2aSXin Li    notice, this list of conditions and the following disclaimer.
9*a58d3d2aSXin Li 
10*a58d3d2aSXin Li    - Redistributions in binary form must reproduce the above copyright
11*a58d3d2aSXin Li    notice, this list of conditions and the following disclaimer in the
12*a58d3d2aSXin Li    documentation and/or other materials provided with the distribution.
13*a58d3d2aSXin Li 
14*a58d3d2aSXin Li    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
15*a58d3d2aSXin Li    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
16*a58d3d2aSXin Li    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
17*a58d3d2aSXin Li    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
18*a58d3d2aSXin Li    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
19*a58d3d2aSXin Li    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
20*a58d3d2aSXin Li    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
21*a58d3d2aSXin Li    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
22*a58d3d2aSXin Li    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
23*a58d3d2aSXin Li    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24*a58d3d2aSXin Li    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25*a58d3d2aSXin Li */
26*a58d3d2aSXin Li 
27*a58d3d2aSXin Li #ifdef HAVE_CONFIG_H
28*a58d3d2aSXin Li #include "config.h"
29*a58d3d2aSXin Li #endif
30*a58d3d2aSXin Li 
31*a58d3d2aSXin Li #include <string.h>
32*a58d3d2aSXin Li #include <stdlib.h>
33*a58d3d2aSXin Li #include "nnet.h"
34*a58d3d2aSXin Li #include "os_support.h"
35*a58d3d2aSXin Li 
36*a58d3d2aSXin Li #define SPARSE_BLOCK_SIZE 32
37*a58d3d2aSXin Li 
parse_record(const void ** data,int * len,WeightArray * array)38*a58d3d2aSXin Li int parse_record(const void **data, int *len, WeightArray *array) {
39*a58d3d2aSXin Li   WeightHead *h = (WeightHead *)*data;
40*a58d3d2aSXin Li   if (*len < WEIGHT_BLOCK_SIZE) return -1;
41*a58d3d2aSXin Li   if (h->block_size < h->size) return -1;
42*a58d3d2aSXin Li   if (h->block_size > *len-WEIGHT_BLOCK_SIZE) return -1;
43*a58d3d2aSXin Li   if (h->name[sizeof(h->name)-1] != 0) return -1;
44*a58d3d2aSXin Li   if (h->size < 0) return -1;
45*a58d3d2aSXin Li   array->name = h->name;
46*a58d3d2aSXin Li   array->type = h->type;
47*a58d3d2aSXin Li   array->size = h->size;
48*a58d3d2aSXin Li   array->data = (void*)((unsigned char*)(*data)+WEIGHT_BLOCK_SIZE);
49*a58d3d2aSXin Li 
50*a58d3d2aSXin Li   *data = (void*)((unsigned char*)*data + h->block_size+WEIGHT_BLOCK_SIZE);
51*a58d3d2aSXin Li   *len -= h->block_size+WEIGHT_BLOCK_SIZE;
52*a58d3d2aSXin Li   return array->size;
53*a58d3d2aSXin Li }
54*a58d3d2aSXin Li 
parse_weights(WeightArray ** list,const void * data,int len)55*a58d3d2aSXin Li int parse_weights(WeightArray **list, const void *data, int len)
56*a58d3d2aSXin Li {
57*a58d3d2aSXin Li   int nb_arrays=0;
58*a58d3d2aSXin Li   int capacity=20;
59*a58d3d2aSXin Li   *list = opus_alloc(capacity*sizeof(WeightArray));
60*a58d3d2aSXin Li   while (len > 0) {
61*a58d3d2aSXin Li     int ret;
62*a58d3d2aSXin Li     WeightArray array = {NULL, 0, 0, 0};
63*a58d3d2aSXin Li     ret = parse_record(&data, &len, &array);
64*a58d3d2aSXin Li     if (ret > 0) {
65*a58d3d2aSXin Li       if (nb_arrays+1 >= capacity) {
66*a58d3d2aSXin Li         /* Make sure there's room for the ending NULL element too. */
67*a58d3d2aSXin Li         capacity = capacity*3/2;
68*a58d3d2aSXin Li         *list = opus_realloc(*list, capacity*sizeof(WeightArray));
69*a58d3d2aSXin Li       }
70*a58d3d2aSXin Li       (*list)[nb_arrays++] = array;
71*a58d3d2aSXin Li     } else {
72*a58d3d2aSXin Li       opus_free(*list);
73*a58d3d2aSXin Li       *list = NULL;
74*a58d3d2aSXin Li       return -1;
75*a58d3d2aSXin Li     }
76*a58d3d2aSXin Li   }
77*a58d3d2aSXin Li   (*list)[nb_arrays].name=NULL;
78*a58d3d2aSXin Li   return nb_arrays;
79*a58d3d2aSXin Li }
80*a58d3d2aSXin Li 
find_array_entry(const WeightArray * arrays,const char * name)81*a58d3d2aSXin Li static const void *find_array_entry(const WeightArray *arrays, const char *name) {
82*a58d3d2aSXin Li   while (arrays->name && strcmp(arrays->name, name) != 0) arrays++;
83*a58d3d2aSXin Li   return arrays;
84*a58d3d2aSXin Li }
85*a58d3d2aSXin Li 
find_array_check(const WeightArray * arrays,const char * name,int size)86*a58d3d2aSXin Li static const void *find_array_check(const WeightArray *arrays, const char *name, int size) {
87*a58d3d2aSXin Li   const WeightArray *a = find_array_entry(arrays, name);
88*a58d3d2aSXin Li   if (a->name && a->size == size) return a->data;
89*a58d3d2aSXin Li   else return NULL;
90*a58d3d2aSXin Li }
91*a58d3d2aSXin Li 
opt_array_check(const WeightArray * arrays,const char * name,int size,int * error)92*a58d3d2aSXin Li static const void *opt_array_check(const WeightArray *arrays, const char *name, int size, int *error) {
93*a58d3d2aSXin Li   const WeightArray *a = find_array_entry(arrays, name);
94*a58d3d2aSXin Li   *error = (a->name != NULL && a->size != size);
95*a58d3d2aSXin Li   if (a->name && a->size == size) return a->data;
96*a58d3d2aSXin Li   else return NULL;
97*a58d3d2aSXin Li }
98*a58d3d2aSXin Li 
find_idx_check(const WeightArray * arrays,const char * name,int nb_in,int nb_out,int * total_blocks)99*a58d3d2aSXin Li static const void *find_idx_check(const WeightArray *arrays, const char *name, int nb_in, int nb_out, int *total_blocks) {
100*a58d3d2aSXin Li   int remain;
101*a58d3d2aSXin Li   const int *idx;
102*a58d3d2aSXin Li   const WeightArray *a = find_array_entry(arrays, name);
103*a58d3d2aSXin Li   *total_blocks = 0;
104*a58d3d2aSXin Li   if (a == NULL) return NULL;
105*a58d3d2aSXin Li   idx = a->data;
106*a58d3d2aSXin Li   remain = a->size/sizeof(int);
107*a58d3d2aSXin Li   while (remain > 0) {
108*a58d3d2aSXin Li     int nb_blocks;
109*a58d3d2aSXin Li     int i;
110*a58d3d2aSXin Li     nb_blocks = *idx++;
111*a58d3d2aSXin Li     if (remain < nb_blocks+1) return NULL;
112*a58d3d2aSXin Li     for (i=0;i<nb_blocks;i++) {
113*a58d3d2aSXin Li       int pos = *idx++;
114*a58d3d2aSXin Li       if (pos+3 >= nb_in || (pos&0x3)) return NULL;
115*a58d3d2aSXin Li     }
116*a58d3d2aSXin Li     nb_out -= 8;
117*a58d3d2aSXin Li     remain -= nb_blocks+1;
118*a58d3d2aSXin Li     *total_blocks += nb_blocks;
119*a58d3d2aSXin Li   }
120*a58d3d2aSXin Li   if (nb_out != 0) return NULL;
121*a58d3d2aSXin Li   return a->data;
122*a58d3d2aSXin Li }
123*a58d3d2aSXin Li 
linear_init(LinearLayer * layer,const WeightArray * arrays,const char * bias,const char * subias,const char * weights,const char * float_weights,const char * weights_idx,const char * diag,const char * scale,int nb_inputs,int nb_outputs)124*a58d3d2aSXin Li int linear_init(LinearLayer *layer, const WeightArray *arrays,
125*a58d3d2aSXin Li   const char *bias,
126*a58d3d2aSXin Li   const char *subias,
127*a58d3d2aSXin Li   const char *weights,
128*a58d3d2aSXin Li   const char *float_weights,
129*a58d3d2aSXin Li   const char *weights_idx,
130*a58d3d2aSXin Li   const char *diag,
131*a58d3d2aSXin Li   const char *scale,
132*a58d3d2aSXin Li   int nb_inputs,
133*a58d3d2aSXin Li   int nb_outputs)
134*a58d3d2aSXin Li {
135*a58d3d2aSXin Li   int err;
136*a58d3d2aSXin Li   layer->bias = NULL;
137*a58d3d2aSXin Li   layer->subias = NULL;
138*a58d3d2aSXin Li   layer->weights = NULL;
139*a58d3d2aSXin Li   layer->float_weights = NULL;
140*a58d3d2aSXin Li   layer->weights_idx = NULL;
141*a58d3d2aSXin Li   layer->diag = NULL;
142*a58d3d2aSXin Li   layer->scale = NULL;
143*a58d3d2aSXin Li   if (bias != NULL) {
144*a58d3d2aSXin Li     if ((layer->bias = find_array_check(arrays, bias, nb_outputs*sizeof(layer->bias[0]))) == NULL) return 1;
145*a58d3d2aSXin Li   }
146*a58d3d2aSXin Li   if (subias != NULL) {
147*a58d3d2aSXin Li     if ((layer->subias = find_array_check(arrays, subias, nb_outputs*sizeof(layer->subias[0]))) == NULL) return 1;
148*a58d3d2aSXin Li   }
149*a58d3d2aSXin Li   if (weights_idx != NULL) {
150*a58d3d2aSXin Li     int total_blocks;
151*a58d3d2aSXin Li     if ((layer->weights_idx = find_idx_check(arrays, weights_idx, nb_inputs, nb_outputs, &total_blocks)) == NULL) return 1;
152*a58d3d2aSXin Li     if (weights != NULL) {
153*a58d3d2aSXin Li       if ((layer->weights = find_array_check(arrays, weights, SPARSE_BLOCK_SIZE*total_blocks*sizeof(layer->weights[0]))) == NULL) return 1;
154*a58d3d2aSXin Li     }
155*a58d3d2aSXin Li     if (float_weights != NULL) {
156*a58d3d2aSXin Li       layer->float_weights = opt_array_check(arrays, float_weights, SPARSE_BLOCK_SIZE*total_blocks*sizeof(layer->float_weights[0]), &err);
157*a58d3d2aSXin Li       if (err) return 1;
158*a58d3d2aSXin Li     }
159*a58d3d2aSXin Li   } else {
160*a58d3d2aSXin Li     if (weights != NULL) {
161*a58d3d2aSXin Li       if ((layer->weights = find_array_check(arrays, weights, nb_inputs*nb_outputs*sizeof(layer->weights[0]))) == NULL) return 1;
162*a58d3d2aSXin Li     }
163*a58d3d2aSXin Li     if (float_weights != NULL) {
164*a58d3d2aSXin Li       layer->float_weights = opt_array_check(arrays, float_weights, nb_inputs*nb_outputs*sizeof(layer->float_weights[0]), &err);
165*a58d3d2aSXin Li       if (err) return 1;
166*a58d3d2aSXin Li     }
167*a58d3d2aSXin Li   }
168*a58d3d2aSXin Li   if (diag != NULL) {
169*a58d3d2aSXin Li     if ((layer->diag = find_array_check(arrays, diag, nb_outputs*sizeof(layer->diag[0]))) == NULL) return 1;
170*a58d3d2aSXin Li   }
171*a58d3d2aSXin Li   if (weights != NULL) {
172*a58d3d2aSXin Li     if ((layer->scale = find_array_check(arrays, scale, nb_outputs*sizeof(layer->scale[0]))) == NULL) return 1;
173*a58d3d2aSXin Li   }
174*a58d3d2aSXin Li   layer->nb_inputs = nb_inputs;
175*a58d3d2aSXin Li   layer->nb_outputs = nb_outputs;
176*a58d3d2aSXin Li   return 0;
177*a58d3d2aSXin Li }
178*a58d3d2aSXin Li 
conv2d_init(Conv2dLayer * layer,const WeightArray * arrays,const char * bias,const char * float_weights,int in_channels,int out_channels,int ktime,int kheight)179*a58d3d2aSXin Li int conv2d_init(Conv2dLayer *layer, const WeightArray *arrays,
180*a58d3d2aSXin Li   const char *bias,
181*a58d3d2aSXin Li   const char *float_weights,
182*a58d3d2aSXin Li   int in_channels,
183*a58d3d2aSXin Li   int out_channels,
184*a58d3d2aSXin Li   int ktime,
185*a58d3d2aSXin Li   int kheight)
186*a58d3d2aSXin Li {
187*a58d3d2aSXin Li   int err;
188*a58d3d2aSXin Li   layer->bias = NULL;
189*a58d3d2aSXin Li   layer->float_weights = NULL;
190*a58d3d2aSXin Li   if (bias != NULL) {
191*a58d3d2aSXin Li     if ((layer->bias = find_array_check(arrays, bias, out_channels*sizeof(layer->bias[0]))) == NULL) return 1;
192*a58d3d2aSXin Li   }
193*a58d3d2aSXin Li   if (float_weights != NULL) {
194*a58d3d2aSXin Li     layer->float_weights = opt_array_check(arrays, float_weights, in_channels*out_channels*ktime*kheight*sizeof(layer->float_weights[0]), &err);
195*a58d3d2aSXin Li     if (err) return 1;
196*a58d3d2aSXin Li   }
197*a58d3d2aSXin Li   layer->in_channels = in_channels;
198*a58d3d2aSXin Li   layer->out_channels = out_channels;
199*a58d3d2aSXin Li   layer->ktime = ktime;
200*a58d3d2aSXin Li   layer->kheight = kheight;
201*a58d3d2aSXin Li   return 0;
202*a58d3d2aSXin Li }
203*a58d3d2aSXin Li 
204*a58d3d2aSXin Li 
205*a58d3d2aSXin Li 
206*a58d3d2aSXin Li #if 0
207*a58d3d2aSXin Li #include <fcntl.h>
208*a58d3d2aSXin Li #include <sys/mman.h>
209*a58d3d2aSXin Li #include <unistd.h>
210*a58d3d2aSXin Li #include <sys/stat.h>
211*a58d3d2aSXin Li #include <stdio.h>
212*a58d3d2aSXin Li 
213*a58d3d2aSXin Li int main()
214*a58d3d2aSXin Li {
215*a58d3d2aSXin Li   int fd;
216*a58d3d2aSXin Li   void *data;
217*a58d3d2aSXin Li   int len;
218*a58d3d2aSXin Li   int nb_arrays;
219*a58d3d2aSXin Li   int i;
220*a58d3d2aSXin Li   WeightArray *list;
221*a58d3d2aSXin Li   struct stat st;
222*a58d3d2aSXin Li   const char *filename = "weights_blob.bin";
223*a58d3d2aSXin Li   stat(filename, &st);
224*a58d3d2aSXin Li   len = st.st_size;
225*a58d3d2aSXin Li   fd = open(filename, O_RDONLY);
226*a58d3d2aSXin Li   data = mmap(NULL, len, PROT_READ, MAP_SHARED, fd, 0);
227*a58d3d2aSXin Li   printf("size is %d\n", len);
228*a58d3d2aSXin Li   nb_arrays = parse_weights(&list, data, len);
229*a58d3d2aSXin Li   for (i=0;i<nb_arrays;i++) {
230*a58d3d2aSXin Li     printf("found %s: size %d\n", list[i].name, list[i].size);
231*a58d3d2aSXin Li   }
232*a58d3d2aSXin Li   printf("%p\n", list[i].name);
233*a58d3d2aSXin Li   opus_free(list);
234*a58d3d2aSXin Li   munmap(data, len);
235*a58d3d2aSXin Li   close(fd);
236*a58d3d2aSXin Li   return 0;
237*a58d3d2aSXin Li }
238*a58d3d2aSXin Li #endif
239