xref: /aosp_15_r20/external/libopus/dnn/parse_lpcnet_weights.c (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1 /* Copyright (c) 2023 Amazon */
2 /*
3    Redistribution and use in source and binary forms, with or without
4    modification, are permitted provided that the following conditions
5    are met:
6 
7    - Redistributions of source code must retain the above copyright
8    notice, this list of conditions and the following disclaimer.
9 
10    - Redistributions in binary form must reproduce the above copyright
11    notice, this list of conditions and the following disclaimer in the
12    documentation and/or other materials provided with the distribution.
13 
14    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
15    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
16    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
17    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
18    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
19    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
20    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
21    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
22    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
23    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 */
26 
27 #ifdef HAVE_CONFIG_H
28 #include "config.h"
29 #endif
30 
31 #include <string.h>
32 #include <stdlib.h>
33 #include "nnet.h"
34 #include "os_support.h"
35 
36 #define SPARSE_BLOCK_SIZE 32
37 
parse_record(const void ** data,int * len,WeightArray * array)38 int parse_record(const void **data, int *len, WeightArray *array) {
39   WeightHead *h = (WeightHead *)*data;
40   if (*len < WEIGHT_BLOCK_SIZE) return -1;
41   if (h->block_size < h->size) return -1;
42   if (h->block_size > *len-WEIGHT_BLOCK_SIZE) return -1;
43   if (h->name[sizeof(h->name)-1] != 0) return -1;
44   if (h->size < 0) return -1;
45   array->name = h->name;
46   array->type = h->type;
47   array->size = h->size;
48   array->data = (void*)((unsigned char*)(*data)+WEIGHT_BLOCK_SIZE);
49 
50   *data = (void*)((unsigned char*)*data + h->block_size+WEIGHT_BLOCK_SIZE);
51   *len -= h->block_size+WEIGHT_BLOCK_SIZE;
52   return array->size;
53 }
54 
parse_weights(WeightArray ** list,const void * data,int len)55 int parse_weights(WeightArray **list, const void *data, int len)
56 {
57   int nb_arrays=0;
58   int capacity=20;
59   *list = opus_alloc(capacity*sizeof(WeightArray));
60   while (len > 0) {
61     int ret;
62     WeightArray array = {NULL, 0, 0, 0};
63     ret = parse_record(&data, &len, &array);
64     if (ret > 0) {
65       if (nb_arrays+1 >= capacity) {
66         /* Make sure there's room for the ending NULL element too. */
67         capacity = capacity*3/2;
68         *list = opus_realloc(*list, capacity*sizeof(WeightArray));
69       }
70       (*list)[nb_arrays++] = array;
71     } else {
72       opus_free(*list);
73       *list = NULL;
74       return -1;
75     }
76   }
77   (*list)[nb_arrays].name=NULL;
78   return nb_arrays;
79 }
80 
find_array_entry(const WeightArray * arrays,const char * name)81 static const void *find_array_entry(const WeightArray *arrays, const char *name) {
82   while (arrays->name && strcmp(arrays->name, name) != 0) arrays++;
83   return arrays;
84 }
85 
find_array_check(const WeightArray * arrays,const char * name,int size)86 static const void *find_array_check(const WeightArray *arrays, const char *name, int size) {
87   const WeightArray *a = find_array_entry(arrays, name);
88   if (a->name && a->size == size) return a->data;
89   else return NULL;
90 }
91 
opt_array_check(const WeightArray * arrays,const char * name,int size,int * error)92 static const void *opt_array_check(const WeightArray *arrays, const char *name, int size, int *error) {
93   const WeightArray *a = find_array_entry(arrays, name);
94   *error = (a->name != NULL && a->size != size);
95   if (a->name && a->size == size) return a->data;
96   else return NULL;
97 }
98 
find_idx_check(const WeightArray * arrays,const char * name,int nb_in,int nb_out,int * total_blocks)99 static const void *find_idx_check(const WeightArray *arrays, const char *name, int nb_in, int nb_out, int *total_blocks) {
100   int remain;
101   const int *idx;
102   const WeightArray *a = find_array_entry(arrays, name);
103   *total_blocks = 0;
104   if (a == NULL) return NULL;
105   idx = a->data;
106   remain = a->size/sizeof(int);
107   while (remain > 0) {
108     int nb_blocks;
109     int i;
110     nb_blocks = *idx++;
111     if (remain < nb_blocks+1) return NULL;
112     for (i=0;i<nb_blocks;i++) {
113       int pos = *idx++;
114       if (pos+3 >= nb_in || (pos&0x3)) return NULL;
115     }
116     nb_out -= 8;
117     remain -= nb_blocks+1;
118     *total_blocks += nb_blocks;
119   }
120   if (nb_out != 0) return NULL;
121   return a->data;
122 }
123 
linear_init(LinearLayer * layer,const WeightArray * arrays,const char * bias,const char * subias,const char * weights,const char * float_weights,const char * weights_idx,const char * diag,const char * scale,int nb_inputs,int nb_outputs)124 int linear_init(LinearLayer *layer, const WeightArray *arrays,
125   const char *bias,
126   const char *subias,
127   const char *weights,
128   const char *float_weights,
129   const char *weights_idx,
130   const char *diag,
131   const char *scale,
132   int nb_inputs,
133   int nb_outputs)
134 {
135   int err;
136   layer->bias = NULL;
137   layer->subias = NULL;
138   layer->weights = NULL;
139   layer->float_weights = NULL;
140   layer->weights_idx = NULL;
141   layer->diag = NULL;
142   layer->scale = NULL;
143   if (bias != NULL) {
144     if ((layer->bias = find_array_check(arrays, bias, nb_outputs*sizeof(layer->bias[0]))) == NULL) return 1;
145   }
146   if (subias != NULL) {
147     if ((layer->subias = find_array_check(arrays, subias, nb_outputs*sizeof(layer->subias[0]))) == NULL) return 1;
148   }
149   if (weights_idx != NULL) {
150     int total_blocks;
151     if ((layer->weights_idx = find_idx_check(arrays, weights_idx, nb_inputs, nb_outputs, &total_blocks)) == NULL) return 1;
152     if (weights != NULL) {
153       if ((layer->weights = find_array_check(arrays, weights, SPARSE_BLOCK_SIZE*total_blocks*sizeof(layer->weights[0]))) == NULL) return 1;
154     }
155     if (float_weights != NULL) {
156       layer->float_weights = opt_array_check(arrays, float_weights, SPARSE_BLOCK_SIZE*total_blocks*sizeof(layer->float_weights[0]), &err);
157       if (err) return 1;
158     }
159   } else {
160     if (weights != NULL) {
161       if ((layer->weights = find_array_check(arrays, weights, nb_inputs*nb_outputs*sizeof(layer->weights[0]))) == NULL) return 1;
162     }
163     if (float_weights != NULL) {
164       layer->float_weights = opt_array_check(arrays, float_weights, nb_inputs*nb_outputs*sizeof(layer->float_weights[0]), &err);
165       if (err) return 1;
166     }
167   }
168   if (diag != NULL) {
169     if ((layer->diag = find_array_check(arrays, diag, nb_outputs*sizeof(layer->diag[0]))) == NULL) return 1;
170   }
171   if (weights != NULL) {
172     if ((layer->scale = find_array_check(arrays, scale, nb_outputs*sizeof(layer->scale[0]))) == NULL) return 1;
173   }
174   layer->nb_inputs = nb_inputs;
175   layer->nb_outputs = nb_outputs;
176   return 0;
177 }
178 
conv2d_init(Conv2dLayer * layer,const WeightArray * arrays,const char * bias,const char * float_weights,int in_channels,int out_channels,int ktime,int kheight)179 int conv2d_init(Conv2dLayer *layer, const WeightArray *arrays,
180   const char *bias,
181   const char *float_weights,
182   int in_channels,
183   int out_channels,
184   int ktime,
185   int kheight)
186 {
187   int err;
188   layer->bias = NULL;
189   layer->float_weights = NULL;
190   if (bias != NULL) {
191     if ((layer->bias = find_array_check(arrays, bias, out_channels*sizeof(layer->bias[0]))) == NULL) return 1;
192   }
193   if (float_weights != NULL) {
194     layer->float_weights = opt_array_check(arrays, float_weights, in_channels*out_channels*ktime*kheight*sizeof(layer->float_weights[0]), &err);
195     if (err) return 1;
196   }
197   layer->in_channels = in_channels;
198   layer->out_channels = out_channels;
199   layer->ktime = ktime;
200   layer->kheight = kheight;
201   return 0;
202 }
203 
204 
205 
206 #if 0
207 #include <fcntl.h>
208 #include <sys/mman.h>
209 #include <unistd.h>
210 #include <sys/stat.h>
211 #include <stdio.h>
212 
213 int main()
214 {
215   int fd;
216   void *data;
217   int len;
218   int nb_arrays;
219   int i;
220   WeightArray *list;
221   struct stat st;
222   const char *filename = "weights_blob.bin";
223   stat(filename, &st);
224   len = st.st_size;
225   fd = open(filename, O_RDONLY);
226   data = mmap(NULL, len, PROT_READ, MAP_SHARED, fd, 0);
227   printf("size is %d\n", len);
228   nb_arrays = parse_weights(&list, data, len);
229   for (i=0;i<nb_arrays;i++) {
230     printf("found %s: size %d\n", list[i].name, list[i].size);
231   }
232   printf("%p\n", list[i].name);
233   opus_free(list);
234   munmap(data, len);
235   close(fd);
236   return 0;
237 }
238 #endif
239